Losses

_mixture_loss

cherimoya.losses._mixture_loss(y, y_hat_logits, y_hat_logcounts, labels=None)[source]

A function that takes in predictions and truth and returns the loss.

This function takes in the observed integer read counts, the predicted logits, and the predicted logcounts, and returns the total loss. Importantly, this calculates a single multinomial over all strands in the tracks and a single count loss across all tracks.

Parameters:
  • y (torch.Tensor, shape=(n, n_outputs, length)) – The observed counts for each example across each strand/output and at each position. This should likely be sparse integers.

  • y_hat_logits (torch.Tensor, shape=(n, n_outputs, length)) – The predicted logits for each example across each strand/output and at each position. This will be normalized internally, so DO NOT run a softmax on your model.

  • y_hat_logcounts (torch.Tensor, shape=(n, n_outputs)) – The predicted log counts for each example across each strand/output. The true log counts will be derived automatically from y.

labels: torch.Tensor, shape=(n,), optional

Whether the example is from a peak (1) or a non-peak (0). If provided, the profile loss will only be calculated on the peak examples. The count loss will always be calculated on the entire set of examples. If not provided, the profile loss will also be calculated on the entire set of examples. Default is None.

Returns:

  • profile_loss (torch.Tensor, shape=(1,)) – The multinomial log likelihood loss averaged across examples and outputs.

  • count_loss (torch.Tensor, shape=(1,)) – The mean-squared error loss, averaged across examples and outputs.

Imported Losses

The following loss functions are imported from bpnetlite and used internally:

cherimoya.losses.MNLLLoss(logps, true_counts)

Multinomial negative log-likelihood loss. Computes the negative log probability of the observed counts under a multinomial distribution parameterized by the predicted log probabilities.

Parameters:
  • logps – Predicted log probabilities, shape (n, length)

  • true_counts – Observed integer counts, shape (n, length)

Returns:

Loss per example, shape (n,)

cherimoya.losses.log1pMSELoss(pred_log_counts, true_counts)

Mean squared error in log space. Computes MSE(pred, log(true + 1)).

Parameters:
  • pred_log_counts – Predicted log counts, shape (n, n_outputs)

  • true_counts – True counts (not in log space), shape (n, n_outputs)

Returns:

Loss per example, shape (n,)