optax.adan

Contents

optax.adan#

optax.adan(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.98, b2: jax.typing.ArrayLike = 0.92, b3: jax.typing.ArrayLike = 0.99, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 1e-08, weight_decay: base.ScalarOrSchedule = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformationExtraArgs[source]#

The ADAptive Nesterov momentum algorithm (Adan).

Adan first reformulates the vanilla Nesterov acceleration to develop a new Nesterov momentum estimation (NME) method, which avoids the extra overhead of computing gradient at the extrapolation point. Then Adan adopts NME to estimate the gradient’s first- and second-order moments in adaptive gradient algorithms for convergence acceleration.

The algorithm is as follows. First, we define the following parameters:

  • \(\eta > 0\): the step size.

  • \(\beta_1 \in [0, 1]\): the decay rate for the exponentially weighted average of gradients.

  • \(\beta_2 \in [0, 1]\): the decay rate for the exponentially weighted average of differences of gradients.

  • \(\beta_3 \in [0, 1]\): the decay rate for the exponentially weighted average of the squared term.

  • \(\varepsilon > 0\): a small constant for numerical stability.

  • \(\lambda > 0\): a weight decay.

Second, we define the following variables:

  • \(\theta_t\): the parameters.

  • \(g_t\): the incoming stochastic gradient.

  • \(m_t\): the exponentially weighted average of gradients.

  • \(v_t\): the exponentially weighted average of differences of gradients.

  • \(n_t\): the exponentially weighted average of the squared term.

  • \(u_t\): the outgoing update vector.

  • \(S_t\): the saved state of the optimizer.

Third, we initialize these variables as follows:

  • \(m_0 = g_0\)

  • \(v_0 = 0\)

  • \(v_1 = g_1 - g_0\)

  • \(n_0 = g_0^2\)

Finally, on each iteration, we update the variables as follows:

\[\begin{align*} m_t &\gets (1 - \beta_1) m_{t-1} + \beta_1 g_t \\ v_t &\gets (1 - \beta_2) v_{t-1} + \beta_2 (g_t - g_{t-1}) \\ n_t &\gets (1 - \beta_3) n_{t-1} + \beta_3 (g_t + (1 - \beta_2) (g_t - g_{t-1}))^2 \\ \eta_t &\gets \eta / ({\sqrt{n_t + \bar{\varepsilon}} + \varepsilon}) \\ u_t &\gets (\theta_t - \eta_t \circ (m_t + (1 - \beta_2) v_t)) / (1 + \lambda \eta) \\ S_t &\leftarrow (m_t, v_t, n_t). \end{align*}\]
Parameters:
  • learning_rate – this is a fixed global scaling factor.

  • b1 – Decay rate for the EWMA of gradients.

  • b2 – Decay rate for the EWMA of differences of gradients.

  • b3 – Decay rate for the EMWA of the algorithm’s squared term.

  • eps – Term added to the denominator to improve numerical stability.

  • eps_root – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • weight_decay – Strength of the weight decay regularization.

  • mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip.

Returns:

the corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> f = lambda x: x @ x  # simple quadratic function
>>> solver = optax.adan(learning_rate=1e-1)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.28E+01
Objective function: 1.17E+01
Objective function: 1.07E+01
Objective function: 9.68E+00
Objective function: 8.76E+00

References

Xie et al, Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models, 2022