optax.losses.safe_softmax_cross_entropy

optax.losses.safe_softmax_cross_entropy#

optax.losses.safe_softmax_cross_entropy(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike) Array[source]#

Computes the softmax cross entropy between sets of logits and labels.

Contrarily to optax.softmax_cross_entropy() this function handles labels*logsoftmax(logits) as 0 when logits=-inf and labels=0, following the convention that 0 log 0 = 0.

Parameters:
  • logits โ€“ Unnormalized log probabilities, with shape [โ€ฆ, num_classes].

  • labels โ€“ Valid probability distributions (non-negative, sum to 1), e.g a one hot encoding specifying the correct class for each input; must have a shape broadcastable to [โ€ฆ, num_classes].

Returns:

cross entropy between each prediction and the corresponding target distributions, with shape [โ€ฆ].