optax.scale_by_stddev

optax.scale_by_stddev#

optax.scale_by_stddev(decay: jax.typing.ArrayLike = 0.9, eps: jax.typing.ArrayLike = 1e-08, initial_scale: jax.typing.ArrayLike = 0.0, eps_in_sqrt: bool = True, bias_correction: bool = False) optax.GradientTransformation[source]#

Rescale updates by the root of the centered exp. moving average of squares.

See optax.rmsprop() for more details.

Parameters:
  • decay โ€“ Decay rate for the exponentially weighted average of squared grads.

  • eps โ€“ Term added to the denominator to improve numerical stability.

  • initial_scale โ€“ Initial value for second moment.

  • eps_in_sqrt โ€“ Whether to add eps in the square root of the denominator or outside the square root.

  • bias_correction โ€“ Whether to apply bias correction to the first and second moment.

Returns:

A optax.GradientTransformation object.