optax.safe_root_mean_squares

optax.safe_root_mean_squares#

optax.safe_root_mean_squares(x: jax.typing.ArrayLike, min_rms: jax.typing.ArrayLike) Array[source]#

Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads.

The gradients of maximum(sqrt(mean(abs_sq(x))), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.

Parameters:
  • x โ€“ jax array.

  • min_rms โ€“ lower bound for the returned norm.

Returns:

The safe RMS of the input vector, accounting for correct gradient.