optax.contrib.prodigy

Contents

optax.contrib.prodigy#

optax.contrib.prodigy(learning_rate: base.ScalarOrSchedule = 1.0, betas: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike] = (0.9, 0.999), beta3: jax.typing.ArrayLike | None = None, eps: jax.typing.ArrayLike = 1e-08, estim_lr0: jax.typing.ArrayLike = 1e-06, estim_lr_coef: jax.typing.ArrayLike = 1.0, weight_decay: jax.typing.ArrayLike = 0.0, safeguard_warmup: bool = False) base.GradientTransformationExtraArgs[source]#

Learning rate free AdamW with Prodigy.

Implementation of the Prodigy method from โ€œProdigy: An Expeditiously Adaptive Parameter-Free Learnerโ€, a version of D-Adapt AdamW that adapts the baseline learning rate faster by using a weighting of the gradients that places higher weights on more recent gradients. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.

Parameters:
  • learning_rate โ€“ Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.

  • betas โ€“ Betas for the underlying AdamW Optimizer.

  • beta3 โ€“ Optional momentum parameter for estimation of D.

  • eps โ€“ eps for the underlying AdamW Optimizer.

  • estim_lr0 โ€“ Initial (under-)estimate of the learning rate.

  • estim_lr_coef โ€“ LR estimates are multiplied by this parameter.

  • weight_decay โ€“ AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.

  • safeguard_warmup โ€“ Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default.

Returns:

A optax.GradientTransformation object.

References

Mishchenko et al, Prodigy: An Expeditiously Adaptive Parameter-Free Learner, 2023