optax.losses.multiclass_sparsemax_loss

optax.losses.multiclass_sparsemax_loss#

optax.losses.multiclass_sparsemax_loss(scores: jax.typing.ArrayLike, labels: jax.typing.ArrayLike) Array[source]#

Multiclass sparsemax loss.

Parameters:
  • scores โ€“ scores produced by the model.

  • labels โ€“ ground-truth integer labels.

Returns:

loss values

References

Martins et al, From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification <https://arxiv.org/abs/1602.02068>, 2016.