optax.losses.multiclass_generalized_dice_loss

optax.losses.multiclass_generalized_dice_loss#

optax.losses.multiclass_generalized_dice_loss(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, smooth: jax.typing.ArrayLike = 1.0, apply_softmax: bool = True, ignore_background: bool = False) Array[source]#

Computes Multiclass Generalized Dice Loss with automatic class weighting.

Computes Generalized Dice Loss where class weights are automatically computed as the inverse of the squared class frequencies. This helps handle class imbalance in segmentation tasks.

Parameters:
  • predictions – Logits of shape […, num_classes].

  • targets – One-hot encoded targets of shape […, num_classes].

  • smooth – Smoothing parameter.

  • apply_softmax – Whether to apply softmax to predictions.

  • ignore_background – If True, excludes the first class (index 0) from loss computation. Useful when class 0 represents background.

Returns:

Scalar loss value averaged across all classes and batch.

References

Sudre et al. “Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations” (2017).