optax.scale_by_rms#
- optax.scale_by_rms(decay: jax.typing.ArrayLike = 0.9, eps: jax.typing.ArrayLike = 1e-08, initial_scale: jax.typing.ArrayLike = 0.0, eps_in_sqrt: bool = True, bias_correction: bool = False) optax.GradientTransformation[source]#
Rescale updates by the root of the exp. moving avg of the square.
See
optax.rmsprop()for more details.- Parameters:
decay โ Decay rate for the exponentially weighted average of squared grads.
eps โ Term added to the denominator to improve numerical stability.
initial_scale โ Initial value for second moment.
eps_in_sqrt โ Whether to add
epsin the square root of the denominator or outside the square root.bias_correction โ Whether to apply bias correction to the exponentially weighted average of squared grads.
- Returns:
A
optax.GradientTransformationobject.
Note
Using scale_by_rms(decay=b2, eps_in_sqrt=False, bias_correction=True) will match the behavior of scale_by_adam(b1=0, b2=b2), while sparing the memory cost of storing the first moment.