Skip to content

Motifscan

chromatinhd.data.motifscan.Motifscan

Bases: Flow

A sprase representation of locations of different motifs in regions of the genome

Source code in src/chromatinhd/data/motifscan/motifscan.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
class Motifscan(Flow):
    """
    A sprase representation of locations of different motifs in regions of the genome
    """

    regions = Linked()
    "The regions"

    indptr: TensorstoreInstance = Tensorstore(dtype="<i8", chunks=(10000,))
    "The index pointers for each position in the regions"

    region_indptr: TensorstoreInstance = Tensorstore(dtype="<i8", chunks=(1000,), compression=None)
    "The index pointers for region"

    coordinates: TensorstoreInstance = Tensorstore(dtype="<i4", chunks=(10000,))
    "Coordinate associated to each site"

    region_indices: TensorstoreInstance = Tensorstore(dtype="<i4", chunks=(10000,))
    "Region index associated to each site"

    indices: TensorstoreInstance = Tensorstore(dtype="<i4", chunks=(10000,))
    "Motif index associated to each site"

    # positions: TensorstoreInstance = Tensorstore(dtype="<i8", chunks=(10000,))
    # "Cumulative coordinate of each site"

    scores: TensorstoreInstance = Tensorstore(dtype="<f4", chunks=(10000,))
    "Scores associated with each detected site"

    strands: TensorstoreInstance = Tensorstore(dtype="<f4", chunks=(10000,))
    "Strand associated with each detected site"

    shape = Stored()

    motifs = StoredDataFrame()
    "Dataframe storing auxilliary information for each motif"

    @classmethod
    def from_pwms(
        cls,
        pwms: dict,
        regions: Regions,
        fasta_file: Union[str, pathlib.Path] = None,
        region_onehots: Dict[np.ndarray, torch.Tensor] = None,
        cutoffs: Union[int, float, pd.Series] = None,
        cutoff_col: str = None,
        min_cutoff=3.0,
        motifs: pd.DataFrame = None,
        device: str = None,
        batch_size: int = 50000000,
        path: Union[str, pathlib.Path] = None,
        overwrite: bool = True,
        reuse: bool = False,
    ):
        """
        Create a motifscan object from a set of pwms and a set of regions

        Parameters:
            pwms:
                A dictionary of pwms, where the keys are the motif ids and the values are the pwms
            regions:
                A regions object
            fasta_file:
                The location of the fasta file containing the genome
            region_onehots:
                A dictionary containing the onehot encoding of each region. If not given, the onehot encoding will be extracted from the fasta file
            motifs:
                A dataframe containing auxilliary information for each motif
            cutoffs:
                A dictionary containing the cutoffs for each motif.
            cutoff_col:
                The column in the motifs dataframe containing the cutoffs
            device:
                The device to use for the scanning
            batch_size:
                The batch size to use for scanning. Decrease this if the GPU runs out of memory
            path:
                The folder where the motifscan data will be stored.
            overwrite:
                Whether to overwrite existing motifscan data
            reuse:
                Whether to reuse existing motifscan data
        """

        if device is None:
            device = get_default_device()

        self = cls(path)

        if ((reuse) or (not overwrite)) and self.o.coordinates.exists(self):
            if not reuse:
                import warnings

                warnings.warn("Motifscan already exists. Use overwrite=True to overwrite, reuse=True to ignore this warning.")
            return self

        if overwrite:
            self.reset()

        self.motifs = motifs
        self.regions = regions

        # check or create cutoffs
        if cutoffs is None:
            if cutoff_col is None:
                raise ValueError("Either motifs+cutoff_col or cutoffs need to be specified.")
            if motifs is None:
                raise ValueError("Either motifs+cutoff_col or cutoffs need to be specified. motifs is not given")

            cutoffs = motifs[cutoff_col].to_dict()
        else:
            if isinstance(cutoffs, (float, int)):
                cutoffs = {motif: cutoffs for motif in pwms.keys()}
            elif isinstance(cutoffs, pd.Series):
                cutoffs = cutoffs.to_dict()
            else:
                raise ValueError("cutoffs should be a float, int, dict or pd.Series")
            assert set(cutoffs.keys()) == set(pwms.keys())

        # check or create motifs
        if motifs is None:
            motifs = pd.DataFrame(
                {
                    "motif": list(pwms.keys()),
                }
            ).set_index("motif")

        # divide regions into batches according to batch size
        region_coordinates = regions.coordinates

        region_coordinates = divide_regions_in_batches(region_coordinates, batch_size=batch_size)

        # load in fasta file
        if fasta_file is not None:
            import pysam

            fasta = pysam.FastaFile(fasta_file)
        else:
            fasta = None
            if region_onehots is None:
                raise ValueError("Either fasta_file or region_onehots need to be specified")

        self.indices.open_creator()
        self.scores.open_creator()
        self.strands.open_creator()
        self.coordinates.open_creator()
        self.region_indices.open_creator()

        # do the actual counting by looping over the batches, extract the sequences and scanning
        progress = tqdm.tqdm(region_coordinates.groupby("batch"))
        cur_region_index = 0
        for batch, region_coordinates_batch in progress:
            # extract onehot
            if fasta is None:
                sequences = [fasta.fetch(chrom, start, end + 1) for chrom, start, end in region_coordinates_batch[["chrom", "start", "end"]].values]
                if not all(len(sequence) == len(sequences[0]) for sequence in sequences):
                    raise ValueError("All regions/sequences should have the same length")
                onehot = create_onehots(sequences).permute(0, 2, 1)
            else:
                if region_onehots is None:
                    if fasta_file is None:
                        raise ValueError("fasta_file must be provided if fasta and region_onehots is not provided")
                    progress.set_description("Extracting sequences")
                    region_onehots = create_region_onehots(regions, fasta_file)
                onehot = torch.stack([region_onehots[region] for region in region_coordinates_batch.index]).permute(0, 2, 1)
            onehot = onehot.to(device)

            progress.set_description(f"Scanning batch {batch} {region_coordinates_batch.index[0]}-{region_coordinates_batch.index[-1]}")

            assert onehot.shape[1] == 4
            assert onehot.shape[2] == region_coordinates_batch["len"].iloc[0], (
                onehot.shape[2],
                region_coordinates_batch["len"].iloc[0],
            )

            scores_raw = []
            indices_raw = []
            coordinates_raw = []
            strands_raw = []
            region_indices_raw = []
            for motif_ix, motif in tqdm.tqdm(enumerate(motifs.index)):
                cutoff = cutoffs[motif]

                if cutoff < min_cutoff:
                    cutoff = min_cutoff

                # get pwm
                pwm = pwms[motif]
                if not torch.is_tensor(pwm):
                    pwm = torch.from_numpy(pwm)
                pwm2 = pwm.to(dtype=torch.float32, device=onehot.device).transpose(1, 0)

                (
                    scores,
                    positions,
                    strands,
                ) = scan(onehot, pwm2, cutoff=cutoff)

                coordinates = positions.astype(np.int32) % onehot.shape[-1]

                region_indices = positions // onehot.shape[-1] + cur_region_index

                if "tss" in regions.coordinates:
                    coordinates = coordinates + (self.regions.coordinates["start"] - self.regions.coordinates["tss"]).values[region_indices]

                coordinates_raw.append(coordinates)
                indices_raw.append(np.full_like(coordinates, motif_ix, dtype=np.int32))
                strands_raw.append(strands)
                scores_raw.append(scores)
                region_indices_raw.append(region_indices)

            # concatenate raw values (sorted by motif)
            coordinates = np.concatenate(coordinates_raw)
            indices = np.concatenate(indices_raw)
            strands = np.concatenate(strands_raw)
            scores = np.concatenate(scores_raw)
            region_indices = np.concatenate(region_indices_raw)

            # sort according to position
            sorted_idx = np.lexsort([coordinates, region_indices])
            indices = indices[sorted_idx]
            scores = scores[sorted_idx]
            strands = strands[sorted_idx]
            coordinates = coordinates[sorted_idx]
            region_indices = region_indices[sorted_idx]

            # store batch
            self.indices.extend(indices)
            self.scores.extend(scores)
            self.strands.extend(strands)
            self.coordinates.extend(coordinates)
            self.region_indices.extend(region_indices)

            # update current region index
            cur_region_index += len(region_coordinates_batch)

        return self

    def create_region_indptr(self, overwrite=False):
        """
        Populate the region_indptr
        """

        if self.o.region_indptr.exists(self) and not overwrite:
            return

        region_indices_reader = self.region_indices.open_reader()
        self.region_indptr = indices_to_indptr_chunked(region_indices_reader, self.regions.n_regions, dtype=np.int64)

    def create_indptr(self, overwrite=False):
        """
        Populate the indptr
        """

        if self.o.indptr.exists(self) and not overwrite:
            return

        if self.regions.width is not None:
            indptr = self.indptr.open_creator(shape=((self.regions.n_regions * self.regions.width) + 1,), dtype=np.int64)
            region_width = self.regions.width
            for region_ix, (region_start, region_end) in tqdm.tqdm(enumerate(zip(self.region_indptr[:-1], self.region_indptr[1:]))):
                indptr[region_ix * region_width : (region_ix + 1) * region_width] = indices_to_indptr(self.coordinates[region_start:region_end], self.regions.width)[:-1] + region_start
            indptr[-1] = region_end
        else:
            indptr = self.indptr.open_creator(shape=(self.regions.cumulative_region_lengths[-1] + 1,), dtype=np.int64)
            for region_ix, (region_start, region_end) in tqdm.tqdm(enumerate(zip(self.region_indptr[:-1], self.region_indptr[1:]))):
                region_start_position = self.regions.cumulative_region_lengths[region_ix]
                region_end_position = self.regions.cumulative_region_lengths[region_ix + 1]
                indptr[region_start_position:region_end_position] = (
                    indices_to_indptr_chunked(
                        self.coordinates[region_start:region_end],
                        region_end_position - region_start_position,
                    )[:-1]
                    + region_start
                )
            indptr[-1] = region_end

    @classmethod
    def from_positions(cls, positions, indices, scores, strands, regions, motifs, path=None):
        """
        Create a motifscan object from positions, indices, scores, strands, regions and motifs
        """
        self = cls(path=path)

        # sort the positions
        sorted_idx = np.argsort(positions)

        self.positions = positions[sorted_idx]
        self.indices = indices[sorted_idx]
        self.scores = scores[sorted_idx]
        self.strands = strands[sorted_idx]
        self.regions = regions
        self.motifs = motifs

        return self

    def filter(self, motif_ids, path=None):
        """
        Select a subset of motifs
        """

        self.motifs["ix"] = np.arange(len(self.motifs))
        motif_ixs = self.motifs.loc[motif_ids, "ix"]

        selected_sites = np.isin(self.indices, motif_ixs)

        new = self.__class__(path=path).create(
            regions=self.regions,
            positions=self.positions[selected_sites],
            indices=self.indices[selected_sites],
            scores=self.scores[selected_sites],
            strands=self.strands[selected_sites],
            motifs=self.motifs.loc[motif_ids],
        )

        new.create_indptr()
        return new

    @property
    def n_motifs(self):
        return len(self.motifs)

    @property
    def scanned(self):
        return self.o.indices.exists(self)

    def get_slice(
        self,
        region_ix=None,
        region_id=None,
        start=None,
        end=None,
        return_indptr=False,
        return_scores=True,
        return_strands=True,
        motif_ixs=None,
    ):
        """
        Get a slice of the motifscan

        Parameters:
            region:
                Region id
            start:
                Start of the slice, in region coordinates
            end:
                End of the slice, in region coordinates

        Returns:
            Motifs positions, indices, scores and strands of the slice
        """
        if region_id is not None:
            region = self.regions.coordinates.loc[region_id]
        elif region_ix is not None:
            region = self.regions.coordinates.iloc[region_ix]
        else:
            raise ValueError("Either region or region_ix should be provided")
        if region_ix is None:
            region_ix = self.regions.coordinates.index.get_indexer([region_id])[0]

        if self.regions.width is not None:
            # get slice for fixed width regions
            width = self.regions.width

            if start is None:
                start = self.regions.window[0]
            if end is None:
                end = self.regions.window[1]

            if self.o.indptr.exists(self):
                start = region_ix * width + start
                end = region_ix * width + end
                indptr = self.indptr[start : end + 1]
                indptr_start, indptr_end = indptr[0], indptr[-1]
                indptr = indptr - indptr[0]
            else:
                region_start = self.region_indptr[region_ix]
                region_end = self.region_indptr[region_ix + 1]
                coordinates = self.coordinates[region_start:region_end]
                indptr_start = coordinates.searchsorted(start - 1) + region_start
                indptr_end = coordinates.searchsorted(end) + region_start

            coordinates = self.coordinates[indptr_start:indptr_end]
            indices = self.indices[indptr_start:indptr_end]

            out = [coordinates, indices]
            if return_scores:
                out.append(self.scores[indptr_start:indptr_end])
            if return_strands:
                out.append(self.strands[indptr_start:indptr_end])

            if motif_ixs is not None:
                selection = np.isin(indices, motif_ixs)
                out = [x[selection] for x in out]

            if return_indptr:
                out.append(indptr)

            return out
        else:
            # get slice for variable width regions
            assert start is not None
            assert end is not None

            if self.o.indptr.exists(self):
                start = self.regions.cumulative_region_lengths[region_ix] + start
                end = self.regions.cumulative_region_lengths[region_ix] + end
                indptr = self.indptr[start : end + 1]
                indptr_start, indptr_end = indptr[0], indptr[-1]
                indptr = indptr - indptr[0]
            else:
                region_start = self.region_indptr[region_ix]
                region_end = self.region_indptr[region_ix + 1]
                coordinates = self.coordinates[region_start:region_end]
                indptr_start = coordinates.searchsorted(start - 1) + region_start
                indptr_end = coordinates.searchsorted(end) + region_start

            coordinates = self.coordinates[indptr_start:indptr_end]
            indices = self.indices[indptr_start:indptr_end]

            out = [coordinates, indices]
            if return_scores:
                out.append(self.scores[indptr_start:indptr_end])
            if return_strands:
                out.append(self.strands[indptr_start:indptr_end])

            if motif_ixs is not None:
                selection = np.isin(indices, motif_ixs)
                out = [x[selection] for x in out]

            if return_indptr:
                out.append(indptr)

            return out

    def count_slices(self, slices: pd.DataFrame) -> pd.DataFrame:
        """
        Get multiple slices of the motifscan

        Parameters:
            slices:
                DataFrame containing the slices to get. Each row should contain a region_ix, start and end column. The region_ix should refer to the index of the regions object. The start and end columns should contain the start and end of the slice, in region coordinates.

        Returns:
            DataFrame containing the counts of each motif (columns) in each slice (rows)
        """

        # if self.regions.window is None:
        #     raise NotImplementedError("count_slices is only implemented for regions with a window")

        if "region_ix" not in slices:
            slices["region_ix"] = self.regions.coordinates.index.get_indexer(slices["region"])

        progress = enumerate(zip(slices["start"], slices["end"], slices["region_ix"]))
        progress = tqdm.tqdm(
            progress,
            total=len(slices),
            leave=False,
            desc="Counting slices",
        )

        motif_counts = np.zeros((len(slices), self.n_motifs), dtype=int)
        for i, (relative_start, relative_end, region_ix) in progress:
            start = relative_start
            end = relative_end
            positions, indices = self.get_slice(
                region_ix=region_ix,
                start=start,
                end=end,
                return_scores=False,
                return_strands=False,
                return_indptr=False,
            )
            motif_counts[i] = np.bincount(indices, minlength=self.n_motifs)
        motif_counts = pd.DataFrame(motif_counts, index=slices.index, columns=self.motifs.index)
        return motif_counts

    def select_motif(self, x=None, symbol=None):
        if symbol is not None:
            return self.motifs.loc[self.motifs["symbol"] == symbol].index[0]
        # return motifscan.motifs.loc[motifscan.motifs.index.str.contains(str)].sort_values("quality").index[0]
        return self.motifs.loc[self.motifs.index.str.contains(x)].index[0]

    def select_motifs(self, x=None, symbol=None):
        if symbol is not None:
            return self.motifs.loc[self.motifs["symbol"] == symbol].index.tolist()
        return self.motifs.loc[self.motifs.index.str.contains(x)].index.tolist()

coordinates: TensorstoreInstance = Tensorstore(dtype='<i4', chunks=(10000)) class-attribute instance-attribute

Coordinate associated to each site

indices: TensorstoreInstance = Tensorstore(dtype='<i4', chunks=(10000)) class-attribute instance-attribute

Motif index associated to each site

indptr: TensorstoreInstance = Tensorstore(dtype='<i8', chunks=(10000)) class-attribute instance-attribute

The index pointers for each position in the regions

motifs = StoredDataFrame() class-attribute instance-attribute

Dataframe storing auxilliary information for each motif

region_indices: TensorstoreInstance = Tensorstore(dtype='<i4', chunks=(10000)) class-attribute instance-attribute

Region index associated to each site

region_indptr: TensorstoreInstance = Tensorstore(dtype='<i8', chunks=(1000), compression=None) class-attribute instance-attribute

The index pointers for region

regions = Linked() class-attribute instance-attribute

The regions

scores: TensorstoreInstance = Tensorstore(dtype='<f4', chunks=(10000)) class-attribute instance-attribute

Scores associated with each detected site

strands: TensorstoreInstance = Tensorstore(dtype='<f4', chunks=(10000)) class-attribute instance-attribute

Strand associated with each detected site

count_slices(slices)

Get multiple slices of the motifscan

Parameters:

Name Type Description Default
slices DataFrame

DataFrame containing the slices to get. Each row should contain a region_ix, start and end column. The region_ix should refer to the index of the regions object. The start and end columns should contain the start and end of the slice, in region coordinates.

required

Returns:

Type Description
DataFrame

DataFrame containing the counts of each motif (columns) in each slice (rows)

Source code in src/chromatinhd/data/motifscan/motifscan.py
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
def count_slices(self, slices: pd.DataFrame) -> pd.DataFrame:
    """
    Get multiple slices of the motifscan

    Parameters:
        slices:
            DataFrame containing the slices to get. Each row should contain a region_ix, start and end column. The region_ix should refer to the index of the regions object. The start and end columns should contain the start and end of the slice, in region coordinates.

    Returns:
        DataFrame containing the counts of each motif (columns) in each slice (rows)
    """

    # if self.regions.window is None:
    #     raise NotImplementedError("count_slices is only implemented for regions with a window")

    if "region_ix" not in slices:
        slices["region_ix"] = self.regions.coordinates.index.get_indexer(slices["region"])

    progress = enumerate(zip(slices["start"], slices["end"], slices["region_ix"]))
    progress = tqdm.tqdm(
        progress,
        total=len(slices),
        leave=False,
        desc="Counting slices",
    )

    motif_counts = np.zeros((len(slices), self.n_motifs), dtype=int)
    for i, (relative_start, relative_end, region_ix) in progress:
        start = relative_start
        end = relative_end
        positions, indices = self.get_slice(
            region_ix=region_ix,
            start=start,
            end=end,
            return_scores=False,
            return_strands=False,
            return_indptr=False,
        )
        motif_counts[i] = np.bincount(indices, minlength=self.n_motifs)
    motif_counts = pd.DataFrame(motif_counts, index=slices.index, columns=self.motifs.index)
    return motif_counts

create_indptr(overwrite=False)

Populate the indptr

Source code in src/chromatinhd/data/motifscan/motifscan.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def create_indptr(self, overwrite=False):
    """
    Populate the indptr
    """

    if self.o.indptr.exists(self) and not overwrite:
        return

    if self.regions.width is not None:
        indptr = self.indptr.open_creator(shape=((self.regions.n_regions * self.regions.width) + 1,), dtype=np.int64)
        region_width = self.regions.width
        for region_ix, (region_start, region_end) in tqdm.tqdm(enumerate(zip(self.region_indptr[:-1], self.region_indptr[1:]))):
            indptr[region_ix * region_width : (region_ix + 1) * region_width] = indices_to_indptr(self.coordinates[region_start:region_end], self.regions.width)[:-1] + region_start
        indptr[-1] = region_end
    else:
        indptr = self.indptr.open_creator(shape=(self.regions.cumulative_region_lengths[-1] + 1,), dtype=np.int64)
        for region_ix, (region_start, region_end) in tqdm.tqdm(enumerate(zip(self.region_indptr[:-1], self.region_indptr[1:]))):
            region_start_position = self.regions.cumulative_region_lengths[region_ix]
            region_end_position = self.regions.cumulative_region_lengths[region_ix + 1]
            indptr[region_start_position:region_end_position] = (
                indices_to_indptr_chunked(
                    self.coordinates[region_start:region_end],
                    region_end_position - region_start_position,
                )[:-1]
                + region_start
            )
        indptr[-1] = region_end

create_region_indptr(overwrite=False)

Populate the region_indptr

Source code in src/chromatinhd/data/motifscan/motifscan.py
281
282
283
284
285
286
287
288
289
290
def create_region_indptr(self, overwrite=False):
    """
    Populate the region_indptr
    """

    if self.o.region_indptr.exists(self) and not overwrite:
        return

    region_indices_reader = self.region_indices.open_reader()
    self.region_indptr = indices_to_indptr_chunked(region_indices_reader, self.regions.n_regions, dtype=np.int64)

filter(motif_ids, path=None)

Select a subset of motifs

Source code in src/chromatinhd/data/motifscan/motifscan.py
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
def filter(self, motif_ids, path=None):
    """
    Select a subset of motifs
    """

    self.motifs["ix"] = np.arange(len(self.motifs))
    motif_ixs = self.motifs.loc[motif_ids, "ix"]

    selected_sites = np.isin(self.indices, motif_ixs)

    new = self.__class__(path=path).create(
        regions=self.regions,
        positions=self.positions[selected_sites],
        indices=self.indices[selected_sites],
        scores=self.scores[selected_sites],
        strands=self.strands[selected_sites],
        motifs=self.motifs.loc[motif_ids],
    )

    new.create_indptr()
    return new

from_positions(positions, indices, scores, strands, regions, motifs, path=None) classmethod

Create a motifscan object from positions, indices, scores, strands, regions and motifs

Source code in src/chromatinhd/data/motifscan/motifscan.py
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@classmethod
def from_positions(cls, positions, indices, scores, strands, regions, motifs, path=None):
    """
    Create a motifscan object from positions, indices, scores, strands, regions and motifs
    """
    self = cls(path=path)

    # sort the positions
    sorted_idx = np.argsort(positions)

    self.positions = positions[sorted_idx]
    self.indices = indices[sorted_idx]
    self.scores = scores[sorted_idx]
    self.strands = strands[sorted_idx]
    self.regions = regions
    self.motifs = motifs

    return self

from_pwms(pwms, regions, fasta_file=None, region_onehots=None, cutoffs=None, cutoff_col=None, min_cutoff=3.0, motifs=None, device=None, batch_size=50000000, path=None, overwrite=True, reuse=False) classmethod

Create a motifscan object from a set of pwms and a set of regions

Parameters:

Name Type Description Default
pwms dict

A dictionary of pwms, where the keys are the motif ids and the values are the pwms

required
regions Regions

A regions object

required
fasta_file Union[str, Path]

The location of the fasta file containing the genome

None
region_onehots Dict[ndarray, Tensor]

A dictionary containing the onehot encoding of each region. If not given, the onehot encoding will be extracted from the fasta file

None
motifs DataFrame

A dataframe containing auxilliary information for each motif

None
cutoffs Union[int, float, Series]

A dictionary containing the cutoffs for each motif.

None
cutoff_col str

The column in the motifs dataframe containing the cutoffs

None
device str

The device to use for the scanning

None
batch_size int

The batch size to use for scanning. Decrease this if the GPU runs out of memory

50000000
path Union[str, Path]

The folder where the motifscan data will be stored.

None
overwrite bool

Whether to overwrite existing motifscan data

True
reuse bool

Whether to reuse existing motifscan data

False
Source code in src/chromatinhd/data/motifscan/motifscan.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
@classmethod
def from_pwms(
    cls,
    pwms: dict,
    regions: Regions,
    fasta_file: Union[str, pathlib.Path] = None,
    region_onehots: Dict[np.ndarray, torch.Tensor] = None,
    cutoffs: Union[int, float, pd.Series] = None,
    cutoff_col: str = None,
    min_cutoff=3.0,
    motifs: pd.DataFrame = None,
    device: str = None,
    batch_size: int = 50000000,
    path: Union[str, pathlib.Path] = None,
    overwrite: bool = True,
    reuse: bool = False,
):
    """
    Create a motifscan object from a set of pwms and a set of regions

    Parameters:
        pwms:
            A dictionary of pwms, where the keys are the motif ids and the values are the pwms
        regions:
            A regions object
        fasta_file:
            The location of the fasta file containing the genome
        region_onehots:
            A dictionary containing the onehot encoding of each region. If not given, the onehot encoding will be extracted from the fasta file
        motifs:
            A dataframe containing auxilliary information for each motif
        cutoffs:
            A dictionary containing the cutoffs for each motif.
        cutoff_col:
            The column in the motifs dataframe containing the cutoffs
        device:
            The device to use for the scanning
        batch_size:
            The batch size to use for scanning. Decrease this if the GPU runs out of memory
        path:
            The folder where the motifscan data will be stored.
        overwrite:
            Whether to overwrite existing motifscan data
        reuse:
            Whether to reuse existing motifscan data
    """

    if device is None:
        device = get_default_device()

    self = cls(path)

    if ((reuse) or (not overwrite)) and self.o.coordinates.exists(self):
        if not reuse:
            import warnings

            warnings.warn("Motifscan already exists. Use overwrite=True to overwrite, reuse=True to ignore this warning.")
        return self

    if overwrite:
        self.reset()

    self.motifs = motifs
    self.regions = regions

    # check or create cutoffs
    if cutoffs is None:
        if cutoff_col is None:
            raise ValueError("Either motifs+cutoff_col or cutoffs need to be specified.")
        if motifs is None:
            raise ValueError("Either motifs+cutoff_col or cutoffs need to be specified. motifs is not given")

        cutoffs = motifs[cutoff_col].to_dict()
    else:
        if isinstance(cutoffs, (float, int)):
            cutoffs = {motif: cutoffs for motif in pwms.keys()}
        elif isinstance(cutoffs, pd.Series):
            cutoffs = cutoffs.to_dict()
        else:
            raise ValueError("cutoffs should be a float, int, dict or pd.Series")
        assert set(cutoffs.keys()) == set(pwms.keys())

    # check or create motifs
    if motifs is None:
        motifs = pd.DataFrame(
            {
                "motif": list(pwms.keys()),
            }
        ).set_index("motif")

    # divide regions into batches according to batch size
    region_coordinates = regions.coordinates

    region_coordinates = divide_regions_in_batches(region_coordinates, batch_size=batch_size)

    # load in fasta file
    if fasta_file is not None:
        import pysam

        fasta = pysam.FastaFile(fasta_file)
    else:
        fasta = None
        if region_onehots is None:
            raise ValueError("Either fasta_file or region_onehots need to be specified")

    self.indices.open_creator()
    self.scores.open_creator()
    self.strands.open_creator()
    self.coordinates.open_creator()
    self.region_indices.open_creator()

    # do the actual counting by looping over the batches, extract the sequences and scanning
    progress = tqdm.tqdm(region_coordinates.groupby("batch"))
    cur_region_index = 0
    for batch, region_coordinates_batch in progress:
        # extract onehot
        if fasta is None:
            sequences = [fasta.fetch(chrom, start, end + 1) for chrom, start, end in region_coordinates_batch[["chrom", "start", "end"]].values]
            if not all(len(sequence) == len(sequences[0]) for sequence in sequences):
                raise ValueError("All regions/sequences should have the same length")
            onehot = create_onehots(sequences).permute(0, 2, 1)
        else:
            if region_onehots is None:
                if fasta_file is None:
                    raise ValueError("fasta_file must be provided if fasta and region_onehots is not provided")
                progress.set_description("Extracting sequences")
                region_onehots = create_region_onehots(regions, fasta_file)
            onehot = torch.stack([region_onehots[region] for region in region_coordinates_batch.index]).permute(0, 2, 1)
        onehot = onehot.to(device)

        progress.set_description(f"Scanning batch {batch} {region_coordinates_batch.index[0]}-{region_coordinates_batch.index[-1]}")

        assert onehot.shape[1] == 4
        assert onehot.shape[2] == region_coordinates_batch["len"].iloc[0], (
            onehot.shape[2],
            region_coordinates_batch["len"].iloc[0],
        )

        scores_raw = []
        indices_raw = []
        coordinates_raw = []
        strands_raw = []
        region_indices_raw = []
        for motif_ix, motif in tqdm.tqdm(enumerate(motifs.index)):
            cutoff = cutoffs[motif]

            if cutoff < min_cutoff:
                cutoff = min_cutoff

            # get pwm
            pwm = pwms[motif]
            if not torch.is_tensor(pwm):
                pwm = torch.from_numpy(pwm)
            pwm2 = pwm.to(dtype=torch.float32, device=onehot.device).transpose(1, 0)

            (
                scores,
                positions,
                strands,
            ) = scan(onehot, pwm2, cutoff=cutoff)

            coordinates = positions.astype(np.int32) % onehot.shape[-1]

            region_indices = positions // onehot.shape[-1] + cur_region_index

            if "tss" in regions.coordinates:
                coordinates = coordinates + (self.regions.coordinates["start"] - self.regions.coordinates["tss"]).values[region_indices]

            coordinates_raw.append(coordinates)
            indices_raw.append(np.full_like(coordinates, motif_ix, dtype=np.int32))
            strands_raw.append(strands)
            scores_raw.append(scores)
            region_indices_raw.append(region_indices)

        # concatenate raw values (sorted by motif)
        coordinates = np.concatenate(coordinates_raw)
        indices = np.concatenate(indices_raw)
        strands = np.concatenate(strands_raw)
        scores = np.concatenate(scores_raw)
        region_indices = np.concatenate(region_indices_raw)

        # sort according to position
        sorted_idx = np.lexsort([coordinates, region_indices])
        indices = indices[sorted_idx]
        scores = scores[sorted_idx]
        strands = strands[sorted_idx]
        coordinates = coordinates[sorted_idx]
        region_indices = region_indices[sorted_idx]

        # store batch
        self.indices.extend(indices)
        self.scores.extend(scores)
        self.strands.extend(strands)
        self.coordinates.extend(coordinates)
        self.region_indices.extend(region_indices)

        # update current region index
        cur_region_index += len(region_coordinates_batch)

    return self

get_slice(region_ix=None, region_id=None, start=None, end=None, return_indptr=False, return_scores=True, return_strands=True, motif_ixs=None)

Get a slice of the motifscan

Parameters:

Name Type Description Default
region

Region id

required
start

Start of the slice, in region coordinates

None
end

End of the slice, in region coordinates

None

Returns:

Type Description

Motifs positions, indices, scores and strands of the slice

Source code in src/chromatinhd/data/motifscan/motifscan.py
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
def get_slice(
    self,
    region_ix=None,
    region_id=None,
    start=None,
    end=None,
    return_indptr=False,
    return_scores=True,
    return_strands=True,
    motif_ixs=None,
):
    """
    Get a slice of the motifscan

    Parameters:
        region:
            Region id
        start:
            Start of the slice, in region coordinates
        end:
            End of the slice, in region coordinates

    Returns:
        Motifs positions, indices, scores and strands of the slice
    """
    if region_id is not None:
        region = self.regions.coordinates.loc[region_id]
    elif region_ix is not None:
        region = self.regions.coordinates.iloc[region_ix]
    else:
        raise ValueError("Either region or region_ix should be provided")
    if region_ix is None:
        region_ix = self.regions.coordinates.index.get_indexer([region_id])[0]

    if self.regions.width is not None:
        # get slice for fixed width regions
        width = self.regions.width

        if start is None:
            start = self.regions.window[0]
        if end is None:
            end = self.regions.window[1]

        if self.o.indptr.exists(self):
            start = region_ix * width + start
            end = region_ix * width + end
            indptr = self.indptr[start : end + 1]
            indptr_start, indptr_end = indptr[0], indptr[-1]
            indptr = indptr - indptr[0]
        else:
            region_start = self.region_indptr[region_ix]
            region_end = self.region_indptr[region_ix + 1]
            coordinates = self.coordinates[region_start:region_end]
            indptr_start = coordinates.searchsorted(start - 1) + region_start
            indptr_end = coordinates.searchsorted(end) + region_start

        coordinates = self.coordinates[indptr_start:indptr_end]
        indices = self.indices[indptr_start:indptr_end]

        out = [coordinates, indices]
        if return_scores:
            out.append(self.scores[indptr_start:indptr_end])
        if return_strands:
            out.append(self.strands[indptr_start:indptr_end])

        if motif_ixs is not None:
            selection = np.isin(indices, motif_ixs)
            out = [x[selection] for x in out]

        if return_indptr:
            out.append(indptr)

        return out
    else:
        # get slice for variable width regions
        assert start is not None
        assert end is not None

        if self.o.indptr.exists(self):
            start = self.regions.cumulative_region_lengths[region_ix] + start
            end = self.regions.cumulative_region_lengths[region_ix] + end
            indptr = self.indptr[start : end + 1]
            indptr_start, indptr_end = indptr[0], indptr[-1]
            indptr = indptr - indptr[0]
        else:
            region_start = self.region_indptr[region_ix]
            region_end = self.region_indptr[region_ix + 1]
            coordinates = self.coordinates[region_start:region_end]
            indptr_start = coordinates.searchsorted(start - 1) + region_start
            indptr_end = coordinates.searchsorted(end) + region_start

        coordinates = self.coordinates[indptr_start:indptr_end]
        indices = self.indices[indptr_start:indptr_end]

        out = [coordinates, indices]
        if return_scores:
            out.append(self.scores[indptr_start:indptr_end])
        if return_strands:
            out.append(self.strands[indptr_start:indptr_end])

        if motif_ixs is not None:
            selection = np.isin(indices, motif_ixs)
            out = [x[selection] for x in out]

        if return_indptr:
            out.append(indptr)

        return out