optax.losses.squared_error

Contents

optax.losses.squared_error#

optax.losses.squared_error(predictions: jax.typing.ArrayLike, targets: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Calculates the squared error for a set of predictions.

Mean Squared Error can be computed as squared_error(a, b).mean().

Parameters:
  • predictions – a vector of arbitrary shape […].

  • targets – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Returns:

elementwise squared differences, with same shape as predictions.

Note

l2_loss = 0.5 * squared_error, where the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.