optax.contrib.dadapt_adamw#
- optax.contrib.dadapt_adamw(learning_rate: base.ScalarOrSchedule = 1.0, betas: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike] = (0.9, 0.999), eps: jax.typing.ArrayLike = 1e-08, estim_lr0: jax.typing.ArrayLike = 1e-06, weight_decay: jax.typing.ArrayLike = 0.0) base.GradientTransformationExtraArgs[source]#
Learning rate free AdamW by D-Adaptation.
Adapts the baseline learning rate of AdamW automatically by estimating the initial distance to solution in the infinity norm. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.
- Parameters:
learning_rate โ Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.
betas โ Betas for the underlying AdamW Optimizer.
eps โ eps for the underlying AdamW Optimizer.
estim_lr0 โ Initial (under-)estimate of the learning rate.
weight_decay โ AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.
- Returns:
The corresponding
optax.GradientTransformation.
References
Defazio et al, Learning-Rate-Free Learning by D-Adaptation, 2023