cherimoya.cheri

The Cheri Block — Cherimoya’s adaptation of the ConvNeXt block — plus the fused dilated-conv-plus-layer-norm dispatcher used inside it. For the high-level model see cherimoya.cherimoya; for the architectural story behind these primitives see Architecture.

CheriBlock

class cherimoya.cheri.CheriBlock(*args, **kwargs)[source]

Bases: Module

A single Cheri Block.

The Cheri Block is the core building block of the Cherimoya model. It adapts the ConvNeXt block to noisy genomics data, with the goal of mixing spatial and channel information cheaply while remaining stable to train.

The block performs the following operations on an input of shape (N, L, C):

  1. A 3-tap depthwise dilated convolution that mixes spatial information independently for each channel.

  2. A per-example layer normalization across the (length, channel) plane. The convolution and normalization are fused into one kernel.

  3. A pointwise expansion linear projection from C to expansion * C channels.

  4. A GELU non-linearity.

  5. A pointwise contraction linear projection back to C channels.

  6. A residual connection where the MLP output is scaled by a fixed constant (residual_scale) before being added to the input. The small constant keeps the residual path near-identity at initialization, which stabilizes training of deep stacks.

Forward dispatch

The block selects one of three implementations per forward call:

  • CPU input → pure-PyTorch path (always; differentiable reference).

  • CUDA input + torch.is_grad_enabled() → training Triton path (FusedDilatedConvNormFunc for conv+norm, PyTorch MLP).

  • CUDA input + no_grad + expansion * n_filters % 16 == 0 → fully fused inference megakernel (~2x faster, bf16 MLP dots). Any case that fails these conditions falls back to the training path, so existing model configurations keep working unchanged.

Existing trained checkpoints are bit-compatible: the parameter layout, init order, and forward semantics of the training path are unchanged. The inference megakernel produces outputs that differ from the training path by at most ~1e-5 max-abs at unit-scale outputs; this drift comes from bf16 weight casts in the MLP and is the precision/speed tradeoff documented in _cast_weights.

param n_filters:

The number of channels (the C dimension).

type n_filters:

int

param dilation:

Dilation rate for the depthwise convolution. The kernel reads from positions (i - dilation, i, i + dilation) at each output position i, with zero padding outside the sequence.

type dilation:

int

param expansion:

The factor by which the inner MLP expands the channel dimension. The first projection maps n_filters -> expansion * n_filters and the second projects back. Default is 2.

type expansion:

int, optional

param residual_scale:

Fixed scalar applied to the MLP output before it is added back to the residual stream. Default is 0.15.

type residual_scale:

float, optional

Constructor

__init__(n_filters, dilation, expansion=2, residual_scale=0.15)[source]
forward(X)[source]

Run the block on an input of shape (N, L, C).

fused_dilated_conv_norm

cherimoya.cheri.fused_dilated_conv_norm(x, w, dilation)[source]

Fused 3-tap dilated depthwise conv plus per-example layer norm.

Dispatches to the Triton kernel when x is a CUDA tensor and Triton is available, otherwise to a pure-PyTorch fallback. The two implementations are numerically equivalent up to floating-point error.

Parameters:
  • x (torch.Tensor, shape=(N, L, C)) – The input tensor.

  • w (torch.Tensor, shape=(3, C)) – Depthwise convolution weights for the left, center, and right taps.

  • dilation (int) – Spacing between the three taps.

Returns:

y – The convolved and normalized output.

Return type:

torch.Tensor, shape=(N, L, C)

Public dispatcher used inside CheriBlock. Routes to the training Triton kernel on CUDA when gradients are enabled, otherwise to the pure-PyTorch fallback (_cheri_conv_norm_cpu). Numerically equivalent up to floating-point error.

FusedDilatedConvNormFunc

class cherimoya.cheri.FusedDilatedConvNormFunc(*args, **kwargs)[source]

Triton-backed fused dilated convolution + per-example layer norm.

Implements the same operation as _cheri_conv_norm_cpu() but uses a custom Triton kernel to fuse the convolution, statistics reduction, and normalization steps into a small number of GPU passes. Only callable on CUDA tensors.

Note: the first call on a given (C, L) shape triggers Triton autotune, and the user-visible backward output from that very first call is contaminated by atomic-add residue from the benchmarking trials (we measured ~7e-2 vs CPU autograd on one shape). Every subsequent call uses the locked-in best config and agrees with CPU autograd at fp32 precision. Training is unaffected in practice because iteration 2 onward is clean. Single-batch debugging or short pipelines should warm up the kernel before reading gradients.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

Any

static forward(ctx, x, w, dilation)[source]
static backward(ctx, dy)[source]

A custom torch.autograd.Function that fuses the 3-tap dilated depthwise convolution and per-example layer normalization. User code interacts with it through fused_dilated_conv_norm() and CheriBlock.

Forward kernels

The forward pass runs three Triton kernels in sequence:

  • _fwd_stats_kernel — runs over (N, num_chunks_of_L) blocks. For each block it loads the three dilated taps, computes the convolved output and stores it into the output buffer y, and atomically accumulates per-example sum and sq_sum of the convolved values into (N,) float32 buffers.

  • _fwd_finalize_kernel — runs over (N,): turns the per-example sum/sq_sum into mean/rstd with the eps=1e-3 numerical stability constant.

  • _fwd_apply_kernel — runs over (N, num_chunks_of_L): loads y in-place, subtracts the mean, multiplies by rstd, writes back the normalized value.

Layer-norm statistics are always accumulated and stored in fp32 even when y is bf16, which keeps the normalization numerically robust under autocast.

Backward kernels

The backward pass also runs three Triton kernels in sequence and recomputes the convolved output during the backward pass rather than caching it from the forward — trading a small amount of FLOPS for the activation memory of an extra (N, L, C) tensor:

  • _bwd_stats_kernel — recomputes the conv output conv = x0*w0 + x1*w1 + x2*w2, stores it into a scratch buffer, and atomically accumulates two per-example reductions: sum_dy and sum_dy_xhat (where xhat = (conv - mean) * rstd). These are the two scalars the layer-norm backward needs.

  • _bwd_apply_kernel — uses sum_dy and sum_dy_xhat to compute d_conv = (rstd/count) * (count*dy - sum_dy - xhat * sum_dy_xhat) for every position, and accumulates the depthwise weight gradient dw by computing d_conv * x{0,1,2} and reducing along the length axis. Per-block dw partials are written into a (N * num_chunks, 3 * C) buffer; the final dw is the sum over all partials (handled outside the kernel).

  • _bwd_dx_kernel — applies the dilated transpose convolution to d_conv to produce dx: at each position p, dx[p] = d_conv[p] * w1 + d_conv[p + d] * w0 + d_conv[p - d] * w2 (the three terms each masked at the sequence ends).

The first call on a given (C, L) shape triggers Triton autotune, and the user-visible gradient output from that very first call is contaminated by atomic-add residue from the benchmarking trials. Subsequent calls use the locked-in best config and agree with CPU autograd at fp32 precision. The test suite warms up the kernel before any gradient is read; training is unaffected in practice because step 2 onward is clean.

Autotune configuration space

Both the forward stats kernel and the two backward kernels are autotuned over the cartesian product of:

  • num_warps: 4, 8, 16

  • num_stages: 2, 3, 4, 5

with the autotune key set to (C, L). BLOCK_C is set to triton.next_power_of_2(C) per call (so it adapts to the actual channel width); BLOCK_L is fixed at 64.

The inference megakernel autotunes its own normalization-plus-MLP kernel over the same warp/stage grid, with an additional BLOCK_HK (hidden width tile) constraint applied via the prune_configs_by={'early_config_prune': ...} callback to keep the hidden tile a divisor of expansion * n_filters.

Inference megakernel

When torch.is_grad_enabled() is False and the MLP hidden width is a multiple of 16, the Cheri Block dispatches to a fused inference megakernel that performs conv + norm + MLP + residual in two GPU passes rather than four separate ops. The implementation lives under the _fwd_inf_ prefix in cherimoya/cheri.py. It is not part of the public API, but the behavior is observable:

  • The bf16 cast of the linear weights is keyed by (id, _version) on both parameters, so any in-place update (including load_state_dict) invalidates the cache automatically.

  • The cache lives on the CheriBlock instance (CheriBlock._w_cache) and is not part of the state dict.

  • The path produces outputs that differ from the training Triton path by at most ~1e-5 max-abs at unit-scale outputs.