# cherimoya.py
# Author: Jacob Schreiber <jmschreiber91@gmail.com>
"""
An implementation of the Cherimoya deep learning model, a compact
architecture for predicting genomic modalities from sequence alone.
"""
import time
import numpy
import torch
from .cheri import CheriBlock
from .losses import _mixture_loss
from .performance import calculate_performance_measures
from tangermeme.predict import predict
from bpnetlite.logging import Logger
torch.set_float32_matmul_precision('high')
class EMA:
"""Exponential moving average of a model's parameters.
Maintains a shadow copy of every floating-point parameter that is
updated as ``shadow = decay * shadow + (1 - decay) * parameter`` after
each training step. The shadow weights are typically used at
evaluation time, where they tend to produce smoother and more stable
predictions than the raw running weights.
Typical usage during training:
1. Create an EMA wrapper after the model is constructed.
2. Call :meth:`update` after every optimizer step.
3. Call :meth:`apply_shadow` before evaluation to swap the shadow
weights into the model.
4. Call :meth:`restore` after evaluation to put the training weights
back.
Parameters
----------
model: torch.nn.Module
The model whose parameters will be tracked.
decay: float, optional
The decay factor of the moving average. Larger values place more
weight on the running shadow and less on each new update. Default
is 0.999.
"""
def __init__(self, model, decay=0.999):
self.decay = decay
self.shadow = {}
self._backup = {}
for name, p in model.named_parameters():
if p.requires_grad and p.is_floating_point():
self.shadow[name] = p.detach().clone()
@torch.no_grad()
def update(self, model):
"""Update the shadow weights using the current model parameters."""
d = self.decay
for name, p in model.named_parameters():
if name in self.shadow:
self.shadow[name].mul_(d).add_(p.detach(), alpha=1.0 - d)
@torch.no_grad()
def apply_shadow(self, model):
"""Swap the model's parameters with the shadow weights.
The original weights are kept in an internal backup so they can
be restored after evaluation. Calling this method twice in a row
without an intervening :meth:`restore` is an error.
"""
assert not self._backup
for name, p in model.named_parameters():
if name in self.shadow:
self._backup[name] = p.detach().clone()
p.data.copy_(self.shadow[name].data)
@torch.no_grad()
def restore(self, model):
"""Put the original training weights back into the model."""
for name, p in model.named_parameters():
if name in self._backup:
p.data.copy_(self._backup[name].data)
self._backup = {}
[docs]
class Cherimoya(torch.nn.Module):
"""The Cherimoya sequence-to-function model.
Parameters
----------
n_filters: int, optional
Width of the convolutional backbone (the channel dimension).
Default is 96.
n_layers: int, optional
Number of stacked Cheri Blocks. Block ``i`` uses dilation
``2**i``. Default is 9.
n_outputs: int, optional
Number of output tracks for the profile head. Default is 1.
n_control_tracks: int, optional
Number of control input tracks. If 0, the model takes only the
one-hot sequence as input. Default is 0.
expansion: int, optional
Channel-expansion factor for the MLP inside each Cheri Block. The
inner projection maps ``n_filters -> expansion * n_filters`` and
then back. Default is 2.
residual_scale: float, optional
Fixed scalar applied to the MLP output of each Cheri Block before
it is added back to the residual stream. Default is 0.15.
name: str or None, optional
Display name used when saving model files. Defaults to
``"cherimoya.{n_filters}.{n_layers}"``.
trimming: int or None, optional
Number of base pairs to trim from each side of the input when
producing the output profile. If None, defaults to
``46 + sum(2**i for i in range(n_layers))``.
single_count_output: bool, optional
If True, the count head returns a single scalar per example;
otherwise it returns one count per output track. Default is True.
verbose: bool, optional
Whether the training-progress logger prints to stdout. Default
is True.
"""
[docs]
def __init__(self, n_filters=96, n_layers=9, n_outputs=1,
n_control_tracks=0, expansion=2, residual_scale=0.15, name=None,
trimming=None, single_count_output=True, verbose=True):
super(Cherimoya, self).__init__()
self.n_filters = n_filters
self.n_layers = n_layers
self.n_outputs = n_outputs
self.n_control_tracks = n_control_tracks
self.expansion = expansion
self.residual_scale = residual_scale
self.single_count_output = single_count_output
self.name = name or "cherimoya.{}.{}".format(n_filters, n_layers)
self.trimming = trimming if trimming is not None else (
46 + sum(2**i for i in range(n_layers)))
self.iconv = torch.nn.Conv1d(4, n_filters, kernel_size=21, padding=10)
self.igelu = torch.nn.GELU(approximate='tanh')
self.blocks = torch.nn.ModuleList([
CheriBlock(n_filters, 2**i, expansion=expansion,
residual_scale=residual_scale)
for i in range(self.n_layers)
])
self.fconv = torch.nn.Conv1d(n_filters+n_control_tracks, n_outputs,
kernel_size=1, padding=0)
self.lw0 = torch.nn.Parameter(torch.ones(1))
self.lw1 = torch.nn.Parameter(torch.ones(1))
n_count_control = 1 if n_control_tracks > 0 else 0
n_count_outputs = 1 if single_count_output else n_outputs
self.linear = torch.nn.Linear(n_filters+n_count_control, n_count_outputs)
torch.nn.init.trunc_normal_(self.iconv.weight, std=0.02)
torch.nn.init.trunc_normal_(self.fconv.weight, std=0.02)
torch.nn.init.trunc_normal_(self.linear.weight, std=0.02)
torch.nn.init.zeros_(self.iconv.bias)
torch.nn.init.zeros_(self.fconv.bias)
torch.nn.init.zeros_(self.linear.bias)
self.logger = Logger(["Epoch", "Iteration", "Training Time",
"Validation Time", "Training MNLL", "Training Count MSE",
"Validation MNLL", "Validation Profile Pearson",
"Validation Count Pearson", "Validation Count MSE", "Saved?"],
verbose=verbose)
def _init_kwargs(self):
"""Return the kwargs needed to reconstruct this model."""
return {
'n_filters': self.n_filters,
'n_layers': self.n_layers,
'n_outputs': self.n_outputs,
'n_control_tracks': self.n_control_tracks,
'expansion': self.expansion,
'residual_scale': self.residual_scale,
'name': self.name,
'trimming': self.trimming,
'single_count_output': self.single_count_output,
'verbose': False,
}
def save(self, path):
"""Save the model to a file.
The checkpoint stores the constructor arguments needed to rebuild
the model along with its parameter state dict. This format can be
loaded with ``weights_only=True`` and is robust to changes in
source layout.
Parameters
----------
path: str
The destination file path.
"""
payload = {
'config': self._init_kwargs(),
'state_dict': self.state_dict(),
}
torch.save(payload, path)
@classmethod
def load(cls, path, device='cpu'):
"""Load a model previously saved with :meth:`save`.
Parameters
----------
path: str
The checkpoint file path.
device: str or torch.device, optional
Device to map the parameters onto. Default is ``'cpu'``.
Returns
-------
model: Cherimoya
The reconstructed model, placed on ``device``.
"""
payload = torch.load(path, map_location=device, weights_only=True)
model = cls(**payload['config'])
model.load_state_dict(payload['state_dict'])
return model.to(device)
@torch.compile(mode='max-autotune')
def forward(self, X, X_ctl=None):
"""A forward pass of the model.
This method takes in a nucleotide sequence X, a corresponding
per-position value from a control track, and a per-locus value
from the control track and makes predictions for the profile
and for the counts. This per-locus value is usually the
log(sum(X_ctl_profile)+1) when the control is an experimental
read track but can also be the output from another model.
Parameters
----------
X: torch.tensor, shape=(batch_size, 4, length)
The one-hot encoded batch of sequences.
X_ctl: torch.tensor or None, shape=(batch_size, n_strands, length)
A value representing the signal of the control at each position in
the sequence. If no controls, pass in None. Default is None.
Returns
-------
y_profile: torch.tensor, shape=(batch_size, n_strands, out_length)
The output predictions for each strand trimmed to the output
length.
"""
start, end = self.trimming, X.shape[2] - self.trimming
X = self.igelu(self.iconv(X))
X = X.transpose(1, 2).contiguous()
for i in range(self.n_layers):
X = self.blocks[i](X)
X = X.transpose(1, 2).contiguous()
if X_ctl is None:
X_w_ctl = X
else:
X_w_ctl = torch.cat([X, X_ctl], dim=1)
y_profile = self.fconv(X_w_ctl)[:, :, start:end]
# counts prediction
X = torch.mean(X[:, :, start:end].float(), dim=2)
if X_ctl is not None:
X_ctl = torch.sum(X_ctl[:, :, start:end].float(), dim=(1, 2))
X_ctl = X_ctl.unsqueeze(-1)
X = torch.cat([X, torch.log(X_ctl+1)], dim=-1)
y_counts = self.linear(X)
return y_profile, y_counts
[docs]
def fit(self, training_data, muon_optimizer, adam_optimizer, muon_scheduler,
adam_scheduler, X_valid, X_ctl_valid, y_valid, max_epochs=50, batch_size=64,
dtype='float32', device='cuda', early_stopping=None):
"""Fit the model to data and validate it periodically.
This method controls the training of a Cherimoya model. It will fit
the model to examples generated by the `training_data` DataLoader
object and, if validation data is provided, will validate the model
against it at the end of each epoch and return those values.
Two versions of the model will be saved using :meth:`save`: the best
model found during training according to the validation measures, and
the final model at the end of training. Additionally, a log will be
saved of the training and validation statistics, e.g. time and
performance.
Parameters
----------
training_data: torch.utils.data.DataLoader
A generator that produces examples to train on. If n_control_tracks
is greater than 0, must product two inputs, otherwise must produce
only one input.
muon_optimizer: torch.optim.Optimizer
A Muon optimizer to control the training of the 2D non-head/non-tail layers
in the model. This is mostly the dense layers and depth-wise convolutions of
the Cheri blocks.
adam_optimizer: torch.optim.Optimizer
An Adam/W optimizer to control the training of the other parametrers. This
should be the head/tail layers, the bias terms, and any other parameters
that are not 2D matrices.
muon_scheduler: torch.optim.lr_scheduler
The scheduler to use for the Muon optimizer. This should likely be a cosine
decay with a warmup phase.
adam_scheduler: torch.optim.lr_scheduler
The scheduler to use for the Adam/W optimizer. This should likely be the
same cosine decay with a warmup phase used for the Muon optimizer.
X_valid: torch.tensor, shape=(n, 4, length)
A block of sequences to validate on at the end of each epoch.
X_ctl_valid: torch.tensor or None, shape=(n, n_control_tracks, length)
A block of control sequences to use for making the validation set
predictions at the end of each epoch. If n_control_tracks is None, pass in
None. Default is None.
y_valid: torch.tensor or None, shape=(n, n_outputs, output_length)
A block of signals to validate against at the end of each epochs.
max_epochs: int
The maximum number of epochs to train for, as measured by the
number of times that `training_data` is exhausted. Default is 50.
batch_size: int, optional
The number of examples to include in each batch. Default is 64.
dtype: str or torch.dtype
The torch.dtype to use when training. Usually, this will be torch.float32
or torch.bfloat16. Default is torch.float32.
device: str
The device to use for training and inference. Typically, this will be
'cuda' but can be anything supported by torch. Default is 'cuda'.
early_stopping: int or None, optional
Whether to stop training early. If None, continue training until
max_epochs is reached. If an integer, continue training until that
number of epochs has been hit without improvement in performance.
Default is None.
"""
if X_valid is not None:
y_valid_counts = y_valid.sum(dim=2)
if X_ctl_valid is not None:
X_ctl_valid = (X_ctl_valid,)
dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
iteration = 0
early_stop_count = 0
best_corr = float("-inf")
self.logger.start()
ema = EMA(self, decay=0.999)
###
for epoch in range(max_epochs):
tic = time.time()
for data in training_data:
X, y, labels = data[0], data[-2], data[-1]
X_ctl = data[1].to(device) if len(data) == 4 else None
if X.shape[0] != batch_size:
continue
X = X.to(device).float()
y = y.to(device)
# Clear the optimizer and set the model to training mode
muon_optimizer.zero_grad()
adam_optimizer.zero_grad()
self.train()
# Make one training step
with torch.autocast(device_type=device, dtype=dtype):
y_hat_logits, y_hat_logcounts = self(X, X_ctl)
profile_loss, count_loss = _mixture_loss(y, y_hat_logits.float(),
y_hat_logcounts.float())
w0 = (1.0 / (2.0 * self.lw0 ** 2))
w1 = (1.0 / (2.0 * self.lw1 ** 2))
loss = w0*profile_loss + w1*count_loss
if self.lw0.requires_grad == True:
loss += torch.sum(torch.log(self.lw0) ** 2 + torch.log(self.lw1) ** 2)
loss.backward()
muon_optimizer.step()
adam_optimizer.step()
muon_scheduler.step()
adam_scheduler.step()
ema.update(self)
iteration += 1
train_time = time.time() - tic
if self.lw0.requires_grad == True and torch.abs(self.lw0.grad).sum() < 1:
self.lw0.requires_grad = False
self.lw1.requires_grad = False
# Validate the model at the end of the epoch
with torch.no_grad():
self.eval()
ema.apply_shadow(self)
tic = time.time()
y_hat_logits, y_hat_logcounts = predict(self, X_valid, args=X_ctl_valid,
batch_size=batch_size, dtype=dtype, device=device)
valid_profile_loss, valid_count_loss = _mixture_loss(y_valid,
y_hat_logits, y_hat_logcounts)
valid_loss = w0*valid_profile_loss + w1*valid_count_loss
measures = calculate_performance_measures(y_hat_logits,
y_valid, y_hat_logcounts, measures=['profile_pearson', 'count_pearson'])
valid_profile_corr = numpy.nan_to_num(measures['profile_pearson'])
valid_count_corr = numpy.nan_to_num(measures['count_pearson']).mean()
valid_time = time.time() - tic
self.logger.add([epoch,
iteration,
train_time,
valid_time,
profile_loss.item(),
count_loss.item(),
valid_profile_loss.item(),
valid_profile_corr.mean(),
valid_count_corr,
valid_count_loss.item(),
(valid_count_corr > best_corr).item()])
self.logger.save("{}.log".format(self.name))
if valid_count_corr > best_corr:
self.save("{}.torch".format(self.name))
best_corr = valid_count_corr
early_stop_count = -1
ema.restore(self)
early_stop_count += 1
if early_stopping is not None and early_stop_count >= early_stopping:
break
ema.apply_shadow(self)
self.save("{}.final.torch".format(self.name))