optax.losses.softmax_cross_entropy_with_integer_labels#
- optax.losses.softmax_cross_entropy_with_integer_labels(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, axis: int | tuple[int, ...] = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Computes softmax cross entropy between the logits and integer labels.
This loss is useful for classification problems with integer labels that are not one-hot encoded. This loss is also known as categorical cross entropy.
Let \(x\) denote the
logitsarray of size[batch_size, num_classes]and \(y\) denote thelabelsarray of size[batch_size]. Then this function returns a vector \(\sigma\) of size[batch_size]defined as:\[\sigma_i = \log\left(\frac{\exp(x_{i y_i})}{\sum_j \exp(x_{i j})}\right)\,. \]- Parameters:
logits โ Unnormalized log probabilities, with shape
[batch_size, num_classes].labels โ Integers specifying the correct class for each input, with shape
[batch_size]. Class labels are assumed to be between 0 andnum_classes - 1inclusive.axis โ Axis or axes along which to compute. If a tuple of axes is passed then
num_classesmust match the total number of elements inaxisdimensions and a label is interpreted as a flat index in alogitsslice of shapelogits[axis].where โ Elements to include in the computation of shape
[batch_size]or logits.shape.
- Returns:
Cross-entropy between each prediction and the corresponding target distributions, with shape
[batch_size].
Examples
>>> import optax >>> import jax.numpy as jnp >>> jnp.set_printoptions(precision=4) >>> # example: batch_size = 2, num_classes = 3 >>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]]) >>> labels = jnp.array([0, 1]) >>> print(optax.softmax_cross_entropy_with_integer_labels(logits, labels)) [0.2761 2.9518]
>>> import jax.numpy as jnp >>> import numpy as np >>> import optax >>> jnp.set_printoptions(precision=4) >>> # example: batch_size = (1, 2), num_classes = 12 (i.e. 3 * 4) >>> shape = (1, 2, 3, 4) >>> logits = jnp.arange(np.prod(shape), dtype=jnp.float32).reshape(shape) >>> # elements indices in slice of shape (3, 4) >>> ix = jnp.array([[1, 2]]) >>> jx = jnp.array([[1, 3]]) >>> labels = jnp.ravel_multi_index((ix, jx), shape[2:]) >>> cross_entropy = optax.softmax_cross_entropy_with_integer_labels( ... logits, labels, axis=(2, 3)) >>> print(cross_entropy) [[6.4587 0.4587]]
References
Cross-entropy Loss, Wikipedia
Multinomial Logistic Regression, Wikipedia
See also
This function is similar to
optax.losses.softmax_cross_entropy(), but accepts integer labels instead of one-hot labels.Changed in version 0.2.4: Added
axisandwherearguments.