optax.losses.sigmoid_focal_loss#
- optax.losses.sigmoid_focal_loss(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, *, alpha: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, gamma: jax.typing.ArrayLike = 2.0) Array[source]#
Sigmoid focal loss with numerical stability improvements.
The focal loss is a dynamically scaled cross entropy loss, where the scaling factor decays to zero as confidence in the correct class increases. This addresses class imbalance by down-weighting easy examples and focusing on hard examples.
This implementation uses log-space computation for the focal weight \((1-p_t)^\gamma\) to ensure numerical stability, especially for \(\gamma < 2\) and extreme logit values.
The loss is defined as:
\[FL(p_t) = -\alpha_t (1-p_t)^\gamma \log(p_t) \]where \(p_t\) is the predicted probability of the correct class:
\[p_t = \begin{cases} p & \text{if } y = 1 \\ 1-p & \text{if } y = 0 \end{cases} \]and \(\alpha_t\) is the weighting factor:
\[\alpha_t = \begin{cases} \alpha & \text{if } y = 1 \\ 1-\alpha & \text{if } y = 0 \end{cases} \]- Parameters:
logits โ Array of unnormalized log probabilities, with shape [โฆ, ]. The predictions for each example.
labels โ Array of labels with shape broadcastable to logits. Can be: - Binary labels {0, 1} for binary classification - Continuous labels [0, 1] for soft targets or label smoothing
alpha โ (optional) Weighting factor in range (0, 1) to balance positive vs negative examples. Default None (no weighting).
gamma โ Exponent of the modulating factor (1 - p_t). Higher values focus more on hard examples. Default 2.0.
- Returns:
Focal loss values with shape identical to logits.
References
Lin et al, Focal Loss for Dense Object Detection, 2017
Changed in version 0.2.5: Added numerical stability improvements using log-space computation. Added support for continuous labels in [0, 1].
Changed in version 0.2.9: Reduced peak memory usage of focal weight computation.