optax.contrib.scale_by_simplified_ademamix#
- optax.contrib.scale_by_simplified_ademamix(b1: jax.typing.ArrayLike = 0.99, b2: jax.typing.ArrayLike = 0.95, alpha: base.ScalarOrSchedule = 0.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0) base.GradientTransformation[source]#
Scale updates according to the Simplified AdEMAMix optimizer.
See
optax.contrib.simplified_ademamix.()for a full description.References
Morwani et al, Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants, 2025
- Parameters:
b1 โ Exponential decay rate to track the EMA.
b2 โ Exponential decay rate to track the second moment of past gradients.
alpha โ Mixing coefficient for the current gradient and EMA.
eps โ A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root โ A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.
- Returns:
The corresponding GradientTransformation.