optax.contrib.momo

Contents

optax.contrib.momo#

optax.contrib.momo(learning_rate: base.ScalarOrSchedule = 1.0, beta: jax.typing.ArrayLike = 0.9, lower_bound: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, adapt_lower_bound: bool = False) base.GradientTransformationExtraArgs[source]#

Adaptive Learning Rates for SGD with momentum.

MoMo typically needs less tuning for value of learning_rate, by exploiting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.

MoMo performs SGD with momentum with a Polyak-type learning rate. The effective step size is min(learning_rate, <adaptive term>), where the adaptive term is computed on the fly.

Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument value.

Parameters:
  • learning_rate โ€“ User-specified learning rate. Recommended to be chosen rather large, by default 1.0.

  • beta โ€“ Momentum coefficient (for EMA).

  • lower_bound โ€“ Lower bound of the loss. Zero should be a good choice for many tasks.

  • weight_decay โ€“ Weight-decay parameter.

  • adapt_lower_bound โ€“ If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).

Returns:

A optax.GradientTransformation object.

Examples

>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = contrib.momo()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  params, opt_state = solver.update(grad, opt_state, params, value=value)
...  print('Objective function: ', f(params))
Objective function:  3.5
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0

References

Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023

Added in version 0.2.3.