optax.contrib.scale_by_ademamix#
- optax.contrib.scale_by_ademamix(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, b3: base.ScalarOrSchedule = 0.9999, alpha: base.ScalarOrSchedule = 6.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformation[source]#
Scale updates according to the Ademamix algorithm.
See
optax.contrib.ademamix.()for a full description of the algorithm.References
Pagliardini et al, The AdEMAMix Optimizer: Better, Faster, Older, 2024
- Parameters:
b1 โ Exponential decay rate to track the fast EMA.
b2 โ Exponential decay rate to track the second moment of past gradients.
b3 โ Exponential decay rate to track the slow EMA.
alpha โ Mixing coefficient in the linear combination for the fast and slow EMAs.
eps โ A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root โ A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.
mu_dtype โ Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
- Returns:
The corresponding GradientTransformation.