Losses#

convex_kl_divergence(log_predictions, targets)

Computes a convex version of the Kullback-Leibler divergence loss.

cosine_distance(predictions, targets[, epsilon])

Computes the cosine distance between targets and predictions.

cosine_similarity(predictions, targets[, ...])

Computes the cosine similarity between targets and predictions.

ctc_loss(logits, logit_paddings, labels, ...)

Computes CTC loss.

ctc_loss_with_forward_probs(logits, ...[, ...])

Computes CTC loss and CTC forward-probabilities.

hinge_loss(predictor_outputs, targets)

Computes the hinge loss for binary classification.

huber_loss(predictions[, targets, delta])

Huber loss, similar to L2 loss close to zero, L1 loss away from zero.

kl_divergence(log_predictions, targets)

Computes the Kullback-Leibler divergence (relative entropy) loss.

l2_loss(predictions[, targets])

Calculates the L2 loss for a set of predictions.

log_cosh(predictions[, targets])

Calculates the log-cosh loss for a set of predictions.

ntxent(embeddings, labels[, temperature])

Normalized temperature scaled cross entropy loss (NT-Xent).

sigmoid_binary_cross_entropy(logits, labels)

Computes element-wise sigmoid cross entropy given logits and labels.

sigmoid_focal_loss(logits, labels[, alpha, ...])

Sigmoid focal loss.

smooth_labels(labels, alpha)

Apply label smoothing.

softmax_cross_entropy(logits, labels)

Computes the softmax cross entropy between sets of logits and labels.

softmax_cross_entropy_with_integer_labels(...)

Computes softmax cross entropy between sets of logits and integer labels.

squared_error(predictions[, targets])

Calculates the squared error for a set of predictions.

Convex Kullback Leibler divergence#

optax.convex_kl_divergence(log_predictions, targets)[source]#

Computes a convex version of the Kullback-Leibler divergence loss.

Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution. This version is jointly convex in p (targets) and q (log_predictions).

References

[Kullback, Leibler, 1951](https://www.jstor.org/stable/2236703)

Parameters:
  • log_predictions (Union[Array, ndarray, bool_, number]) – Probabilities of predicted distribution with shape […, dim]. Expected to be in the log-space to avoid underflow.

  • targets (Union[Array, ndarray, bool_, number]) – Probabilities of target distribution with shape […, dim]. Expected to be strictly positive.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

Kullback-Leibler divergence of predicted distribution from target distribution with shape […].

Cosine distance#

optax.cosine_distance(predictions, targets, epsilon=0.0)[source]#

Computes the cosine distance between targets and predictions.

The cosine distance, implemented here, measures the dissimilarity of two vectors as the opposite of cosine similarity: 1 - cos(theta).

References

[Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)

Parameters:
  • predictions (Union[Array, ndarray, bool_, number]) – The predicted vectors, with shape […, dim].

  • targets (Union[Array, ndarray, bool_, number]) – Ground truth target vectors, with shape […, dim].

  • epsilon (float) – minimum norm for terms in the denominator of the cosine similarity.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

cosine distances, with shape […].

Cosine similarity#

optax.cosine_similarity(predictions, targets, epsilon=0.0)[source]#

Computes the cosine similarity between targets and predictions.

The cosine similarity is a measure of similarity between vectors defined as the cosine of the angle between them, which is also the inner product of those vectors normalized to have unit norm.

References

[Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)

Parameters:
  • predictions (Union[Array, ndarray, bool_, number]) – The predicted vectors, with shape […, dim].

  • targets (Union[Array, ndarray, bool_, number]) – Ground truth target vectors, with shape […, dim].

  • epsilon (float) – minimum norm for terms in the denominator of the cosine similarity.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

cosine similarity measures, with shape […].

Connectionist temporal classification loss#

optax.ctc_loss(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon=-100000.0)[source]#

Computes CTC loss.

See docstring for ctc_loss_with_forward_probs for details.

Parameters:
  • logits (Union[Array, ndarray, bool_, number]) – (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in logits, and K denotes the number of classes including a class for blanks.

  • logit_paddings (Union[Array, ndarray, bool_, number]) – (B, T)-array. Padding indicators for logits. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0 denotes that logits[b, t, :] are padded values.

  • labels (Union[Array, ndarray, bool_, number]) – (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence.

  • label_paddings (Union[Array, ndarray, bool_, number]) – (B, N)-array. Padding indicators for labels. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0 denotes that labels[b, n] is a padded label. In the current implementation, labels must be right-padded, i.e. each row labelpaddings[b, :] must be repetition of zeroes, followed by repetition of ones.

  • blank_id (int) – Id for blank token. logits[b, :, blank_id] are used as probabilities of blank symbols.

  • log_epsilon (float) – Numerically-stable approximation of log(+0).

Return type:

Union[Array, ndarray, bool_, number]

Returns:

(B,)-array containing loss values for each sequence in the batch.

optax.ctc_loss_with_forward_probs(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon=-100000.0)[source]#

Computes CTC loss and CTC forward-probabilities.

The CTC loss is a loss function based on log-likelihoods of the model that introduces a special blank symbol \(\phi\) to represent variable-length output sequences.

Forward probabilities returned by this function, as auxiliary results, are grouped into two part: blank alpha-probability and non-blank alpha probability. Those are defined as follows:

\[\alpha_{\mathrm{BLANK}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\ \alpha_{\mathrm{LABEL}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots). \]

Here, \(\pi\) denotes the alignment sequence in the reference [Graves et al, 2006] that is blank-inserted representations of labels. The return values are the logarithms of the above probabilities.

References

[Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891)

Parameters:
  • logits (Union[Array, ndarray, bool_, number]) – (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in logits, and K denotes the number of classes including a class for blanks.

  • logit_paddings (Union[Array, ndarray, bool_, number]) – (B, T)-array. Padding indicators for logits. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0 denotes that logits[b, t, :] are padded values.

  • labels (Union[Array, ndarray, bool_, number]) – (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence.

  • label_paddings (Union[Array, ndarray, bool_, number]) – (B, N)-array. Padding indicators for labels. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0 denotes that labels[b, n] is a padded label. In the current implementation, labels must be right-padded, i.e. each row labelpaddings[b, :] must be repetition of zeroes, followed by repetition of ones.

  • blank_id (int) – Id for blank token. logits[b, :, blank_id] are used as probabilities of blank symbols.

  • log_epsilon (float) – Numerically-stable approximation of log(+0).

Return type:

tuple[Union[Array, ndarray, bool_, number], Union[Array, ndarray, bool_, number], Union[Array, ndarray, bool_, number]]

Returns:

A tuple (loss_value, logalpha_blank, logalpha_nonblank). Here, loss_value is a (B,)-array containing the loss values for each sequence in the batch, logalpha_blank and logalpha_nonblank are (T, B, N+1)-arrays where the (t, b, n)-th element denotes log alpha_B(t, n) and log alpha_L(t, n), respectively, for b-th sequence in the batch.

Hinge loss#

optax.hinge_loss(predictor_outputs, targets)[source]#

Computes the hinge loss for binary classification.

Parameters:
  • predictor_outputs (Union[Array, ndarray, bool_, number]) – Outputs of the decision function.

  • targets (Union[Array, ndarray, bool_, number]) – Target values. Target values should be strictly in the set {-1, 1}.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

loss value.

Huber loss#

optax.huber_loss(predictions, targets=None, delta=1.0)[source]#

Huber loss, similar to L2 loss close to zero, L1 loss away from zero.

If gradient descent is applied to the huber loss, it is equivalent to clipping gradients of an l2_loss to [-delta, delta] in the backward pass.

References

[Huber, 1964](www.projecteuclid.org/download/pdf_1/euclid.aoms/1177703732)

Parameters:
  • predictions (Union[Array, ndarray, bool_, number]) – a vector of arbitrary shape […].

  • targets (Union[Array, ndarray, bool_, number, None]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

  • delta (float) – the bounds for the huber loss transformation, defaults at 1.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

elementwise huber losses, with the same shape of predictions.

Kullback-Leibler divergence#

optax.kl_divergence(log_predictions, targets)[source]#

Computes the Kullback-Leibler divergence (relative entropy) loss.

Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution.

References

[Kullback, Leibler, 1951](https://www.jstor.org/stable/2236703)

Parameters:
  • log_predictions (Union[Array, ndarray, bool_, number]) – Probabilities of predicted distribution with shape […, dim]. Expected to be in the log-space to avoid underflow.

  • targets (Union[Array, ndarray, bool_, number]) – Probabilities of target distribution with shape […, dim]. Expected to be strictly positive.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

Kullback-Leibler divergence of predicted distribution from target distribution with shape […].

L2 Squared loss#

optax.squared_error(predictions, targets=None)[source]#

Calculates the squared error for a set of predictions.

Mean Squared Error can be computed as squared_error(a, b).mean().

Note: l2_loss = 0.5 * squared_error, where the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.

References

[Chris Bishop, 2006](https://bit.ly/3eeP0ga)

Parameters:
  • predictions (Union[Array, ndarray, bool_, number]) – a vector of arbitrary shape […].

  • targets (Union[Array, ndarray, bool_, number, None]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

elementwise squared differences, with same shape as predictions.

optax.l2_loss(predictions, targets=None)[source]#

Calculates the L2 loss for a set of predictions.

Note: the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.

References

[Chris Bishop, 2006](https://bit.ly/3eeP0ga)

Parameters:
  • predictions (Union[Array, ndarray, bool_, number]) – a vector of arbitrary shape […].

  • targets (Union[Array, ndarray, bool_, number, None]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

elementwise squared differences, with same shape as predictions.

Log hyperbolic cosine loss#

optax.log_cosh(predictions, targets=None)[source]#

Calculates the log-cosh loss for a set of predictions.

log(cosh(x)) is approximately (x**2) / 2 for small x and abs(x) - log(2) for large x. It is a twice differentiable alternative to the Huber loss.

References

[Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym)

Parameters:
  • predictions (Union[Array, ndarray, bool_, number]) – a vector of arbitrary shape […].

  • targets (Union[Array, ndarray, bool_, number, None]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

the log-cosh loss, with same shape as predictions.

Normalized temperature scaled cross-entropy (NT-Xent) loss#

optax.ntxent(embeddings, labels, temperature=0.07)[source]#

Normalized temperature scaled cross entropy loss (NT-Xent).

References

T. Chen et al A Simple Framework for Contrastive Learning of Visual Representations, 2020 kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss

Parameters:
  • embeddings (Union[Array, ndarray, bool_, number]) – batch of embeddings, with shape [batch, feature_length]

  • labels (Union[Array, ndarray, bool_, number]) – labels for groups that are positive pairs. e.g. if you have a batch of 4 embeddings and the first two and last two were positive pairs your labels should look like [0, 0, 1, 1]. labels SHOULD NOT be all the same (e.g. [0, 0, 0, 0]) you will get a NaN result. Shape [batch]

  • temperature (Union[Array, ndarray, bool_, number, float, int]) – temperature scaling parameter.

Return type:

Union[Array, ndarray, bool_, number, float, int]

Returns:

A scalar loss value of NT-Xent values averaged over all positive pairs

Added in version 0.2.3.

Sigmoid binary cross-entropy#

optax.sigmoid_binary_cross_entropy(logits, labels)[source]#

Computes element-wise sigmoid cross entropy given logits and labels.

This function can be used for binary or multiclass classification (where each class is an independent binary prediction and different classes are not mutually exclusive e.g. predicting that an image contains both a cat and a dog.)

Because this function is overloaded, please ensure your logits and labels are compatible with each other. If you’re passing in binary labels (values in {0, 1}), ensure your logits correspond to class 1 only. If you’re passing in per-class target probabilities or one-hot labels, please ensure your logits are also multiclass. Be particularly careful if you’re relying on implicit broadcasting to reshape logits or labels.

References

[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)

Parameters:
  • logits – Each element is the unnormalized log probability of a binary prediction. See note about compatibility with labels above.

  • labels – Binary labels whose values are {0,1} or multi-class target probabilities. See note about compatibility with logits above.

Returns:

cross entropy for each binary prediction, same shape as logits.

Sigmoid focal loss#

optax.sigmoid_focal_loss(logits, labels, alpha=None, gamma=2.0)[source]#

Sigmoid focal loss.

The focal loss is a re-weighted cross entropy for unbalanced problems. Use this loss function if classes are not mutually exclusive. See sigmoid_binary_cross_entropy for more information.

References

Lin et al. 2018. https://arxiv.org/pdf/1708.02002.pdf

Parameters:
  • logits (Union[Array, ndarray, bool_, number]) – Array of floats. The predictions for each example. The predictions for each example.

  • labels (Union[Array, ndarray, bool_, number]) – Array of floats. Labels and logits must have the same shape. The label array must contain the binary classification labels for each element in the data set (0 for the out-of-class and 1 for in-class).

  • alpha (Optional[float]) – (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default None (no weighting).

  • gamma (float) – Exponent of the modulating factor (1 - p_t). Balances easy vs hard examples.

Return type:

Union[Array, ndarray, bool_, number]

Returns:

A loss value array with a shape identical to the logits and target arrays.

Smoothing labels#

optax.smooth_labels(labels, alpha)[source]#

Apply label smoothing.

Label smoothing is often used in combination with a cross-entropy loss. Smoothed labels favour small logit gaps, and it has been shown that this can provide better model calibration by preventing overconfident predictions.

References

[MĂĽller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf)

Parameters:
  • labels (Union[Array, ndarray, bool_, number]) – One hot labels to be smoothed.

  • alpha (float) – The smoothing factor.

Return type:

Array

Returns:

a smoothed version of the one hot input labels.

Soft-max cross-entropy#

optax.softmax_cross_entropy(logits, labels)[source]#

Computes the softmax cross entropy between sets of logits and labels.

Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.

References

[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)

Parameters:
  • logits (Union[Array, ndarray, bool_, number]) – Unnormalized log probabilities, with shape […, num_classes].

  • labels (Union[Array, ndarray, bool_, number]) – Valid probability distributions (non-negative, sum to 1), e.g a one hot encoding specifying the correct class for each input; must have a shape broadcastable to […, num_classes].

Return type:

Union[Array, ndarray, bool_, number]

Returns:

cross entropy between each prediction and the corresponding target distributions, with shape […].

See also

optax.safe_softmax_cross_entropy()

optax.softmax_cross_entropy_with_integer_labels(logits, labels)[source]#

Computes softmax cross entropy between sets of logits and integer labels.

Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.

References

[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)

Parameters:
  • logits (Union[Array, ndarray, bool_, number]) – Unnormalized log probabilities, with shape […, num_classes].

  • labels (Union[Array, ndarray, bool_, number]) – Integers specifying the correct class for each input, with shape […].

Return type:

Union[Array, ndarray, bool_, number]

Returns:

Cross entropy between each prediction and the corresponding target distributions, with shape […].