optax.losses.sparsemax_loss#
- optax.losses.sparsemax_loss(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike) Array[source]#
Binary sparsemax loss.
This loss is zero if and only if jax.nn.sparse_sigmoid(logits) == labels.
- Parameters:
logits โ score produced by the model (float).
labels โ ground-truth integer label (0 or 1).
- Returns:
loss value
References
Learning with Fenchel-Young Losses. Mathieu Blondel, Andrรฉ F. T. Martins, Vlad Niculae. JMLR 2020. (Sec. 4.4)
Added in version 0.2.3.