optax.losses.log_cosh

Contents

optax.losses.log_cosh#

optax.losses.log_cosh(predictions: jax.typing.ArrayLike, targets: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Calculates the log-cosh loss for a set of predictions.

log(cosh(x)) is approximately (x**2) / 2 for small x and abs(x) - log(2) for large x. It is a twice differentiable alternative to the Huber loss.

Parameters:
  • predictions โ€“ a vector of arbitrary shape [โ€ฆ].

  • targets โ€“ a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Returns:

the log-cosh loss, with same shape as predictions.

References

Chen et al, Log Hyperbolic Cosine Loss Improves Variational Auto-Encoder <https://openreview.net/pdf?id=rkglvsC9Ym>, 2019