optax.losses.binary_dice_loss

optax.losses.binary_dice_loss#

optax.losses.binary_dice_loss(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, smooth: jax.typing.ArrayLike = 1.0, apply_sigmoid: bool = True) Array[source]#

Binary Dice Loss convenience function.

Parameters:
  • predictions โ€“ Logits of shape [โ€ฆ] or [โ€ฆ, 1].

  • targets โ€“ Binary targets of shape [โ€ฆ] or [โ€ฆ, 1].

  • smooth โ€“ Smoothing parameter.

  • apply_sigmoid โ€“ Whether to apply sigmoid to predictions.

Returns:

Loss values of shape [โ€ฆ] (batch dimensions only).