optax.scale_by_rms

Contents

optax.scale_by_rms#

optax.scale_by_rms(decay: jax.typing.ArrayLike = 0.9, eps: jax.typing.ArrayLike = 1e-08, initial_scale: jax.typing.ArrayLike = 0.0, eps_in_sqrt: bool = True, bias_correction: bool = False) optax.GradientTransformation[source]#

Rescale updates by the root of the exp. moving avg of the square.

See optax.rmsprop() for more details.

Parameters:
  • decay โ€“ Decay rate for the exponentially weighted average of squared grads.

  • eps โ€“ Term added to the denominator to improve numerical stability.

  • initial_scale โ€“ Initial value for second moment.

  • eps_in_sqrt โ€“ Whether to add eps in the square root of the denominator or outside the square root.

  • bias_correction โ€“ Whether to apply bias correction to the exponentially weighted average of squared grads.

Returns:

A optax.GradientTransformation object.

Note

Using scale_by_rms(decay=b2, eps_in_sqrt=False, bias_correction=True) will match the behavior of scale_by_adam(b1=0, b2=b2), while sparing the memory cost of storing the first moment.