optax.lookahead#
- optax.lookahead(fast_optimizer: optax.GradientTransformation, sync_period: jax.typing.ArrayLike, slow_step_size: jax.typing.ArrayLike, reset_state: bool = False) optax.GradientTransformation[source]#
Lookahead optimizer.
Performs steps with a fast optimizer and periodically updates a set of slow parameters. Optionally resets the fast optimizer state after synchronization by calling the init function of the fast optimizer.
Updates returned by the lookahead optimizer should not be modified before they are applied, otherwise fast and slow parameters are not synchronized correctly.
- Parameters:
fast_optimizer โ The optimizer to use in the inner loop of lookahead.
sync_period โ Number of fast optimizer steps to take before synchronizing parameters. Must be >= 1.
slow_step_size โ Step size of the slow parameter updates.
reset_state โ Whether to reset the optimizer state of the fast optimizer after each synchronization.
- Returns:
A
optax.GradientTransformationwith init and update functions. The updates passed to the update function should be calculated using the fast lookahead parameters only.
Example
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> fast_opt = optax.sgd(1e-2) >>> opt = optax.lookahead(fast_opt, sync_period=5, slow_step_size=0.5) >>> params = optax.LookaheadParams.init_synced(jnp.ones((2,))) >>> state = opt.init(params) >>> loss_fn = lambda p: jnp.sum(p**2) >>> # Calculate gradients wrt the fast parameters >>> grads = jax.grad(loss_fn)(params.fast) >>> updates, state = opt.update(grads, state, params) >>> params = optax.apply_updates(params, updates) >>> # Calculate the eval loss wrt the slow parameters >>> loss_fn(params.slow) Array(2., dtype=float32)
References
Zhang et al, Lookahead Optimizer: k steps forward, 1 step back, 2019