Source code for cherimoya.cheri

# cheri.py
# Author: Jacob Schreiber <jmschreiber91@gmail.com>

"""
The Cheri block — Cherimoya's adaptation of the ConvNeXt block to genomics.

Each block performs a 3-tap dilated depthwise convolution fused with a
per-example layer normalization, followed by an MLP expansion path with a
learnable per-channel residual scale (gamma). The fused conv+norm operation
has two implementations: a Triton GPU kernel for CUDA tensors, and a pure
PyTorch fallback that is used on CPU and on platforms without Triton. Both
paths are numerically equivalent up to floating-point error.
"""

import itertools

import torch
import torch.nn as nn
import torch.nn.functional as F


try:
	import triton
	import triton.language as tl
	HAS_TRITON = True
except ImportError:
	HAS_TRITON = False


CONV_NORM_EPS = 1e-3


def _cheri_conv_norm_cpu(x, w, dilation, eps=CONV_NORM_EPS):
	"""Pure PyTorch implementation of the fused dilated conv + norm.

	Performs a 3-tap depthwise dilated convolution followed by a per-example
	layer normalization across the (length, channels) plane. This is the
	fallback used on CPU and on platforms where Triton is unavailable. It
	participates in autograd automatically because it is built from
	standard differentiable PyTorch ops.

	Parameters
	----------
	x: torch.Tensor, shape=(N, L, C)
		The input tensor.

	w: torch.Tensor, shape=(3, C)
		Depthwise convolution weights for the left, center, and right taps.

	dilation: int
		Spacing between the three taps. The kernel reads from positions
		(i - dilation, i, i + dilation) for each output position i, with
		zero padding outside the sequence.

	eps: float, optional
		Numerical-stability constant added to the variance before taking
		the reciprocal square root. Default is 1e-3.

	Returns
	-------
	y: torch.Tensor, shape=(N, L, C)
		The convolved and normalized output, in the same dtype as `x`.
	"""

	N, L, C = x.shape

	# F.conv1d expects (N, C, L) input and (C, 1, 3) weight for depthwise.
	weight = w.t().unsqueeze(1).contiguous()
	x_t = x.transpose(1, 2).contiguous()
	y_t = F.conv1d(x_t, weight, padding=dilation, dilation=dilation, groups=C)
	y = y_t.transpose(1, 2).contiguous()

	# Per-example layer norm across all (L, C) positions.
	flat = y.reshape(N, -1).float()
	mean = flat.mean(dim=1, keepdim=True)
	var = flat.var(dim=1, keepdim=True, unbiased=False)
	rstd = (var + eps).rsqrt()
	flat = (flat - mean) * rstd
	return flat.reshape(N, L, C).to(y.dtype)


if HAS_TRITON:

	def _autotune_configs():
		num_warps = [4, 8, 16]
		num_stages = [2, 3, 4, 5]
		configs = []

		for num_warp, num_stage in itertools.product(num_warps, num_stages):
			configs.append(triton.Config({
				'num_warps': num_warp,
				'num_stages': num_stage,
			}))

		return configs


	@triton.autotune(
		configs=_autotune_configs(),
		key=['C', 'L'],
		reset_to_zero=['Sum_ptr', 'Sq_sum_ptr']
	)
	@triton.jit
	def _fwd_stats_kernel(
		X_ptr, W_ptr, Y_ptr, Sum_ptr, Sq_sum_ptr,
		stride_xn, dilation,
		L: tl.constexpr,
		C: tl.constexpr,
		BLOCK_C: tl.constexpr,
		BLOCK_L: tl.constexpr
	):
		pid_n = tl.program_id(0)
		pid_l = tl.program_id(1)
		offs_c = tl.arange(0, BLOCK_C)[None, :]
		mask_c = offs_c < C

		w_idx = W_ptr + offs_c
		w0 = tl.load(w_idx,       mask=mask_c, other=0.0)
		w1 = tl.load(w_idx + C,   mask=mask_c, other=0.0)
		w2 = tl.load(w_idx + C*2, mask=mask_c, other=0.0)

		l_start = pid_l * BLOCK_L
		offs = l_start + tl.arange(0, BLOCK_L)[:, None]
		offs_l = offs - dilation
		offs_r = offs + dilation

		mask = (offs < L) & mask_c
		mask_l = (offs_l >= 0) & mask
		mask_r = (offs_r < L) & mask

		x_idx = X_ptr + pid_n * stride_xn + offs_c
		x1 = tl.load(x_idx + offs*C,   mask=mask,   other=0.0)
		x0 = tl.load(x_idx + offs_l*C, mask=mask_l, other=0.0)
		x2 = tl.load(x_idx + offs_r*C, mask=mask_r, other=0.0)

		conv = x0*w0 + x1*w1 + x2*w2

		y_idx = Y_ptr + pid_n * stride_xn + offs * C + offs_c
		tl.store(y_idx, conv, mask=mask)

		conv = conv.to(tl.float32)

		block_sum = tl.sum(conv)
		block_sq_sum = tl.sum(conv * conv)

		tl.atomic_add(Sum_ptr + pid_n, block_sum, sem='relaxed')
		tl.atomic_add(Sq_sum_ptr + pid_n, block_sq_sum, sem='relaxed')


	@triton.jit
	def _fwd_finalize_kernel(
		Sum_ptr, Sq_sum_ptr, Mean_ptr, Rstd_ptr,
		eps,
		L: tl.constexpr,
		C: tl.constexpr,
	):
		pid_n = tl.program_id(0)

		running_sum = tl.load(Sum_ptr + pid_n)
		running_sq_sum = tl.load(Sq_sum_ptr + pid_n)

		count = L * C
		mean = running_sum / count
		var = (running_sq_sum / count) - (mean * mean)
		rstd = 1.0 / tl.sqrt(var + eps)

		tl.store(Mean_ptr + pid_n, mean)
		tl.store(Rstd_ptr + pid_n, rstd)


	@triton.jit
	def _fwd_apply_kernel(
		Y_ptr, Mean_ptr, Rstd_ptr,
		stride_yn,
		L: tl.constexpr,
		C: tl.constexpr,
		BLOCK_C: tl.constexpr,
		BLOCK_L: tl.constexpr,
	):
		pid_n = tl.program_id(0)
		pid_l = tl.program_id(1)
		offs_c = tl.arange(0, BLOCK_C)[None, :]
		mask_c = offs_c < C

		mean = tl.load(Mean_ptr + pid_n)
		rstd = tl.load(Rstd_ptr + pid_n)

		l_start = pid_l * BLOCK_L
		offs = l_start + tl.arange(0, BLOCK_L)[:, None]
		mask = (offs < L) & mask_c

		y_idx = Y_ptr + pid_n * stride_yn + offs * C + offs_c
		conv = tl.load(y_idx, mask=mask, other=0.0).to(tl.float32)
		x_hat = (conv - mean) * rstd

		tl.store(y_idx, x_hat, mask=mask)


	@triton.autotune(
		configs=_autotune_configs(),
		key=['C', 'L'],
		reset_to_zero=['Sum_dy_ptr', 'Sum_dy_xhat_ptr']
	)
	@triton.jit
	def _bwd_stats_kernel(
		dY_ptr, X_ptr, W_ptr, Mean_ptr, Rstd_ptr,
		Conv_ptr, Sum_dy_ptr, Sum_dy_xhat_ptr,
		stride_xn, dilation,
		L: tl.constexpr,
		C: tl.constexpr,
		BLOCK_C: tl.constexpr,
		BLOCK_L: tl.constexpr
	):
		pid_n = tl.program_id(0)
		pid_l = tl.program_id(1)
		offs_c = tl.arange(0, BLOCK_C)[None, :]
		mask_c = offs_c < C

		w_idx = W_ptr + offs_c
		w0 = tl.load(w_idx,       mask=mask_c, other=0.0)
		w1 = tl.load(w_idx + C,   mask=mask_c, other=0.0)
		w2 = tl.load(w_idx + C*2, mask=mask_c, other=0.0)

		mean = tl.load(Mean_ptr + pid_n)
		rstd = tl.load(Rstd_ptr + pid_n)

		l_start = pid_l * BLOCK_L
		offs = l_start + tl.arange(0, BLOCK_L)[:, None]
		offs_l = offs - dilation
		offs_r = offs + dilation

		mask = (offs < L) & mask_c
		mask_l = (offs_l >= 0) & mask
		mask_r = (offs_r < L) & mask

		x_idx = X_ptr + pid_n * stride_xn + offs_c
		x0 = tl.load(x_idx + offs_l*C, mask=mask_l, other=0.0)
		x1 = tl.load(x_idx + offs*C,   mask=mask,   other=0.0)
		x2 = tl.load(x_idx + offs_r*C, mask=mask_r, other=0.0)

		conv = x0*w0 + x1*w1 + x2*w2

		conv_idx = Conv_ptr + pid_n * stride_xn + offs * C + offs_c
		tl.store(conv_idx, conv, mask=mask)

		x_hat = (conv.to(tl.float32) - mean) * rstd

		dy_idx = dY_ptr + pid_n * stride_xn + offs * C + offs_c
		dy = tl.load(dy_idx, mask=mask, other=0.0).to(tl.float32)

		tl.atomic_add(Sum_dy_ptr + pid_n, tl.sum(dy), sem='relaxed')
		tl.atomic_add(Sum_dy_xhat_ptr + pid_n, tl.sum(dy * x_hat), sem='relaxed')


	@triton.autotune(
		configs=_autotune_configs(),
		key=['C', 'L']
	)
	@triton.jit
	def _bwd_apply_kernel(
		dY_ptr, X_ptr, Mean_ptr, Rstd_ptr,
		Sum_dy_ptr, Sum_dy_xhat_ptr,
		Conv_ptr, dW_ptr,
		stride_xn, num_chunks, dilation,
		L: tl.constexpr,
		C: tl.constexpr,
		BLOCK_C: tl.constexpr,
		BLOCK_L: tl.constexpr
	):
		pid_n = tl.program_id(0)
		pid_l = tl.program_id(1)
		offs_c = tl.arange(0, BLOCK_C)[None, :]
		mask_c = offs_c < C

		mean = tl.load(Mean_ptr + pid_n)
		rstd = tl.load(Rstd_ptr + pid_n)

		sum_dy_val = tl.load(Sum_dy_ptr + pid_n)
		sum_dy_xhat_val = tl.load(Sum_dy_xhat_ptr + pid_n)
		count = L * C

		l_start = pid_l * BLOCK_L
		offs = l_start + tl.arange(0, BLOCK_L)[:, None]
		offs_l = offs - dilation
		offs_r = offs + dilation
		mask = (offs < L) & mask_c
		mask_l = (offs_l >= 0) & mask
		mask_r = (offs_r < L) & mask

		buf_idx = Conv_ptr + pid_n * stride_xn + offs * C + offs_c
		conv = tl.load(buf_idx, mask=mask, other=0.0).to(tl.float32)
		x_hat = (conv - mean) * rstd

		dy = tl.load(dY_ptr + pid_n * stride_xn + offs * C + offs_c, mask=mask, other=0.0).to(tl.float32)

		d_conv = (rstd / count) * (count * dy - sum_dy_val - x_hat * sum_dy_xhat_val)
		tl.store(buf_idx, d_conv, mask=mask)

		x_idx = X_ptr + pid_n * stride_xn + offs_c
		x0 = tl.load(x_idx + offs_l*C, mask=mask_l, other=0.0).to(tl.float32)
		x1 = tl.load(x_idx + offs*C,   mask=mask,   other=0.0).to(tl.float32)
		x2 = tl.load(x_idx + offs_r*C, mask=mask_r, other=0.0).to(tl.float32)

		dw0 = tl.sum(d_conv * x0, axis=0)[None, :]
		dw1 = tl.sum(d_conv * x1, axis=0)[None, :]
		dw2 = tl.sum(d_conv * x2, axis=0)[None, :]

		dw_idx = dW_ptr + (pid_n * num_chunks + pid_l) * (3 * C) + offs_c
		tl.store(dw_idx,         dw0, mask=mask_c)
		tl.store(dw_idx + C,     dw1, mask=mask_c)
		tl.store(dw_idx + 2 * C, dw2, mask=mask_c)


	@triton.autotune(
		configs=_autotune_configs(),
		key=['C', 'L']
	)
	@triton.jit
	def _bwd_dx_kernel(
		dConv_ptr, W_ptr, dX_ptr,
		stride_xn, dilation,
		L: tl.constexpr,
		C: tl.constexpr,
		BLOCK_C: tl.constexpr,
		BLOCK_L: tl.constexpr
	):
		pid_n = tl.program_id(0)
		pid_l = tl.program_id(1)
		offs_c = tl.arange(0, BLOCK_C)[None, :]
		mask_c = offs_c < C

		w_idx = W_ptr + offs_c
		w0 = tl.load(w_idx,       mask=mask_c, other=0.0).to(tl.float32)
		w1 = tl.load(w_idx + C,   mask=mask_c, other=0.0).to(tl.float32)
		w2 = tl.load(w_idx + C*2, mask=mask_c, other=0.0).to(tl.float32)

		l_start = pid_l * BLOCK_L
		offs = l_start + tl.arange(0, BLOCK_L)[:, None]
		offs_p = offs + dilation
		offs_m = offs - dilation
		mask = (offs < L) & mask_c

		dc_base = dConv_ptr + pid_n * stride_xn + offs_c
		dc_c = tl.load(dc_base + offs * C,   mask=mask,                 other=0.0).to(tl.float32)
		dc_p = tl.load(dc_base + offs_p * C, mask=(offs_p < L) & mask,  other=0.0).to(tl.float32)
		dc_m = tl.load(dc_base + offs_m * C, mask=(offs_m >= 0) & mask, other=0.0).to(tl.float32)

		# dx[p] = d_conv[p]*w1 + d_conv[p+d]*w0 + d_conv[p-d]*w2
		dx = dc_c * w1 + dc_p * w0 + dc_m * w2
		tl.store(dX_ptr + pid_n * stride_xn + offs * C + offs_c, dx, mask=mask)


	class FusedDilatedConvNormFunc(torch.autograd.Function):
		"""Triton-backed fused dilated convolution + per-example layer norm.

		Implements the same operation as :func:`_cheri_conv_norm_cpu` but
		uses a custom Triton kernel to fuse the convolution, statistics
		reduction, and normalization steps into a small number of GPU
		passes. Only callable on CUDA tensors.
		"""

		@staticmethod
		def forward(ctx, x, w, dilation):
			N, L, C = x.shape
			BLOCK_C = triton.next_power_of_2(C)
			BLOCK_L = 64
			eps = CONV_NORM_EPS

			NUM_PARTIALS = triton.cdiv(L, BLOCK_L)

			sum_buf = torch.zeros((N,), dtype=torch.float32, device=x.device)
			sq_sum_buf = torch.zeros((N,), dtype=torch.float32, device=x.device)

			mean = torch.empty((N,), dtype=torch.float32, device=x.device)
			rstd = torch.empty((N,), dtype=torch.float32, device=x.device)

			y = torch.empty_like(x)

			_fwd_stats_kernel[(N, NUM_PARTIALS)](
				x, w, y, sum_buf, sq_sum_buf,
				x.stride(0), dilation,
				L, C, BLOCK_C=BLOCK_C, BLOCK_L=BLOCK_L
			)

			_fwd_finalize_kernel[(N,)](
				sum_buf, sq_sum_buf, mean, rstd,
				eps,
				L, C,
			)

			_fwd_apply_kernel[(N, NUM_PARTIALS)](
				y, mean, rstd,
				y.stride(0),
				L, C, BLOCK_C=BLOCK_C, BLOCK_L=BLOCK_L
			)

			ctx.save_for_backward(x, w, mean, rstd)
			ctx.dilation = dilation
			return y

		@staticmethod
		def backward(ctx, dy):
			x, w, mean, rstd = ctx.saved_tensors
			N, L, C = x.shape
			BLOCK_C = triton.next_power_of_2(C)
			BLOCK_L = 64

			dy = dy.contiguous()

			NUM_CHUNKS = triton.cdiv(L, BLOCK_L)

			sum_dy = torch.zeros((N,), dtype=torch.float32, device=x.device)
			sum_dy_xhat = torch.zeros((N,), dtype=torch.float32, device=x.device)

			buf = torch.empty((N, L, C), dtype=torch.float32, device=x.device)
			dx = torch.empty_like(x)
			dw = torch.empty((N * NUM_CHUNKS, 3 * C), dtype=torch.float32, device=x.device)

			_bwd_stats_kernel[(N, NUM_CHUNKS)](
				dy, x, w, mean, rstd,
				buf, sum_dy, sum_dy_xhat,
				x.stride(0), ctx.dilation,
				L, C, BLOCK_C=BLOCK_C, BLOCK_L=BLOCK_L
			)

			_bwd_apply_kernel[(N, NUM_CHUNKS)](
				dy, x, mean, rstd,
				sum_dy, sum_dy_xhat,
				buf, dw,
				x.stride(0), NUM_CHUNKS, ctx.dilation,
				L, C, BLOCK_C=BLOCK_C, BLOCK_L=BLOCK_L
			)

			_bwd_dx_kernel[(N, NUM_CHUNKS)](
				buf, w, dx,
				x.stride(0), ctx.dilation,
				L, C, BLOCK_C=BLOCK_C, BLOCK_L=BLOCK_L
			)

			dw = dw.view(N * NUM_CHUNKS, 3, C).sum(dim=0)
			return dx.to(x.dtype), dw.to(x.dtype), None


def fused_dilated_conv_norm(x, w, dilation):
	"""Fused 3-tap dilated depthwise conv plus per-example layer norm.

	Dispatches to the Triton kernel when `x` is a CUDA tensor and Triton
	is available, otherwise to a pure-PyTorch fallback. The two
	implementations are numerically equivalent up to floating-point error.

	Parameters
	----------
	x: torch.Tensor, shape=(N, L, C)
		The input tensor.

	w: torch.Tensor, shape=(3, C)
		Depthwise convolution weights for the left, center, and right taps.

	dilation: int
		Spacing between the three taps.

	Returns
	-------
	y: torch.Tensor, shape=(N, L, C)
		The convolved and normalized output.
	"""

	if HAS_TRITON and x.is_cuda:
		return FusedDilatedConvNormFunc.apply(x, w, dilation)
	return _cheri_conv_norm_cpu(x, w, dilation)


[docs] class CheriBlock(torch.nn.Module): """A single Cheri Block. The Cheri Block is the core building block of the Cherimoya model. It adapts the ConvNeXt block to noisy genomics data, with the goal of mixing spatial and channel information cheaply while remaining stable to train. The block performs the following operations on an input of shape (N, L, C): 1. A 3-tap depthwise dilated convolution that mixes spatial information independently for each channel. 2. A per-example layer normalization across the (length, channel) plane. The convolution and normalization are fused into one kernel. 3. A pointwise expansion linear projection from C to ``expansion * C`` channels. 4. A GELU non-linearity. 5. A pointwise contraction linear projection back to C channels. 6. A residual connection where the MLP output is scaled by a fixed constant (``residual_scale``) before being added to the input. The small constant keeps the residual path near-identity at initialization, which stabilizes training of deep stacks. Parameters ---------- n_filters: int The number of channels (the C dimension). dilation: int Dilation rate for the depthwise convolution. The kernel reads from positions (i - dilation, i, i + dilation) at each output position i, with zero padding outside the sequence. expansion: int, optional The factor by which the inner MLP expands the channel dimension. The first projection maps ``n_filters -> expansion * n_filters`` and the second projects back. Default is 2. residual_scale: float, optional Fixed scalar applied to the MLP output before it is added back to the residual stream. Default is 0.15. """ def __init__(self, n_filters, dilation, expansion=2, residual_scale=0.15): super().__init__() self.n_filters = n_filters self.dilation = dilation self.expansion = expansion self.residual_scale = residual_scale hidden = expansion * n_filters self.conv_weight = torch.nn.Parameter(torch.randn(3, n_filters)) self.linear1 = torch.nn.Linear(n_filters, hidden, bias=False) self.linear2 = torch.nn.Linear(hidden, n_filters, bias=False) self.activation = torch.nn.GELU(approximate='tanh') torch.nn.init.trunc_normal_(self.conv_weight, std=0.02) torch.nn.init.trunc_normal_(self.linear1.weight, std=0.02) torch.nn.init.trunc_normal_(self.linear2.weight, std=0.02)
[docs] def forward(self, X): """Run the block on an input of shape (N, L, C).""" X_conv = fused_dilated_conv_norm(X, self.conv_weight, self.dilation) X_mlp = self.linear2(self.activation(self.linear1(X_conv))) return X + X_mlp * self.residual_scale