optax.losses.softmax_cross_entropy#
- optax.losses.softmax_cross_entropy(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Computes the softmax cross entropy between sets of logits and labels.
This loss function is commonly used for multi-class classification tasks. It measures the dissimilarity between the predicted probability distribution (obtained by applying the softmax function to the logits) and the true probability distribution (represented by the one-hot encoded labels). This loss is also known as categorical cross entropy.
Let \(x\) denote the
logitsarray of size[batch_size, num_classes]and \(y\) denote thelabelsarray of size[batch_size, num_classes]. Then this function returns a vector \(\sigma\) of size[batch_size]defined as:\[\sigma_i = - \sum_j y_{i j} \log\left(\frac{\exp(x_{i j})}{\sum_k \exp(x_{i k})}\right) \,. \]- Parameters:
logits โ Unnormalized log probabilities, with shape
[batch_size, num_classes].labels โ One-hot encoded labels, with shape [batch_size, num_classes]. Each row represents the true class distribution for a single example.
axis โ Axis or axes along which to compute.
where โ Elements to include in the computation of shape
[batch_size]or logits.shape.
- Returns:
Cross-entropy between each prediction and the corresponding target distributions, with shape
[batch_size].
Examples
>>> import optax >>> import jax.numpy as jnp >>> jnp.set_printoptions(precision=4) >>> # example: batch_size = 2, num_classes = 3 >>> logits = jnp.array([[1.2, -0.8, -0.5], [0.9, -1.2, 1.1]]) >>> labels = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) >>> print(optax.softmax_cross_entropy(logits, labels)) [0.2761 2.9518]
References
Cross-entropy Loss, Wikipedia
Multinomial Logistic Regression, Wikipedia
See also
This function is similar to
optax.losses.softmax_cross_entropy_with_integer_labels(), but accepts one-hot labels instead of integer labels.optax.losses.safe_softmax_cross_entropy()provides an alternative implementation that differs on howlogits=-infare handled.Changed in version 0.2.4: Added
axisandwherearguments.