optax.scale_by_adan#
- optax.scale_by_adan(b1: jax.typing.ArrayLike = 0.98, b2: jax.typing.ArrayLike = 0.92, b3: jax.typing.ArrayLike = 0.99, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0) optax.GradientTransformation[source]#
Rescale updates according to the Adan algorithm.
See
optax.adan()for more details.- Parameters:
b1 โ Decay rate for the EWMA of gradients.
b2 โ Decay rate for the EWMA of differences of gradients.
b3 โ Decay rate for the EMWA of the algorithmโs squared term.
eps โ Term added to the denominator to improve numerical stability.
eps_root โ Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.
- Returns:
An
optax.GradientTransformationobject.