optax.losses.softmax_cross_entropy_with_integer_labels

optax.losses.softmax_cross_entropy_with_integer_labels#

optax.losses.softmax_cross_entropy_with_integer_labels(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, axis: int | tuple[int, ...] = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Computes softmax cross entropy between the logits and integer labels.

This loss is useful for classification problems with integer labels that are not one-hot encoded. This loss is also known as categorical cross entropy.

Let \(x\) denote the logits array of size [batch_size, num_classes] and \(y\) denote the labels array of size [batch_size]. Then this function returns a vector \(\sigma\) of size [batch_size] defined as:

\[\sigma_i = \log\left(\frac{\exp(x_{i y_i})}{\sum_j \exp(x_{i j})}\right)\,. \]
Parameters:
  • logits โ€“ Unnormalized log probabilities, with shape [batch_size, num_classes].

  • labels โ€“ Integers specifying the correct class for each input, with shape [batch_size]. Class labels are assumed to be between 0 and num_classes - 1 inclusive.

  • axis โ€“ Axis or axes along which to compute. If a tuple of axes is passed then num_classes must match the total number of elements in axis dimensions and a label is interpreted as a flat index in a logits slice of shape logits[axis].

  • where โ€“ Elements to include in the computation of shape [batch_size] or logits.shape.

Returns:

Cross-entropy between each prediction and the corresponding target distributions, with shape [batch_size].

Examples

>>> import optax
>>> import jax.numpy as jnp
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = 2, num_classes = 3
>>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]])
>>> labels = jnp.array([0, 1])
>>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels))
[0.2761 2.9518]
>>> import jax.numpy as jnp
>>> import numpy as np
>>> import optax
>>> jnp.set_printoptions(precision=4)
>>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4)
>>> shape = (1, 2, 3, 4)
>>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape)
>>> # elements indices in slice of shape (3, 4)
>>> ix = jnp.array([[1, 2]])
>>> jx = jnp.array([[1, 3]])
>>> labels = jnp.ravel_multi_index((ix, jx), shape[2:])
>>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels(
...     logits, labels, axis=(2, 3))
>>> print(cross_entropy)
[[6.4587 0.4587]]

References

Cross-entropy Loss, Wikipedia

Multinomial Logistic Regression, Wikipedia

See also

This function is similar to optax.losses.softmax_cross_entropy(), but accepts integer labels instead of one-hot labels.

Changed in version 0.2.4: Added axis and where arguments.