optax.losses.squared_error#
- optax.losses.squared_error(predictions: jax.typing.ArrayLike, targets: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Calculates the squared error for a set of predictions.
Mean Squared Error can be computed as squared_error(a, b).mean().
- Parameters:
predictions – a vector of arbitrary shape […].
targets – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.
- Returns:
elementwise squared differences, with same shape as predictions.
Note
l2_loss = 0.5 * squared_error, where the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.