optax.nadam

Contents

optax.nadam#

optax.nadam(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: Any | None = None, *, nesterov: bool = True) base.GradientTransformationExtraArgs#

The NAdam optimizer.

Nadam is a variant of optax.adam() with Nesterov’s momentum. The update rule of this solver is as follows:

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}\\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\]
Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 – Exponential decay rate to track the first moment of past gradients.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

  • mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.nadam(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01

References

Dozat, Incorporating Nesterov Momentum into Adam, 2016

Added in version 0.1.9.