optax.contrib.scale_by_simplified_ademamix

optax.contrib.scale_by_simplified_ademamix#

optax.contrib.scale_by_simplified_ademamix(b1: jax.typing.ArrayLike = 0.99, b2: jax.typing.ArrayLike = 0.95, alpha: base.ScalarOrSchedule = 0.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0) base.GradientTransformation[source]#

Scale updates according to the Simplified AdEMAMix optimizer.

See optax.contrib.simplified_ademamix.() for a full description.

References

Morwani et al, Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants, 2025

Parameters:
  • b1 โ€“ Exponential decay rate to track the EMA.

  • b2 โ€“ Exponential decay rate to track the second moment of past gradients.

  • alpha โ€“ Mixing coefficient for the current gradient and EMA.

  • eps โ€“ A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root โ€“ A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

Returns:

The corresponding GradientTransformation.