# io.py
# Author: Jacob Schreiber <jmschreiber91@gmail.com>
# Code adapted from Alex Tseng, Avanti Shrikumar, and Ziga Avsec
import numpy
import torch
from tangermeme.io import extract_loci
[docs]
class PeakNegativeSampler(torch.utils.data.Dataset):
"""A data generator mimicking the BPNet data loading procedure.
Here, a set of peaks and negatives are separately loaded. These sets can be
any size. From these sets, batches of given size are sampled that are a
mixture of peaks and negatives.
Sampling is fully deterministic given ``random_state`` and the epoch
number. ``__getitem__(idx)`` is a pure function of ``idx`` and the
current epoch, so ``num_workers > 1`` produces the same per-index
data tuples as ``num_workers = 1`` — the DataLoader yields identical
batch sequences, just faster.
Each peak is drawn exactly once per epoch; the peak/negative
interleaving and all augmentations are reproducible from
``(random_state, epoch)``.
In the documentation below, ``mj`` = max_jitter.
Parameters
----------
peak_sequences: torch.tensor, shape=(n_peaks, 4, in_window+2*mj)
A tensor of peak sequences that are one-hot encoded.
peak_signals: torch.tensor, shape=(n_peaks, t, out_window+2*mj)
A tensor of signals to predict, usually base-pair resolution
integer counts.
peak_controls: torch.tensor, shape=(n, t, out_window+2*mj) or None,
optional
Optional control input track for peak examples.
negative_sequences: torch.tensor, shape=(n, 4, in_window+2*mj)
One-hot encoded negative sequences.
negative_signals: torch.tensor, shape=(n, t, out_window+2*mj)
Negative sequence signals.
negative_controls: torch.tensor or None, optional
Optional control input track for negative examples.
negative_ratio: float, optional
Ratio of negatives to peaks per epoch. ``0`` means no negative
draws. Default 0.1.
in_window: int, optional
The input window size. Default 2114.
out_window: int, optional
The output window size. Default 1000.
max_jitter: int, optional
Maximum jitter (in either direction) applied to peaks. Default 0.
reverse_complement: bool, optional
Whether to reverse complement-augment half of the data. Default
False.
shuffle: bool, optional
Whether to shuffle the peak ordering each epoch. Default True.
random_state: int or None, optional
Base seed for the deterministic per-epoch RNG. If None, a random
seed is captured once at construction time so that all forked
worker processes share it.
"""
[docs]
def __init__(self, peak_sequences, peak_signals, negative_sequences,
negative_signals, peak_controls=None, negative_controls=None,
negative_ratio=0.1, in_window=2114, out_window=1000, max_jitter=0,
reverse_complement=False, shuffle=True, random_state=None):
if max_jitter < 0:
raise ValueError("max_jitter must be non-negative, got {}"
.format(max_jitter))
if negative_ratio < 0:
raise ValueError("negative_ratio must be non-negative, got {}"
.format(negative_ratio))
self.peak_sequences = peak_sequences.numpy(force=True)
self.peak_signals = peak_signals.numpy(force=True)
self.n_peaks = len(self.peak_sequences)
self.negative_sequences = negative_sequences.numpy(force=True)
self.negative_signals = negative_signals.numpy(force=True)
self.n_negatives = len(self.negative_sequences)
if peak_controls is not None:
self.peak_controls = peak_controls.numpy(force=True)
self.negative_controls = negative_controls.numpy(force=True)
else:
self.peak_controls = None
self.negative_controls = None
self.negative_ratio = negative_ratio
self.in_window = in_window
self.out_window = out_window
self.max_jitter = max_jitter
self.reverse_complement = reverse_complement
self.shuffle = shuffle
# Capture one base seed at construction so every forked worker
# inherits the same value (2654435761 is Knuth's hash constant,
# spreading small epoch values across the 32-bit seed space).
if random_state is None:
random_state = int(numpy.random.randint(0, 2**31 - 1))
self._base_seed = int(random_state) % (2**31 - 1)
# _last_idx detects epoch boundaries by wrap-around (idx jumping
# backward). Each forked worker maintains its own copy.
self._last_idx = -1
self._epoch = -1
self._prepare_epoch(0)
def __len__(self):
return self.n_peaks + int(self.n_peaks * self.negative_ratio)
def _prepare_epoch(self, epoch):
"""Recompute per-epoch arrays from the (base_seed, epoch) RNG."""
self._epoch = epoch
seed = (self._base_seed + epoch * 2654435761) % (2**31 - 1)
rng = numpy.random.RandomState(seed)
n = len(self)
# Peak ordering — each peak appears exactly once. Kept as an
# attribute for introspection.
self.peak_ordering = (rng.permutation(self.n_peaks) if self.shuffle
else numpy.arange(self.n_peaks))
# Per-position label: True at exactly n_peaks slots, False at the
# remaining negative slots.
labels = rng.permutation(numpy.arange(n) < self.n_peaks)
self._labels = labels
# Per-position source index into the peak or negative tensor.
# max(1, n_negatives) keeps randint's bounds valid even when
# n_negatives == 0; the size is also 0 in that case so no values
# are actually written.
source = numpy.empty(n, dtype=numpy.int64)
source[labels] = self.peak_ordering
source[~labels] = rng.randint(0, max(1, self.n_negatives),
size=int((~labels).sum()))
self._source_idx = source
# Per-position jitter (0 at negative positions) and rc flag.
if self.max_jitter > 0:
jitters = rng.randint(0, self.max_jitter * 2, size=n)
jitters[~labels] = 0
self._jitters = jitters
else:
self._jitters = numpy.zeros(n, dtype=numpy.int64)
if self.reverse_complement:
self._rc_flags = rng.randint(0, 2, size=n).astype(bool)
else:
self._rc_flags = numpy.zeros(n, dtype=bool)
def __getitem__(self, idx):
if idx < self._last_idx:
self._prepare_epoch(self._epoch + 1)
self._last_idx = idx
is_peak = bool(self._labels[idx])
src = int(self._source_idx[idx])
j = int(self._jitters[idx])
if is_peak:
X, y, X_ctl = (self.peak_sequences, self.peak_signals,
self.peak_controls)
else:
X, y, X_ctl = (self.negative_sequences, self.negative_signals,
self.negative_controls)
Xi = torch.from_numpy(X[src][:, j:j+self.in_window])
yi = torch.from_numpy(y[src][:, j:j+self.out_window])
Xi_ctl = (torch.from_numpy(X_ctl[src][:, j:j+self.in_window])
if self.peak_controls is not None else None)
if self._rc_flags[idx]:
Xi = torch.flip(Xi, [0, 1])
yi = torch.flip(yi, [0, 1])
if Xi_ctl is not None:
Xi_ctl = torch.flip(Xi_ctl, [0, 1])
if Xi_ctl is not None:
return Xi, Xi_ctl, yi, int(is_peak)
return Xi, yi, int(is_peak)
[docs]
def PeakGenerator(peaks, negatives, sequences, signals, controls=None,
chroms=None, in_window=2114, out_window=1000, max_jitter=50,
negative_ratio=0.1, reverse_complement=True, shuffle=True, min_counts=None,
max_counts=None, summits=False, exclusion_lists=None, random_state=None,
pin_memory=True, num_workers=1, batch_size=32, verbose=False):
"""This is a constructor function that handles all IO.
This function will extract signal from all signal and control files,
pass that into a DataGenerator, and wrap that using a PyTorch data
loader. This is the only function that needs to be used.
Parameters
----------
peaks: str or pandas.DataFrame or list/tuple of such
A BED-formatted file containing peak coordinates. This can be either
the string path to the BED file or a pandas DataFrame object containing
three columns: chrom, start, and end. Alternatively, this can be a list
of such objects whose coordinates will be interleaved.
negatives: str or pandas.DataFrame or list/tuple of such
A BED-formatted file containing negative coordinates. This can be either
the string path to the BED file or a pandas DataFrame object containing
three columns: chrom, start, and end. Alternatively, this can be a list
of such objects whose coordinates will be interleaved.
sequences: str or dictionary
Either the path to a fasta file to read from or a dictionary where the
keys are the unique set of chromosoms and the values are one-hot
encoded sequences as numpy arrays or memory maps.
signals: list of strs or list of dictionaries
A list of filepaths to bigwig files, where each filepath will be read
using pyBigWig, or a list of dictionaries where the keys are the same
set of unique chromosomes and the values are numpy arrays or memory
maps.
controls: list of strs or list of dictionaries or None, optional
A list of filepaths to bigwig files, where each filepath will be read
using pyBigWig, or a list of dictionaries where the keys are the same
set of unique chromosomes and the values are numpy arrays or memory
maps. If None, no control tensor is returned. Default is None.
chroms: list or None, optional
A set of chromosomes to extact loci from. Loci in other chromosomes
in the locus file are ignored. If None, all loci are used. Default is
None.
in_window: int, optional
The input window size. Default is 2114.
out_window: int, optional
The output window size. Default is 1000.
max_jitter: int, optional
The maximum amount of jitter to add, in either direction, to the
midpoints that are passed in. Default is 50.
negative_ratio: float, optional
The ratio of negatives compared to peaks in each batch. A value of 1 means
that each batch is balanced, and a value of 10 means that there would be 10
negatives for each positive. Note that this is independent of the number of
peaks and negatives provided. Even if the `peaks` input has 10x the number
of coordinates as the `negatives` one, if the ratio is 1 each batch during
training will be balanced (on average).
reverse_complement: bool, optional
Whether to reverse complement-augment half of the data. Default is True.
shuffle: bool, optional
Whether to randomly sample peaks, if True, or to proceed sequentially
through them, if False. Negatives are always randomly sampled. Default
is True.
min_counts: float or None, optional
The minimum number of counts, summed across the length of each example
and across all tasks, needed to be kept. If None, no minimum. Default
is None.
max_counts: float or None, optional
The maximum number of counts, summed across the length of each example
and across all tasks, needed to be kept. If None, no maximum. Default
is None.
summits: bool, optional
Whether to return a region centered around the summit instead of the center
between the start and end. If True, it will add the 10th column (index 9)
to the start to get the center of the window, and so the data must be in
narrowPeak format.
exclusion_lists: list or None, optional
A list of strings of filenames to BED-formatted files containing exclusion
lists, i.e., regions where overlapping loci should be filtered out. If None,
no filtering is performed based on exclusion zones. Default is None.
random_state: int or None, optional
Base seed for the sampler's deterministic per-epoch RNG. If None,
a seed is captured once from system entropy.
pin_memory: bool, optional
Whether to pin page memory to make data loading onto a GPU easier.
Default is True.
num_workers: int, optional
The number of processes fetching data at a time to feed into a
model. If 0, data is fetched from the main process (synchronous,
can become a bottleneck because each batch blocks the GPU).
Default is 1, which runs one async prefetch worker. Higher values
are safe and produce the **same** sequence of batches as
``num_workers = 1``, just faster: ``__getitem__(idx)`` is a pure
function of ``idx`` and the current epoch, so all workers compute
the same data for any given index.
batch_size: int, optional
The number of data elements per batch. Default is 32.
verbose: bool, optional
Whether to display a progress bar while loading. Default is False.
Returns
-------
X: torch.utils.data.DataLoader
A PyTorch DataLoader wrapped DataGenerator object.
"""
X_peaks = extract_loci(loci=peaks, sequences=sequences,
signals=signals, in_signals=controls, chroms=chroms, in_window=in_window,
out_window=out_window, max_jitter=max_jitter, min_counts=min_counts,
max_counts=max_counts, summits=summits, exclusion_lists=exclusion_lists,
ignore=list('QWERYUIOPSDFHJKLZXVBNM'), return_mask=True, verbose=verbose)
loci_counts = X_peaks[1].sum(dim=(1, 2))
outlier_threshold = torch.quantile(X_peaks[1].sum(dim=(1, 2)), 0.99) * 1.2
outlier_idxs = loci_counts > outlier_threshold
X_bg = extract_loci(loci=negatives, sequences=sequences,
signals=signals, in_signals=controls, chroms=chroms, in_window=in_window,
out_window=out_window, max_jitter=0, min_counts=min_counts,
max_counts=max_counts, summits=False, exclusion_lists=exclusion_lists,
ignore=list('QWERYUIOPSDFHJKLZXVBNM'), return_mask=True, verbose=verbose)
if verbose:
n_filtered_peaks = len(X_peaks[-1]) - X_peaks[-1].sum() + outlier_idxs.sum()
n_filtered_negatives = len(X_bg[-1]) - X_bg[-1].sum()
print("\nFiltered Peaks: {}".format(n_filtered_peaks))
print("Filtered Negatives: {}".format(n_filtered_negatives))
###
X_gen = PeakNegativeSampler(
peak_sequences=X_peaks[0][~outlier_idxs],
peak_signals=X_peaks[1][~outlier_idxs],
peak_controls=None if controls is None else X_peaks[2][~outlier_idxs],
negative_sequences=X_bg[0],
negative_signals=X_bg[1],
negative_controls=None if controls is None else X_bg[2],
negative_ratio=negative_ratio,
in_window=in_window,
out_window=out_window,
max_jitter=max_jitter,
reverse_complement=reverse_complement,
shuffle=shuffle,
random_state=random_state
)
X_gen = torch.utils.data.DataLoader(X_gen, pin_memory=pin_memory,
num_workers=num_workers, batch_size=batch_size,
persistent_workers=num_workers > 0)
return X_gen