optax.losses.sigmoid_focal_loss

optax.losses.sigmoid_focal_loss#

optax.losses.sigmoid_focal_loss(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, *, alpha: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, gamma: jax.typing.ArrayLike = 2.0) Array[source]#

Sigmoid focal loss with numerical stability improvements.

The focal loss is a dynamically scaled cross entropy loss, where the scaling factor decays to zero as confidence in the correct class increases. This addresses class imbalance by down-weighting easy examples and focusing on hard examples.

This implementation uses log-space computation for the focal weight \((1-p_t)^\gamma\) to ensure numerical stability, especially for \(\gamma < 2\) and extreme logit values.

The loss is defined as:

\[FL(p_t) = -\alpha_t (1-p_t)^\gamma \log(p_t) \]

where \(p_t\) is the predicted probability of the correct class:

\[p_t = \begin{cases} p & \text{if } y = 1 \\ 1-p & \text{if } y = 0 \end{cases} \]

and \(\alpha_t\) is the weighting factor:

\[\alpha_t = \begin{cases} \alpha & \text{if } y = 1 \\ 1-\alpha & \text{if } y = 0 \end{cases} \]
Parameters:
  • logits โ€“ Array of unnormalized log probabilities, with shape [โ€ฆ, ]. The predictions for each example.

  • labels โ€“ Array of labels with shape broadcastable to logits. Can be: - Binary labels {0, 1} for binary classification - Continuous labels [0, 1] for soft targets or label smoothing

  • alpha โ€“ (optional) Weighting factor in range (0, 1) to balance positive vs negative examples. Default None (no weighting).

  • gamma โ€“ Exponent of the modulating factor (1 - p_t). Higher values focus more on hard examples. Default 2.0.

Returns:

Focal loss values with shape identical to logits.

References

Lin et al, Focal Loss for Dense Object Detection, 2017

Changed in version 0.2.5: Added numerical stability improvements using log-space computation. Added support for continuous labels in [0, 1].

Changed in version 0.2.9: Reduced peak memory usage of focal weight computation.