optax.safe_root_mean_squares#
- optax.safe_root_mean_squares(x: jax.typing.ArrayLike, min_rms: jax.typing.ArrayLike) Array[source]#
Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads.
The gradients of maximum(sqrt(mean(abs_sq(x))), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.
- Parameters:
x โ jax array.
min_rms โ lower bound for the returned norm.
- Returns:
The safe RMS of the input vector, accounting for correct gradient.