optax.contrib.schedule_free_adamw

optax.contrib.schedule_free_adamw#

optax.contrib.schedule_free_adamw(learning_rate: jax.typing.ArrayLike = 0.0025, warmup_steps: int | None = None, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike = 0.0, weight_lr_power: jax.typing.ArrayLike = 2.0, state_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformationExtraArgs[source]#

Schedule-Free wrapper for AdamW.

Shortcut example for using schedule_free with AdamW, which is a common use case. Note that this is just an example, and other usecases are possible, e.g. using a weight decay mask, nesterov, etc. Note also that the EMA parameter of the schedule free method (b1) must be strictly positive.

Parameters:
  • learning_rate โ€“ AdamW learning rate.

  • warmup_steps โ€“ positive integer, the length of the linear warmup.

  • b1 โ€“ beta_1 parameter in the y update.

  • b2 โ€“ Exponential decay rate to track the second moment of past gradients.

  • eps โ€“ A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • weight_decay โ€“ Strength of the weight decay regularization.

  • weight_lr_power โ€“ we downweight the weight of averaging using this. This is especially helpful in early iterations during warmup.

  • state_dtype โ€“ dtype for z sequence in the schedule free method.

Returns:

A optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.contrib.schedule_free_adamw(1.0)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  eval_params = optax.contrib.schedule_free_eval_params(
...      opt_state, params)
...  print('Objective function: {:.2E}'.format(f(eval_params)))
Objective function: 5.00E+00
Objective function: 3.05E+00
Objective function: 1.73E+00
Objective function: 8.94E-01
Objective function: 4.13E-01

Note

Note that optax.scale_by_adam() with b1=0 stores in its state an unused first moment always equal to zero. To avoid this waste of memory, we replace optax.scale_by_adam() with b1=0 by the equivalent optax.scale_by_rms() with eps_in_sqrt=False, bias_correction=True.