optax.scale_by_adan

Contents

optax.scale_by_adan#

optax.scale_by_adan(b1: jax.typing.ArrayLike = 0.98, b2: jax.typing.ArrayLike = 0.92, b3: jax.typing.ArrayLike = 0.99, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0) optax.GradientTransformation[source]#

Rescale updates according to the Adan algorithm.

See optax.adan() for more details.

Parameters:
  • b1 โ€“ Decay rate for the EWMA of gradients.

  • b2 โ€“ Decay rate for the EWMA of differences of gradients.

  • b3 โ€“ Decay rate for the EMWA of the algorithmโ€™s squared term.

  • eps โ€“ Term added to the denominator to improve numerical stability.

  • eps_root โ€“ Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

Returns:

An optax.GradientTransformation object.