optax.contrib.madgrad

Contents

optax.contrib.madgrad#

optax.contrib.madgrad(learning_rate: base.ScalarOrSchedule, momentum: float = 0.9, weight_decay: float = 0.0, eps: float = 1e-06) base.GradientTransformation[source]#

The MADGRAD optimizer.

MADGRAD is a general purpose optimizer that matches the performance of SGD+Momentum on vision tasks and Adam on NLP tasks.

Parameters:
  • learning_rate โ€“ A global scaling factor, either fixed or evolving along iterations with a scheduler.

  • momentum โ€“ Momentum parameter (default: 0.9).

  • weight_decay โ€“ Strength of the weight decay regularization (L2).

  • eps โ€“ Term added to the denominator to improve numerical stability.

Returns:

The corresponding optax.GradientTransformation.

References

Defazio et al, Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization, 2021.