optax.losses.dice_loss#
- optax.losses.dice_loss(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, class_weights: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, smooth: jax.typing.ArrayLike = 1.0, alpha: jax.typing.ArrayLike = 0.5, beta: jax.typing.ArrayLike = 0.5, apply_softmax: bool = True, reduction: str = 'mean', ignore_background: bool = False, axis: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Computes the Dice Loss for multi-class segmentation.
Computes the Soft Dice Loss for segmentation tasks. This implementation includes parameters to weigh false positives and false negatives, making it a generalization of the standard Dice Loss. Works for both binary and multi-class segmentation.
The loss is computed per class and then averaged (or summed) across classes. For class c:
\[intersection_c = \sum_i^{N} p_{i,c} \cdot t_{i,c} \\ dice_c = 1 - \frac{ intersection_c + smooth }{ intersection_c + \alpha \cdot (P_c - intersection_c) + \beta \cdot (T_c - intersection_c) + smooth } \]- where:
\(p_{i,c}\): predicted probability for class c at pixel i.
\(t_{i,c}\): target value (0 or 1) for class c at pixel i.
\(P_c = \sum_i p_{i,c}\) (sum of predicted probabilities for class c)
\(T_c = \sum_i t_{i,c}\) (sum of target values for class c)
\(\alpha\): weight for false positives (\(FP_c = P_c - intersection_c\)).
\(\beta\): weight for false negatives (\(FN_c = T_c - intersection_c\)).
Note: With the default \(\alpha = \beta = 0.5\), this is equivalent to the standard Dice coefficient. Setting \(\alpha > \beta\) penalizes false positives more, while \(\beta > \alpha\) penalizes false negatives more (Tversky loss).
- Parameters:
predictions – Logits of shape […, num_classes] for multi-class or […, 1] or […] for binary segmentation.
targets – One-hot encoded targets of shape […, num_classes] for multi-class or binary targets of shape […, 1] or […] for binary.
class_weights – Optional weights for each class of shape [num_classes]. If None, all classes weighted equally.
smooth – Smoothing parameter to avoid division by zero and improve gradient stability.
alpha – Weight for false positives. Defaults to 0.5 (standard Dice).
beta – Weight for false negatives. Defaults to 0.5 (standard Dice).
apply_softmax – Whether to apply softmax to predictions. Set False if predictions are already probabilities.
reduction – How to reduce across classes: ‘mean’, ‘sum’, or ‘none’. ‘none’ returns per-class losses.
ignore_background – If True, excludes the first class (index 0) from loss computation. Useful when class 0 represents background.
axis – Axis or sequence of axes to sum over when computing the loss.
None (If)
first (sums over all spatial dimensions (all except the)
example (and last). For)
shape (with input)
dimensions. (default is to sum over H and W)
- Returns:
‘mean’/’sum’: […] (batch dimensions only)
’none’: […, num_classes] (includes class dimension)
- Return type:
Loss values. Shape depends on reduction
Examples
Binary segmentation (standard Dice):
>>> import jax.numpy as jnp >>> from optax.losses import dice_loss >>> logits = jnp.array([[1.0, -1.0], [0.5, 0.5]]) # Shape: [2, 2] >>> targets = jnp.array([[1.0, 0.0], [1.0, 0.0]]) # Shape: [2, 2] >>> loss = dice_loss(logits[..., None], targets[..., None]) >>> loss.shape (2,)
Multi-class Dice with custom weighting for false positives/negatives:
>>> import jax >>> key = jax.random.PRNGKey(0) >>> logits = jax.random.normal(key, (2, 4, 4, 3)) >>> labels = jax.random.randint(key, (2, 4, 4), 0, 3) >>> targets = jax.nn.one_hot(labels, 3) >>> loss = dice_loss( ... logits, targets, alpha=0.3, beta=0.7 ... ) >>> loss.shape (2,)
References
Milletari et al. “V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation” (2016).