optax.losses.multiclass_generalized_dice_loss#
- optax.losses.multiclass_generalized_dice_loss(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, smooth: jax.typing.ArrayLike = 1.0, apply_softmax: bool = True, ignore_background: bool = False) Array[source]#
Computes Multiclass Generalized Dice Loss with automatic class weighting.
Computes Generalized Dice Loss where class weights are automatically computed as the inverse of the squared class frequencies. This helps handle class imbalance in segmentation tasks.
- Parameters:
predictions – Logits of shape […, num_classes].
targets – One-hot encoded targets of shape […, num_classes].
smooth – Smoothing parameter.
apply_softmax – Whether to apply softmax to predictions.
ignore_background – If True, excludes the first class (index 0) from loss computation. Useful when class 0 represents background.
- Returns:
Scalar loss value averaged across all classes and batch.
References
Sudre et al. “Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations” (2017).