optax.losses.smooth_labels

Contents

optax.losses.smooth_labels#

optax.losses.smooth_labels(labels: jax.typing.ArrayLike, alpha: jax.typing.ArrayLike, *, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Apply label smoothing.

Label smoothing is often used in combination with a cross-entropy loss. Smoothed labels favor small logit gaps, and it has been shown that this can provide better model calibration by preventing overconfident predictions.

Parameters:
  • labels โ€“ One hot labels to be smoothed.

  • alpha โ€“ The smoothing factor.

  • axis โ€“ Axis or axes along which to compute.

  • where โ€“ Elements to include in the computation.

Returns:

a smoothed version of the one hot input labels.

References

Muller et al, When does label smoothing help?, 2019