optax.losses.multiclass_sparsemax_loss#
- optax.losses.multiclass_sparsemax_loss(scores: jax.typing.ArrayLike, labels: jax.typing.ArrayLike) Array[source]#
Multiclass sparsemax loss.
- Parameters:
scores โ scores produced by the model.
labels โ ground-truth integer labels.
- Returns:
loss values
References
Martins et al, From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification <https://arxiv.org/abs/1602.02068>, 2016.