Source code for cherimoya.losses
# losses.py
# Authors: Jacob Schreiber <jmschreiber91@gmail.com>
"""
This module contains the mixture loss function used for training Cherimoya
models, which is comprised of a multinomial log likelihood component and a
mean-squared error component. These losses are provided independently, so
other code can implement different ways of combining them into a single loss.
"""
import torch
from bpnetlite.losses import MNLLLoss
from bpnetlite.losses import log1pMSELoss
[docs]
def _mixture_loss(y, y_hat_logits, y_hat_logcounts, labels=None):
"""A function that takes in predictions and truth and returns the loss.
This function takes in the observed integer read counts, the predicted logits,
and the predicted logcounts, and returns the total loss. Importantly, this
calculates a single multinomial over all strands in the tracks and a single
count loss across all tracks.
Parameters
----------
y: torch.Tensor, shape=(n, n_outputs, length)
The observed counts for each example across each strand/output and at each
position. This should likely be sparse integers.
y_hat_logits: torch.Tensor, shape=(n, n_outputs, length)
The predicted *logits* for each example across each strand/output and at
each position. This will be normalized internally, so DO NOT run a softmax
on your model.
y_hat_logcounts: torch.Tensor, shape=(n, n_outputs)
The predicted *log counts* for each example across each strand/output. The
true log counts will be derived automatically from `y`.
labels: torch.Tensor, shape=(n,), optional
Whether the example is from a peak (1) or a non-peak (0). If provided, the
profile loss will only be calculated on the peak examples. The count loss
will always be calculated on the entire set of examples. If not provided,
the profile loss will also be calculated on the entire set of examples.
Default is None.
Returns
-------
profile_loss: torch.Tensor, shape=(1,)
The multinomial log likelihood loss averaged across examples and outputs.
count_loss: torch.Tensor, shape=(1,)
The mean-squared error loss, averaged across examples and outputs.
"""
y_hat_logits = y_hat_logits.reshape(y_hat_logits.shape[0], -1)
y_hat_logits = torch.nn.functional.log_softmax(y_hat_logits, dim=-1)
y = y.reshape(y.shape[0], -1)
y_ = y.sum(dim=-1).reshape(y.shape[0], 1)
# Calculate the profile and count losses
if labels is not None:
profile_loss = MNLLLoss(y_hat_logits[labels == 1], y[labels == 1]).mean()
else:
profile_loss = MNLLLoss(y_hat_logits, y).mean()
count_loss = log1pMSELoss(y_hat_logcounts, y_).mean()
return profile_loss, count_loss