optax.contrib.scale_by_madgrad

optax.contrib.scale_by_madgrad#

optax.contrib.scale_by_madgrad(learning_rate: base.ScalarOrSchedule, momentum: float = 0.9, eps: float = 1e-06) base.GradientTransformation[source]#

Rescale updates according to the MADGRAD algorithm.

MADGRAD is a Dual Averaging method that maintains a weighted sum of gradients and squared gradients to compute adaptive updates. It effectively bridges the gap between the generalization performance of SGD and the convergence speed of adaptive methods like Adam.

Parameters:
  • learning_rate โ€“ A global scaling factor, either fixed or evolving along iterations with a scheduler.

  • momentum โ€“ Momentum parameter (default: 0.9).

  • eps โ€“ Term added to the denominator to improve numerical stability.

Returns:

A optax.GradientTransformation object.

References

Defazio et al, Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization, 2021.