optax.contrib.dadapt_adamw

Contents

optax.contrib.dadapt_adamw#

optax.contrib.dadapt_adamw(learning_rate: base.ScalarOrSchedule = 1.0, betas: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike] = (0.9, 0.999), eps: jax.typing.ArrayLike = 1e-08, estim_lr0: jax.typing.ArrayLike = 1e-06, weight_decay: jax.typing.ArrayLike = 0.0) base.GradientTransformationExtraArgs[source]#

Learning rate free AdamW by D-Adaptation.

Adapts the baseline learning rate of AdamW automatically by estimating the initial distance to solution in the infinity norm. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.

Parameters:
  • learning_rate โ€“ Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.

  • betas โ€“ Betas for the underlying AdamW Optimizer.

  • eps โ€“ eps for the underlying AdamW Optimizer.

  • estim_lr0 โ€“ Initial (under-)estimate of the learning rate.

  • weight_decay โ€“ AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.

Returns:

The corresponding optax.GradientTransformation.

References

Defazio et al, Learning-Rate-Free Learning by D-Adaptation, 2023