Losses#

Computes a convex version of the KullbackLeibler divergence loss. 

Computes the cosine distance between targets and predictions. 

Computes the cosine similarity between targets and predictions. 

Computes CTC loss. 

Computes CTC loss and CTC forwardprobabilities. 

Computes the hinge loss for binary classification. 

Huber loss, similar to L2 loss close to zero, L1 loss away from zero. 

Computes the KullbackLeibler divergence (relative entropy) loss. 

Calculates the L2 loss for a set of predictions. 

Calculates the logcosh loss for a set of predictions. 

Normalized temperature scaled cross entropy loss (NTXent). 

Computes elementwise sigmoid cross entropy given logits and labels. 

Sigmoid focal loss. 

Apply label smoothing. 

Computes the softmax cross entropy between sets of logits and labels. 
Computes softmax cross entropy between sets of logits and integer labels. 


Calculates the squared error for a set of predictions. 
Convex Kullback Leibler divergence#
 optax.convex_kl_divergence(log_predictions, targets)[source]#
Computes a convex version of the KullbackLeibler divergence loss.
Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution. This version is jointly convex in p (targets) and q (log_predictions).
References
[Kullback, Leibler, 1951](https://www.jstor.org/stable/2236703)
 Parameters:
log_predictions (
Union
[Array
,ndarray
,bool_
,number
]) â€“ Probabilities of predicted distribution with shape [â€¦, dim]. Expected to be in the logspace to avoid underflow.targets (
Union
[Array
,ndarray
,bool_
,number
]) â€“ Probabilities of target distribution with shape [â€¦, dim]. Expected to be strictly positive.
 Return type:
 Returns:
KullbackLeibler divergence of predicted distribution from target distribution with shape [â€¦].
Cosine distance#
 optax.cosine_distance(predictions, targets, epsilon=0.0)[source]#
Computes the cosine distance between targets and predictions.
The cosine distance, implemented here, measures the dissimilarity of two vectors as the opposite of cosine similarity: 1  cos(theta).
References
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)
 Parameters:
 Return type:
 Returns:
cosine distances, with shape [â€¦].
Cosine similarity#
 optax.cosine_similarity(predictions, targets, epsilon=0.0)[source]#
Computes the cosine similarity between targets and predictions.
The cosine similarity is a measure of similarity between vectors defined as the cosine of the angle between them, which is also the inner product of those vectors normalized to have unit norm.
References
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)
 Parameters:
 Return type:
 Returns:
cosine similarity measures, with shape [â€¦].
Connectionist temporal classification loss#
 optax.ctc_loss(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon=100000.0)[source]#
Computes CTC loss.
See docstring for
ctc_loss_with_forward_probs
for details. Parameters:
logits (
Union
[Array
,ndarray
,bool_
,number
]) â€“ (B, T, K)array containing logits of each class where B denotes the batch size, T denotes the max time frames inlogits
, and K denotes the number of classes including a class for blanks.logit_paddings (
Union
[Array
,ndarray
,bool_
,number
]) â€“ (B, T)array. Padding indicators forlogits
. Each element must be either 1.0 or 0.0, andlogitpaddings[b, t] == 1.0
denotes thatlogits[b, t, :]
are padded values.labels (
Union
[Array
,ndarray
,bool_
,number
]) â€“ (B, N)array containing reference integer labels where N denotes the max time frames in the label sequence.label_paddings (
Union
[Array
,ndarray
,bool_
,number
]) â€“ (B, N)array. Padding indicators forlabels
. Each element must be either 1.0 or 0.0, andlabelpaddings[b, n] == 1.0
denotes thatlabels[b, n]
is a padded label. In the current implementation,labels
must be rightpadded, i.e. each rowlabelpaddings[b, :]
must be repetition of zeroes, followed by repetition of ones.blank_id (
int
) â€“ Id for blank token.logits[b, :, blank_id]
are used as probabilities of blank symbols.log_epsilon (
float
) â€“ Numericallystable approximation of log(+0).
 Return type:
 Returns:
(B,)array containing loss values for each sequence in the batch.
 optax.ctc_loss_with_forward_probs(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon=100000.0)[source]#
Computes CTC loss and CTC forwardprobabilities.
The CTC loss is a loss function based on loglikelihoods of the model that introduces a special blank symbol \(\phi\) to represent variablelength output sequences.
Forward probabilities returned by this function, as auxiliary results, are grouped into two part: blank alphaprobability and nonblank alpha probability. Those are defined as follows:
\[\alpha_{\mathrm{BLANK}}(t, n) = \sum_{\pi_{1:t1}} p(\pi_t = \phi  \pi_{1:t1}, y_{1:n1}, \cdots), \\ \alpha_{\mathrm{LABEL}}(t, n) = \sum_{\pi_{1:t1}} p(\pi_t = y_n  \pi_{1:t1}, y_{1:n1}, \cdots). \]Here, \(\pi\) denotes the alignment sequence in the reference [Graves et al, 2006] that is blankinserted representations of
labels
. The return values are the logarithms of the above probabilities.References
[Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891)
 Parameters:
logits (
Union
[Array
,ndarray
,bool_
,number
]) â€“ (B, T, K)array containing logits of each class where B denotes the batch size, T denotes the max time frames inlogits
, and K denotes the number of classes including a class for blanks.logit_paddings (
Union
[Array
,ndarray
,bool_
,number
]) â€“ (B, T)array. Padding indicators forlogits
. Each element must be either 1.0 or 0.0, andlogitpaddings[b, t] == 1.0
denotes thatlogits[b, t, :]
are padded values.labels (
Union
[Array
,ndarray
,bool_
,number
]) â€“ (B, N)array containing reference integer labels where N denotes the max time frames in the label sequence.label_paddings (
Union
[Array
,ndarray
,bool_
,number
]) â€“ (B, N)array. Padding indicators forlabels
. Each element must be either 1.0 or 0.0, andlabelpaddings[b, n] == 1.0
denotes thatlabels[b, n]
is a padded label. In the current implementation,labels
must be rightpadded, i.e. each rowlabelpaddings[b, :]
must be repetition of zeroes, followed by repetition of ones.blank_id (
int
) â€“ Id for blank token.logits[b, :, blank_id]
are used as probabilities of blank symbols.log_epsilon (
float
) â€“ Numericallystable approximation of log(+0).
 Return type:
tuple
[Union
[Array
,ndarray
,bool_
,number
],Union
[Array
,ndarray
,bool_
,number
],Union
[Array
,ndarray
,bool_
,number
]] Returns:
A tuple
(loss_value, logalpha_blank, logalpha_nonblank)
. Here,loss_value
is a (B,)array containing the loss values for each sequence in the batch,logalpha_blank
andlogalpha_nonblank
are (T, B, N+1)arrays where the (t, b, n)th element denotes log alpha_B(t, n) and log alpha_L(t, n), respectively, forb
th sequence in the batch.
Hinge loss#
Huber loss#
 optax.huber_loss(predictions, targets=None, delta=1.0)[source]#
Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
If gradient descent is applied to the huber loss, it is equivalent to clipping gradients of an l2_loss to [delta, delta] in the backward pass.
References
[Huber, 1964](www.projecteuclid.org/download/pdf_1/euclid.aoms/1177703732)
 Parameters:
predictions (
Union
[Array
,ndarray
,bool_
,number
]) â€“ a vector of arbitrary shape [â€¦].targets (
Union
[Array
,ndarray
,bool_
,number
,None
]) â€“ a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.delta (
float
) â€“ the bounds for the huber loss transformation, defaults at 1.
 Return type:
 Returns:
elementwise huber losses, with the same shape of predictions.
KullbackLeibler divergence#
 optax.kl_divergence(log_predictions, targets)[source]#
Computes the KullbackLeibler divergence (relative entropy) loss.
Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution.
References
[Kullback, Leibler, 1951](https://www.jstor.org/stable/2236703)
 Parameters:
log_predictions (
Union
[Array
,ndarray
,bool_
,number
]) â€“ Probabilities of predicted distribution with shape [â€¦, dim]. Expected to be in the logspace to avoid underflow.targets (
Union
[Array
,ndarray
,bool_
,number
]) â€“ Probabilities of target distribution with shape [â€¦, dim]. Expected to be strictly positive.
 Return type:
 Returns:
KullbackLeibler divergence of predicted distribution from target distribution with shape [â€¦].
L2 Squared loss#
 optax.squared_error(predictions, targets=None)[source]#
Calculates the squared error for a set of predictions.
Mean Squared Error can be computed as squared_error(a, b).mean().
Note: l2_loss = 0.5 * squared_error, where the 0.5 term is standard in â€śPattern Recognition and Machine Learningâ€ť by Bishop, but not â€śThe Elements of Statistical Learningâ€ť by Tibshirani.
References
[Chris Bishop, 2006](https://bit.ly/3eeP0ga)
 Parameters:
 Return type:
 Returns:
elementwise squared differences, with same shape as predictions.
 optax.l2_loss(predictions, targets=None)[source]#
Calculates the L2 loss for a set of predictions.
Note: the 0.5 term is standard in â€śPattern Recognition and Machine Learningâ€ť by Bishop, but not â€śThe Elements of Statistical Learningâ€ť by Tibshirani.
References
[Chris Bishop, 2006](https://bit.ly/3eeP0ga)
 Parameters:
 Return type:
 Returns:
elementwise squared differences, with same shape as predictions.
Log hyperbolic cosine loss#
 optax.log_cosh(predictions, targets=None)[source]#
Calculates the logcosh loss for a set of predictions.
log(cosh(x)) is approximately (x**2) / 2 for small x and abs(x)  log(2) for large x. It is a twice differentiable alternative to the Huber loss.
References
[Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym)
 Parameters:
 Return type:
 Returns:
the logcosh loss, with same shape as predictions.
Normalized temperature scaled crossentropy (NTXent) loss#
 optax.ntxent(embeddings, labels, temperature=0.07)[source]#
Normalized temperature scaled cross entropy loss (NTXent).
References
T. Chen et al A Simple Framework for Contrastive Learning of Visual Representations, 2020 kevinmusgrave.github.io/pytorchmetriclearning/losses/#ntxentloss
 Parameters:
embeddings (
Union
[Array
,ndarray
,bool_
,number
]) â€“ batch of embeddings, with shape [batch, feature_length]labels (
Union
[Array
,ndarray
,bool_
,number
]) â€“ labels for groups that are positive pairs. e.g. if you have a batch of 4 embeddings and the first two and last two were positive pairs your labels should look like [0, 0, 1, 1]. labels SHOULD NOT be all the same (e.g. [0, 0, 0, 0]) you will get a NaN result. Shape [batch]temperature (
Union
[Array
,ndarray
,bool_
,number
,float
,int
]) â€“ temperature scaling parameter.
 Return type:
 Returns:
A scalar loss value of NTXent values averaged over all positive pairs
Added in version 0.2.3.
Sigmoid binary crossentropy#
 optax.sigmoid_binary_cross_entropy(logits, labels)[source]#
Computes elementwise sigmoid cross entropy given logits and labels.
This function can be used for binary or multiclass classification (where each class is an independent binary prediction and different classes are not mutually exclusive e.g. predicting that an image contains both a cat and a dog.)
Because this function is overloaded, please ensure your logits and labels are compatible with each other. If youâ€™re passing in binary labels (values in {0, 1}), ensure your logits correspond to class 1 only. If youâ€™re passing in perclass target probabilities or onehot labels, please ensure your logits are also multiclass. Be particularly careful if youâ€™re relying on implicit broadcasting to reshape logits or labels.
References
[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)
 Parameters:
logits â€“ Each element is the unnormalized log probability of a binary prediction. See note about compatibility with labels above.
labels â€“ Binary labels whose values are {0,1} or multiclass target probabilities. See note about compatibility with logits above.
 Returns:
cross entropy for each binary prediction, same shape as logits.
Sigmoid focal loss#
 optax.sigmoid_focal_loss(logits, labels, alpha=None, gamma=2.0)[source]#
Sigmoid focal loss.
The focal loss is a reweighted cross entropy for unbalanced problems. Use this loss function if classes are not mutually exclusive. See sigmoid_binary_cross_entropy for more information.
References
Lin et al. 2018. https://arxiv.org/pdf/1708.02002.pdf
 Parameters:
logits (
Union
[Array
,ndarray
,bool_
,number
]) â€“ Array of floats. The predictions for each example. The predictions for each example.labels (
Union
[Array
,ndarray
,bool_
,number
]) â€“ Array of floats. Labels and logits must have the same shape. The label array must contain the binary classification labels for each element in the data set (0 for the outofclass and 1 for inclass).alpha (
Optional
[float
]) â€“ (optional) Weighting factor in range (0,1) to balance positive vs negative examples. Default None (no weighting).gamma (
float
) â€“ Exponent of the modulating factor (1  p_t). Balances easy vs hard examples.
 Return type:
 Returns:
A loss value array with a shape identical to the logits and target arrays.
Smoothing labels#
 optax.smooth_labels(labels, alpha)[source]#
Apply label smoothing.
Label smoothing is often used in combination with a crossentropy loss. Smoothed labels favour small logit gaps, and it has been shown that this can provide better model calibration by preventing overconfident predictions.
References
[MĂĽller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf)
Softmax crossentropy#
 optax.softmax_cross_entropy(logits, labels)[source]#
Computes the softmax cross entropy between sets of logits and labels.
Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.
References
[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)
 Parameters:
logits (
Union
[Array
,ndarray
,bool_
,number
]) â€“ Unnormalized log probabilities, with shape [â€¦, num_classes].labels (
Union
[Array
,ndarray
,bool_
,number
]) â€“ Valid probability distributions (nonnegative, sum to 1), e.g a one hot encoding specifying the correct class for each input; must have a shape broadcastable to [â€¦, num_classes].
 Return type:
 Returns:
cross entropy between each prediction and the corresponding target distributions, with shape [â€¦].
See also
optax.safe_softmax_cross_entropy()
 optax.softmax_cross_entropy_with_integer_labels(logits, labels)[source]#
Computes softmax cross entropy between sets of logits and integer labels.
Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.
References
[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)
 Parameters:
 Return type:
 Returns:
Cross entropy between each prediction and the corresponding target distributions, with shape [â€¦].