optax.contrib.scale_by_ademamix

optax.contrib.scale_by_ademamix#

optax.contrib.scale_by_ademamix(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, b3: base.ScalarOrSchedule = 0.9999, alpha: base.ScalarOrSchedule = 6.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformation[source]#

Scale updates according to the Ademamix algorithm.

See optax.contrib.ademamix.() for a full description of the algorithm.

References

Pagliardini et al, The AdEMAMix Optimizer: Better, Faster, Older, 2024

Parameters:
  • b1 โ€“ Exponential decay rate to track the fast EMA.

  • b2 โ€“ Exponential decay rate to track the second moment of past gradients.

  • b3 โ€“ Exponential decay rate to track the slow EMA.

  • alpha โ€“ Mixing coefficient in the linear combination for the fast and slow EMAs.

  • eps โ€“ A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root โ€“ A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • mu_dtype โ€“ Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Returns:

The corresponding GradientTransformation.