optax.losses.dice_loss

Contents

optax.losses.dice_loss#

optax.losses.dice_loss(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, class_weights: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, smooth: jax.typing.ArrayLike = 1.0, alpha: jax.typing.ArrayLike = 0.5, beta: jax.typing.ArrayLike = 0.5, apply_softmax: bool = True, reduction: str = 'mean', ignore_background: bool = False, axis: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Computes the Dice Loss for multi-class segmentation.

Computes the Soft Dice Loss for segmentation tasks. This implementation includes parameters to weigh false positives and false negatives, making it a generalization of the standard Dice Loss. Works for both binary and multi-class segmentation.

The loss is computed per class and then averaged (or summed) across classes. For class c:

\[intersection_c = \sum_i^{N} p_{i,c} \cdot t_{i,c} \\ dice_c = 1 - \frac{ intersection_c + smooth }{ intersection_c + \alpha \cdot (P_c - intersection_c) + \beta \cdot (T_c - intersection_c) + smooth } \]
where:
  • \(p_{i,c}\): predicted probability for class c at pixel i.

  • \(t_{i,c}\): target value (0 or 1) for class c at pixel i.

  • \(P_c = \sum_i p_{i,c}\) (sum of predicted probabilities for class c)

  • \(T_c = \sum_i t_{i,c}\) (sum of target values for class c)

  • \(\alpha\): weight for false positives (\(FP_c = P_c - intersection_c\)).

  • \(\beta\): weight for false negatives (\(FN_c = T_c - intersection_c\)).

Note: With the default \(\alpha = \beta = 0.5\), this is equivalent to the standard Dice coefficient. Setting \(\alpha > \beta\) penalizes false positives more, while \(\beta > \alpha\) penalizes false negatives more (Tversky loss).

Parameters:
  • predictions – Logits of shape […, num_classes] for multi-class or […, 1] or […] for binary segmentation.

  • targets – One-hot encoded targets of shape […, num_classes] for multi-class or binary targets of shape […, 1] or […] for binary.

  • class_weights – Optional weights for each class of shape [num_classes]. If None, all classes weighted equally.

  • smooth – Smoothing parameter to avoid division by zero and improve gradient stability.

  • alpha – Weight for false positives. Defaults to 0.5 (standard Dice).

  • beta – Weight for false negatives. Defaults to 0.5 (standard Dice).

  • apply_softmax – Whether to apply softmax to predictions. Set False if predictions are already probabilities.

  • reduction – How to reduce across classes: ‘mean’, ‘sum’, or ‘none’. ‘none’ returns per-class losses.

  • ignore_background – If True, excludes the first class (index 0) from loss computation. Useful when class 0 represents background.

  • axis – Axis or sequence of axes to sum over when computing the loss.

  • None (If)

  • first (sums over all spatial dimensions (all except the)

  • example (and last). For)

  • shape (with input)

  • dimensions. (default is to sum over H and W)

Returns:

  • ‘mean’/’sum’: […] (batch dimensions only)

  • ’none’: […, num_classes] (includes class dimension)

Return type:

Loss values. Shape depends on reduction

Examples

Binary segmentation (standard Dice):

>>> import jax.numpy as jnp
>>> from optax.losses import dice_loss
>>> logits = jnp.array([[1.0, -1.0], [0.5, 0.5]])  # Shape: [2, 2]
>>> targets = jnp.array([[1.0, 0.0], [1.0, 0.0]])  # Shape: [2, 2]
>>> loss = dice_loss(logits[..., None], targets[..., None])
>>> loss.shape
(2,)

Multi-class Dice with custom weighting for false positives/negatives:

>>> import jax
>>> key = jax.random.PRNGKey(0)
>>> logits = jax.random.normal(key, (2, 4, 4, 3))
>>> labels = jax.random.randint(key, (2, 4, 4), 0, 3)
>>> targets = jax.nn.one_hot(labels, 3)
>>> loss = dice_loss(
...     logits, targets, alpha=0.3, beta=0.7
... )
>>> loss.shape
(2,)

References

Milletari et al. “V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation” (2016).