optax.contrib.prodigy

Contents

optax.contrib.prodigy#

optax.contrib.prodigy(learning_rate: base.ScalarOrSchedule = 1.0, betas: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike] = (0.9, 0.999), beta3: jax.typing.ArrayLike | None = None, eps: jax.typing.ArrayLike = 1e-08, estim_lr0: jax.typing.ArrayLike = 1e-06, estim_lr_coef: jax.typing.ArrayLike = 1.0, weight_decay: jax.typing.ArrayLike = 0.0, safeguard_warmup: bool = False, weight_decay_mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformationExtraArgs[source]#

Learning rate free AdamW with Prodigy.

Implementation of the Prodigy method from โ€œProdigy: An Expeditiously Adaptive Parameter-Free Learnerโ€, a version of D-Adapt AdamW that adapts the baseline learning rate faster by using a weighting of the gradients that places higher weights on more recent gradients. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.

Parameters:
  • learning_rate โ€“ Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.

  • betas โ€“ Betas for the underlying AdamW Optimizer.

  • beta3 โ€“ Optional momentum parameter for estimation of D.

  • eps โ€“ eps for the underlying AdamW Optimizer.

  • estim_lr0 โ€“ Initial (under-)estimate of the learning rate.

  • estim_lr_coef โ€“ LR estimates are multiplied by this parameter.

  • weight_decay โ€“ AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.

  • safeguard_warmup โ€“ Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default.

  • weight_decay_mask โ€“ A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

Returns:

A optax.GradientTransformation object.

References

Mishchenko et al, Prodigy: An Expeditiously Adaptive Parameter-Free Learner, 2023