optax.losses.safe_softmax_cross_entropy#
- optax.losses.safe_softmax_cross_entropy(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike) Array[source]#
Computes the softmax cross entropy between sets of logits and labels.
Contrarily to
optax.softmax_cross_entropy()this function handleslabels*logsoftmax(logits)as0whenlogits=-infandlabels=0, following the convention that0 log 0 = 0.- Parameters:
logits โ Unnormalized log probabilities, with shape [โฆ, num_classes].
labels โ Valid probability distributions (non-negative, sum to 1), e.g a one hot encoding specifying the correct class for each input; must have a shape broadcastable to [โฆ, num_classes].
- Returns:
cross entropy between each prediction and the corresponding target distributions, with shape [โฆ].