optax.scale_by_stddev#
- optax.scale_by_stddev(decay: jax.typing.ArrayLike = 0.9, eps: jax.typing.ArrayLike = 1e-08, initial_scale: jax.typing.ArrayLike = 0.0, eps_in_sqrt: bool = True, bias_correction: bool = False) optax.GradientTransformation[source]#
Rescale updates by the root of the centered exp. moving average of squares.
See
optax.rmsprop()for more details.- Parameters:
decay โ Decay rate for the exponentially weighted average of squared grads.
eps โ Term added to the denominator to improve numerical stability.
initial_scale โ Initial value for second moment.
eps_in_sqrt โ Whether to add
epsin the square root of the denominator or outside the square root.bias_correction โ Whether to apply bias correction to the first and second moment.
- Returns:
A
optax.GradientTransformationobject.