optax.losses.sparsemax_loss

Contents

optax.losses.sparsemax_loss#

optax.losses.sparsemax_loss(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike) Array[source]#

Binary sparsemax loss.

This loss is zero if and only if jax.nn.sparse_sigmoid(logits) == labels.

Parameters:
  • logits โ€“ score produced by the model (float).

  • labels โ€“ ground-truth integer label (0 or 1).

Returns:

loss value

References

Learning with Fenchel-Young Losses. Mathieu Blondel, Andrรฉ F. T. Martins, Vlad Niculae. JMLR 2020. (Sec. 4.4)

Added in version 0.2.3.