optax.contrib.simplified_ademamix

optax.contrib.simplified_ademamix#

optax.contrib.simplified_ademamix(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.99, b2: jax.typing.ArrayLike = 0.95, alpha: base.ScalarOrSchedule = 0.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#

Simplified AdEMAMix.

Simplified AdEMAMix (Adaptive EMA Mixture) is a simplified version of AdEMAMix that eliminates the need for maintaining two separate momentum buffers and removes the requirement for scheduling the mixing parameter \(\alpha\). Setting \(\alpha = 0\) recovers the standard Adam optimizer, subject to appropriate transformations of \(\eta\) and \(\beta_1\).

Let \(\eta\) represent the learning rate and \(\beta_1, \beta_2\), \(\alpha, \varepsilon, \bar{\varepsilon}\), represent the arguments b1, b2, alpha, eps and eps_root respectively. Let \(\lambda\) be the weight decay and \(\theta_t\) the parameter vector at time \(t\).

The init function of this optimizer initializes an internal state \(S_0 := (m^{(1)}_0, \nu_0) = (0, 0)\), representing initial estimates for the EMAs of the first moment along with the second moment estimate. In practice, these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g^t\), the optimizer state \(S^t\) and the parameters \(\theta^{(t)}\). It then computes updates \(\theta^{(t+1)}\) and the new state \(S^{(t+1)}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_1^{(t)} &\leftarrow \beta_1 \cdot m_1^{(t-1)} + g^{(t)} \\ g^{(t)} \\ \nu^{(t)} &\leftarrow \beta_2 \cdot \nu^{(t-1)} + (1-\beta_2) \cdot {g^{(t)}}^2 \\ \hat{\nu}^{(t)} &\leftarrow \nu^{(t)} / {(1-\beta_2^{(t)})} \\ \theta^{(t)} &\leftarrow \theta^{(t-1)} - \eta \cdot \left( \frac{(m_1^{(t)} + \alpha g^{(t)})}{\left(\sqrt{\hat{\nu}^{(t)} + \bar{\varepsilon}} + \varepsilon\right)} + \lambda \theta^{(t-1)} \right).\\ S^{(t)} &\leftarrow (m_1^{(t)}, m_2^{(t)}, v^{(t)}). \end{align*}\]

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(jnp.square(x))  # simple quadratic function
>>> solver = optax.contrib.simplified_ademamix(learning_rate=0.01)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.36E+01
Objective function: 1.33E+01
Objective function: 1.28E+01
Objective function: 1.23E+01

References

Morwani et al, Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants, 2025

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 – Exponential decay rate to track the EMA.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • alpha – Mixing coefficient for the current gradient and EMA.

  • eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

Returns:

The corresponding GradientTransformation.

See also

See the related functions optax.adam(), optax.nadamw(), as well as the example Recreate AdeMAMix Rosenbrock Plot from Paper.