optax.losses.l2_loss

Contents

optax.losses.l2_loss#

optax.losses.l2_loss(predictions: jax.typing.ArrayLike, targets: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Calculates the L2 loss for a set of predictions.

Parameters:
  • predictions – a vector of arbitrary shape […].

  • targets – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Returns:

elementwise squared differences, with same shape as predictions.

Note

the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.