optax.losses.huber_loss#
- optax.losses.huber_loss(predictions: jax.typing.ArrayLike, targets: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None, *, delta: jax.typing.ArrayLike = 1.0) Array[source]#
Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
If gradient descent is applied to the huber loss, it is equivalent to clipping gradients of an l2_loss to [-delta, delta] in the backward pass.
- Parameters:
predictions โ a vector of arbitrary shape [โฆ].
targets โ a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.
delta โ the bounds for the huber loss transformation, defaults at 1.
- Returns:
elementwise huber losses, with the same shape of predictions.
References
Huber loss, Wikipedia.