optax.contrib.schedule_free

Contents

optax.contrib.schedule_free#

optax.contrib.schedule_free(base_optimizer: base.GradientTransformation, learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, weight_lr_power: jax.typing.ArrayLike = 2.0, state_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformationExtraArgs[source]#

Turn base_optimizer schedule_free.

Accumulates updates returned by the base_optimizer w/o Momentum and replaces the momentum of an underlying optimizer with a combination of interpolation and averaging. In the case of gradient descent the update is

\[\begin{align*} y_{t} & = (1-\beta_1)z_{t} + \beta_1 x_{t},\\ z_{t+1} & =z_{t}-\gamma\nabla f(y_{t}),\\ x_{t+1} & =\left(1-\frac{1}{t}\right)x_{t}+\frac{1}{t}z_{t+1}, \end{align*}\]

Here \(x\) is the sequence that evaluations of test/val loss should occur at, which differs from the primary iterates \(z\) and the gradient evaluation locations \(y\). The updates to \(z\) correspond to the underlying optimizer, in this case a simple gradient step. Note that, \(\beta_1\) corresponds to b1 in the code.

As the name suggests, Schedule-Free learning does not require a decreasing learning rate schedule, yet typically out-performs, or at worst matches, SOTA schedules such as cosine-decay and linear decay. Only two sequences need to be stored at a time (the third can be computed from the other two on the fly) so this method has the same memory requirements as the base optimizer (parameter buffer + momentum).

In practice, authors recommend tuning \(\beta_1\), warmup_steps and peak_lr for each problem separately. Default for \(\beta_1\) is 0.9 but 0.95 and 0.98 may also work well. Schedule-Free can be wrapped on top of any optax optimizer. At test time, the parameters should be evaluated using optax.contrib.schedule_free_eval_params() as presented below.

For example, change this:

learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr)
optimizer = optax.adam(learning_rate_fn, b1=b1)

To:

learning_rate_fn = optax.warmup_constant_schedule(peak_value=retuned_lr)
optimizer = optax.adam(learning_rate_fn, b1=0.)
optimizer = optax.contrib.schedule_free(optimizer, learning_rate_fn, b1=b1)
..
params_for_eval = optax.contrib.schedule_free_eval_params(state, params)

Especially note that is important to switch off Momentum of the base optimizer. As of Apr, 2024, schedule_free is tested with SGD and Adam.

Parameters:
  • base_optimizer โ€“ Base optimizer to compute updates from.

  • learning_rate โ€“ learning_rate schedule w/o decay but with warmup.

  • b1 โ€“ beta_1 parameter in the y update.

  • weight_lr_power โ€“ we downweight the weight of averaging using this. This is especially helpful in early iterations during warmup.

  • state_dtype โ€“ dtype for z sequence in the schedule free method.

Returns:

A optax.GradientTransformationExtraArgs.

References

Defazio et al, The Road Less Scheduled, 2024

Defazio et al, Schedule-Free Learning - A New Way to Train, 2024

Warning

The current implementation requires the parameter b1 to be strictly positive.