optax.losses.binary_dice_loss#
- optax.losses.binary_dice_loss(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, smooth: jax.typing.ArrayLike = 1.0, apply_sigmoid: bool = True) Array[source]#
Binary Dice Loss convenience function.
- Parameters:
predictions โ Logits of shape [โฆ] or [โฆ, 1].
targets โ Binary targets of shape [โฆ] or [โฆ, 1].
smooth โ Smoothing parameter.
apply_sigmoid โ Whether to apply sigmoid to predictions.
- Returns:
Loss values of shape [โฆ] (batch dimensions only).