optax.losses.ranking_softmax_loss#
- optax.losses.ranking_softmax_loss(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, *, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, weights: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, reduce_fn: ~collections.abc.Callable[[...], TypeAliasForwardRef('jax.typing.ArrayLike')] | None = <function mean>) Array[source]#
Ranking softmax loss.
Definition:
\[\ell(s, y) = -\sum_i y_i \log \frac{\exp(s_i)}{\sum_j \exp(s_j)} \]- Parameters:
logits โ A
[..., list_size]-Array, indicating the score of each item.labels โ A
[..., list_size]-Array, indicating the relevance label for each item.where โ An optional
[..., list_size]-Array, indicating which items are valid for computing the loss. Items for which this is False will be ignored when computing the loss.weights โ An optional
[..., list_size]-Array, indicating the weight for each item.reduce_fn โ An optional function that reduces the loss values. Can be
jax.numpy.sum()orjax.numpy.mean(). IfNone, no reduction is performed.
- Returns:
The ranking softmax loss.