Model

Cherimoya

class cherimoya.cherimoya.Cherimoya(*args, **kwargs)[source]

Bases: Module

The Cherimoya sequence-to-function model.

Parameters:
  • n_filters (int, optional) – Width of the convolutional backbone (the channel dimension). Default is 96.

  • n_layers (int, optional) – Number of stacked Cheri Blocks. Block i uses dilation 2**i. Default is 9.

  • n_outputs (int, optional) – Number of output tracks for the profile head. Default is 1.

  • n_control_tracks (int, optional) – Number of control input tracks. If 0, the model takes only the one-hot sequence as input. Default is 0.

  • expansion (int, optional) – Channel-expansion factor for the MLP inside each Cheri Block. The inner projection maps n_filters -> expansion * n_filters and then back. Default is 2.

  • residual_scale (float, optional) – Fixed scalar applied to the MLP output of each Cheri Block before it is added back to the residual stream. Default is 0.15.

  • name (str or None, optional) – Display name used when saving model files. Defaults to "cherimoya.{n_filters}.{n_layers}".

  • trimming (int or None, optional) – Number of base pairs to trim from each side of the input when producing the output profile. If None, defaults to 46 + sum(2**i for i in range(n_layers)).

  • single_count_output (bool, optional) – If True, the count head returns a single scalar per example; otherwise it returns one count per output track. Default is True.

  • verbose (bool, optional) – Whether the training-progress logger prints to stdout. Default is True.

Constructor

__init__(n_filters=96, n_layers=9, n_outputs=1, n_control_tracks=0, expansion=2, residual_scale=0.15, name=None, trimming=None, single_count_output=True, verbose=True)[source]
forward(X, X_ctl=None)

A forward pass of the model.

This method takes in a nucleotide sequence X, a corresponding per-position value from a control track, and a per-locus value from the control track and makes predictions for the profile and for the counts. This per-locus value is usually the log(sum(X_ctl_profile)+1) when the control is an experimental read track but can also be the output from another model.

Parameters:
  • X (torch.tensor, shape=(batch_size, 4, length)) – The one-hot encoded batch of sequences.

  • X_ctl (torch.tensor or None, shape=(batch_size, n_strands, length)) – A value representing the signal of the control at each position in the sequence. If no controls, pass in None. Default is None.

Returns:

y_profile – The output predictions for each strand trimmed to the output length.

Return type:

torch.tensor, shape=(batch_size, n_strands, out_length)

fit(training_data, muon_optimizer, adam_optimizer, muon_scheduler, adam_scheduler, X_valid, X_ctl_valid, y_valid, max_epochs=50, batch_size=64, dtype='float32', device='cuda', early_stopping=None)[source]

Fit the model to data and validate it periodically.

This method controls the training of a Cherimoya model. It will fit the model to examples generated by the training_data DataLoader object and, if validation data is provided, will validate the model against it at the end of each epoch and return those values.

Two versions of the model will be saved using save(): the best model found during training according to the validation measures, and the final model at the end of training. Additionally, a log will be saved of the training and validation statistics, e.g. time and performance.

Parameters:
  • training_data (torch.utils.data.DataLoader) – A generator that produces examples to train on. If n_control_tracks is greater than 0, must product two inputs, otherwise must produce only one input.

  • muon_optimizer (torch.optim.Optimizer) – A Muon optimizer to control the training of the 2D non-head/non-tail layers in the model. This is mostly the dense layers and depth-wise convolutions of the Cheri blocks.

  • adam_optimizer (torch.optim.Optimizer) – An Adam/W optimizer to control the training of the other parametrers. This should be the head/tail layers, the bias terms, and any other parameters that are not 2D matrices.

  • muon_scheduler (torch.optim.lr_scheduler) – The scheduler to use for the Muon optimizer. This should likely be a cosine decay with a warmup phase.

  • adam_scheduler (torch.optim.lr_scheduler) – The scheduler to use for the Adam/W optimizer. This should likely be the same cosine decay with a warmup phase used for the Muon optimizer.

  • X_valid (torch.tensor, shape=(n, 4, length)) – A block of sequences to validate on at the end of each epoch.

  • X_ctl_valid (torch.tensor or None, shape=(n, n_control_tracks, length)) – A block of control sequences to use for making the validation set predictions at the end of each epoch. If n_control_tracks is None, pass in None. Default is None.

  • y_valid (torch.tensor or None, shape=(n, n_outputs, output_length)) – A block of signals to validate against at the end of each epochs.

  • max_epochs (int) – The maximum number of epochs to train for, as measured by the number of times that training_data is exhausted. Default is 50.

  • batch_size (int, optional) – The number of examples to include in each batch. Default is 64.

  • dtype (str or torch.dtype) – The torch.dtype to use when training. Usually, this will be torch.float32 or torch.bfloat16. Default is torch.float32.

  • device (str) – The device to use for training and inference. Typically, this will be ‘cuda’ but can be anything supported by torch. Default is ‘cuda’.

  • early_stopping (int or None, optional) – Whether to stop training early. If None, continue training until max_epochs is reached. If an integer, continue training until that number of epochs has been hit without improvement in performance. Default is None.

CheriBlock

class cherimoya.cherimoya.CheriBlock(*args, **kwargs)[source]

Bases: Module

A single Cheri Block.

The Cheri Block is the core building block of the Cherimoya model. It adapts the ConvNeXt block to noisy genomics data, with the goal of mixing spatial and channel information cheaply while remaining stable to train.

The block performs the following operations on an input of shape (N, L, C):

  1. A 3-tap depthwise dilated convolution that mixes spatial information independently for each channel.

  2. A per-example layer normalization across the (length, channel) plane. The convolution and normalization are fused into one kernel.

  3. A pointwise expansion linear projection from C to expansion * C channels.

  4. A GELU non-linearity.

  5. A pointwise contraction linear projection back to C channels.

  6. A residual connection where the MLP output is scaled by a fixed constant (residual_scale) before being added to the input. The small constant keeps the residual path near-identity at initialization, which stabilizes training of deep stacks.

Parameters:
  • n_filters (int) – The number of channels (the C dimension).

  • dilation (int) – Dilation rate for the depthwise convolution. The kernel reads from positions (i - dilation, i, i + dilation) at each output position i, with zero padding outside the sequence.

  • expansion (int, optional) – The factor by which the inner MLP expands the channel dimension. The first projection maps n_filters -> expansion * n_filters and the second projects back. Default is 2.

  • residual_scale (float, optional) – Fixed scalar applied to the MLP output before it is added back to the residual stream. Default is 0.15.

forward(X)[source]

Run the block on an input of shape (N, L, C).

FusedDilatedConvNormFunc