optax.losses.ranking_softmax_loss

optax.losses.ranking_softmax_loss#

optax.losses.ranking_softmax_loss(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, *, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, weights: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, reduce_fn: ~collections.abc.Callable[[...], TypeAliasForwardRef('jax.typing.ArrayLike')] | None = <function mean>) Array[source]#

Ranking softmax loss.

Definition:

\[\ell(s, y) = -\sum_i y_i \log \frac{\exp(s_i)}{\sum_j \exp(s_j)} \]
Parameters:
  • logits โ€“ A [..., list_size]-Array, indicating the score of each item.

  • labels โ€“ A [..., list_size]-Array, indicating the relevance label for each item.

  • where โ€“ An optional [..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.

  • weights โ€“ An optional [..., list_size]-Array, indicating the weight for each item.

  • reduce_fn โ€“ An optional function that reduces the loss values. Can be jax.numpy.sum() or jax.numpy.mean(). If None, no reduction is performed.

Returns:

The ranking softmax loss.