cherimoya.cheri¶
The Cheri Block — Cherimoya’s adaptation of the ConvNeXt block — plus the fused dilated-conv-plus-layer-norm dispatcher used inside it. For the high-level model see cherimoya.cherimoya; for the architectural story behind these primitives see Architecture.
CheriBlock¶
- class cherimoya.cheri.CheriBlock(*args, **kwargs)[source]¶
Bases:
ModuleA single Cheri Block.
The Cheri Block is the core building block of the Cherimoya model. It adapts the ConvNeXt block to noisy genomics data, with the goal of mixing spatial and channel information cheaply while remaining stable to train.
The block performs the following operations on an input of shape (N, L, C):
A 3-tap depthwise dilated convolution that mixes spatial information independently for each channel.
A per-example layer normalization across the (length, channel) plane. The convolution and normalization are fused into one kernel.
A pointwise expansion linear projection from C to
expansion * Cchannels.A GELU non-linearity.
A pointwise contraction linear projection back to C channels.
A residual connection where the MLP output is scaled by a fixed constant (
residual_scale) before being added to the input. The small constant keeps the residual path near-identity at initialization, which stabilizes training of deep stacks.
Forward dispatch¶
The block selects one of three implementations per forward call:
CPU input → pure-PyTorch path (always; differentiable reference).
CUDA input +
torch.is_grad_enabled()→ training Triton path (FusedDilatedConvNormFuncfor conv+norm, PyTorch MLP).CUDA input + no_grad +
expansion * n_filters % 16 == 0→ fully fused inference megakernel (~2x faster, bf16 MLP dots). Any case that fails these conditions falls back to the training path, so existing model configurations keep working unchanged.
Existing trained checkpoints are bit-compatible: the parameter layout, init order, and forward semantics of the training path are unchanged. The inference megakernel produces outputs that differ from the training path by at most ~1e-5 max-abs at unit-scale outputs; this drift comes from bf16 weight casts in the MLP and is the precision/speed tradeoff documented in
_cast_weights.- param n_filters:
The number of channels (the C dimension).
- type n_filters:
int
- param dilation:
Dilation rate for the depthwise convolution. The kernel reads from positions (i - dilation, i, i + dilation) at each output position i, with zero padding outside the sequence.
- type dilation:
int
- param expansion:
The factor by which the inner MLP expands the channel dimension. The first projection maps
n_filters -> expansion * n_filtersand the second projects back. Default is 2.- type expansion:
int, optional
- param residual_scale:
Fixed scalar applied to the MLP output before it is added back to the residual stream. Default is 0.15.
- type residual_scale:
float, optional
Constructor
fused_dilated_conv_norm¶
- cherimoya.cheri.fused_dilated_conv_norm(x, w, dilation)[source]¶
Fused 3-tap dilated depthwise conv plus per-example layer norm.
Dispatches to the Triton kernel when x is a CUDA tensor and Triton is available, otherwise to a pure-PyTorch fallback. The two implementations are numerically equivalent up to floating-point error.
- Parameters:
x (torch.Tensor, shape=(N, L, C)) – The input tensor.
w (torch.Tensor, shape=(3, C)) – Depthwise convolution weights for the left, center, and right taps.
dilation (int) – Spacing between the three taps.
- Returns:
y – The convolved and normalized output.
- Return type:
torch.Tensor, shape=(N, L, C)
Public dispatcher used inside CheriBlock. Routes to the
training Triton kernel on CUDA when gradients are enabled, otherwise
to the pure-PyTorch fallback (_cheri_conv_norm_cpu). Numerically
equivalent up to floating-point error.
FusedDilatedConvNormFunc¶
- class cherimoya.cheri.FusedDilatedConvNormFunc(*args, **kwargs)[source]¶
Triton-backed fused dilated convolution + per-example layer norm.
Implements the same operation as
_cheri_conv_norm_cpu()but uses a custom Triton kernel to fuse the convolution, statistics reduction, and normalization steps into a small number of GPU passes. Only callable on CUDA tensors.Note: the first call on a given (C, L) shape triggers Triton autotune, and the user-visible backward output from that very first call is contaminated by atomic-add residue from the benchmarking trials (we measured ~7e-2 vs CPU autograd on one shape). Every subsequent call uses the locked-in best config and agrees with CPU autograd at fp32 precision. Training is unaffected in practice because iteration 2 onward is clean. Single-batch debugging or short pipelines should warm up the kernel before reading gradients.
- Parameters:
args (Any)
kwargs (Any)
- Return type:
Any
A custom torch.autograd.Function that fuses the 3-tap dilated
depthwise convolution and per-example layer normalization. User code
interacts with it through fused_dilated_conv_norm() and
CheriBlock.
Forward kernels¶
The forward pass runs three Triton kernels in sequence:
_fwd_stats_kernel— runs over(N, num_chunks_of_L)blocks. For each block it loads the three dilated taps, computes the convolved output and stores it into the output buffery, and atomically accumulates per-examplesumandsq_sumof the convolved values into(N,)float32 buffers._fwd_finalize_kernel— runs over(N,): turns the per-examplesum/sq_sumintomean/rstdwith theeps=1e-3numerical stability constant._fwd_apply_kernel— runs over(N, num_chunks_of_L): loadsyin-place, subtracts the mean, multiplies byrstd, writes back the normalized value.
Layer-norm statistics are always accumulated and stored in fp32 even
when y is bf16, which keeps the normalization numerically robust
under autocast.
Backward kernels¶
The backward pass also runs three Triton kernels in sequence and
recomputes the convolved output during the backward pass rather
than caching it from the forward — trading a small amount of FLOPS
for the activation memory of an extra (N, L, C) tensor:
_bwd_stats_kernel— recomputes the conv outputconv = x0*w0 + x1*w1 + x2*w2, stores it into a scratch buffer, and atomically accumulates two per-example reductions:sum_dyandsum_dy_xhat(wherexhat = (conv - mean) * rstd). These are the two scalars the layer-norm backward needs._bwd_apply_kernel— usessum_dyandsum_dy_xhatto computed_conv = (rstd/count) * (count*dy - sum_dy - xhat * sum_dy_xhat)for every position, and accumulates the depthwise weight gradientdwby computingd_conv * x{0,1,2}and reducing along the length axis. Per-blockdwpartials are written into a(N * num_chunks, 3 * C)buffer; the finaldwis the sum over all partials (handled outside the kernel)._bwd_dx_kernel— applies the dilated transpose convolution tod_convto producedx: at each positionp,dx[p] = d_conv[p] * w1 + d_conv[p + d] * w0 + d_conv[p - d] * w2(the three terms each masked at the sequence ends).
The first call on a given (C, L) shape triggers Triton autotune,
and the user-visible gradient output from that very first call is
contaminated by atomic-add residue from the benchmarking trials.
Subsequent calls use the locked-in best config and agree with CPU
autograd at fp32 precision. The test suite warms up the kernel
before any gradient is read; training is unaffected in practice
because step 2 onward is clean.
Autotune configuration space¶
Both the forward stats kernel and the two backward kernels are autotuned over the cartesian product of:
num_warps: 4, 8, 16num_stages: 2, 3, 4, 5
with the autotune key set to (C, L). BLOCK_C is set to
triton.next_power_of_2(C) per call (so it adapts to the
actual channel width); BLOCK_L is fixed at 64.
The inference megakernel autotunes its own normalization-plus-MLP
kernel over the same warp/stage grid, with an additional
BLOCK_HK (hidden width tile) constraint applied via the
prune_configs_by={'early_config_prune': ...} callback to keep
the hidden tile a divisor of expansion * n_filters.
Inference megakernel¶
When torch.is_grad_enabled() is False and the MLP hidden
width is a multiple of 16, the Cheri Block dispatches to a fused
inference megakernel that performs conv + norm + MLP + residual in
two GPU passes rather than four separate ops. The implementation
lives under the _fwd_inf_ prefix in
cherimoya/cheri.py. It is not part of the public API, but the
behavior is observable:
The bf16 cast of the linear weights is keyed by
(id, _version)on both parameters, so any in-place update (includingload_state_dict) invalidates the cache automatically.The cache lives on the
CheriBlockinstance (CheriBlock._w_cache) and is not part of the state dict.The path produces outputs that differ from the training Triton path by at most ~1e-5 max-abs at unit-scale outputs.