optax.contrib.momo_adam

Contents

optax.contrib.momo_adam#

optax.contrib.momo_adam(learning_rate: base.ScalarOrSchedule = 0.01, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, lower_bound: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, adapt_lower_bound: bool = False) base.GradientTransformationExtraArgs[source]#

Adaptive Learning Rates for Adam(W).

MoMo-Adam typically needs less tuning for value of learning_rate, by exploiting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.

MoMo performs Adam(W) with a Polyak-type learning rate. The effective step size is min(learning_rate, <adaptive term>), where the adaptive term is computed on the fly.

Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument value.

Parameters:
  • learning_rate โ€“ User-specified learning rate. Recommended to be chosen rather large, by default 1.0.

  • b1 โ€“ Exponential decay rate to track the first moment of past gradients.

  • b2 โ€“ Exponential decay rate to track the second moment of past gradients.

  • eps โ€“ eps for the underlying Adam Optimizer.

  • lower_bound โ€“ Lower bound of the loss. Zero should be a good choice for many tasks.

  • weight_decay โ€“ Weight-decay parameter. Momo-Adam performs weight decay in similar fashion to AdamW.

  • adapt_lower_bound โ€“ If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).

Returns:

A GradientTransformation object.

Examples

>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = contrib.momo_adam()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  params, opt_state = solver.update(grad, opt_state, params, value=value)
...  print('Objective function: ', f(params))
Objective function:  0.00029999594
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0

References

Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023

Added in version 0.2.3.