optax.contrib.momo#
- optax.contrib.momo(learning_rate: base.ScalarOrSchedule = 1.0, beta: jax.typing.ArrayLike = 0.9, lower_bound: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, adapt_lower_bound: bool = False) base.GradientTransformationExtraArgs[source]#
Adaptive Learning Rates for SGD with momentum.
MoMo typically needs less tuning for value of
learning_rate, by exploiting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.MoMo performs SGD with momentum with a Polyak-type learning rate. The effective step size is
min(learning_rate, <adaptive term>), where the adaptive term is computed on the fly.Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument
value.- Parameters:
learning_rate โ User-specified learning rate. Recommended to be chosen rather large, by default 1.0.
beta โ Momentum coefficient (for EMA).
lower_bound โ Lower bound of the loss. Zero should be a good choice for many tasks.
weight_decay โ Weight-decay parameter.
adapt_lower_bound โ If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).
- Returns:
A
optax.GradientTransformationobject.
Examples
>>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.momo() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 3.5 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0
References
Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023
Added in version 0.2.3.