optax.contrib.momo_adam#
- optax.contrib.momo_adam(learning_rate: base.ScalarOrSchedule = 0.01, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, lower_bound: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, adapt_lower_bound: bool = False) base.GradientTransformationExtraArgs[source]#
Adaptive Learning Rates for Adam(W).
MoMo-Adam typically needs less tuning for value of
learning_rate, by exploiting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.MoMo performs Adam(W) with a Polyak-type learning rate. The effective step size is
min(learning_rate, <adaptive term>), where the adaptive term is computed on the fly.Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument
value.- Parameters:
learning_rate โ User-specified learning rate. Recommended to be chosen rather large, by default 1.0.
b1 โ Exponential decay rate to track the first moment of past gradients.
b2 โ Exponential decay rate to track the second moment of past gradients.
eps โ eps for the underlying Adam Optimizer.
lower_bound โ Lower bound of the loss. Zero should be a good choice for many tasks.
weight_decay โ Weight-decay parameter. Momo-Adam performs weight decay in similar fashion to AdamW.
adapt_lower_bound โ If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).
- Returns:
A
GradientTransformationobject.
Examples
>>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.momo_adam() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 0.00029999594 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0
References
Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023
Added in version 0.2.3.