optax.nadamw#
- optax.nadamw(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: Any | None = None, weight_decay: base.ScalarOrSchedule = 0.0001, mask: Any | Callable[[base.Params], Any] | None = None, *, nesterov: bool = True) base.GradientTransformationExtraArgs#
NAdamW optimizer, implemented as part of the AdamW optimizer.
NadamW is variant of
optax.adamw()with Nesterov’s momentum. Compared to AdamW, this optimizer replaces the assignment\[\hat{m}_t \leftarrow m_t / {(1-\beta_1^t)}\]with
\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.\]- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().b1 – Exponential decay rate to track the first moment of past gradients.
b2 – Exponential decay rate to track the second moment of past gradients.
eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.
mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.
mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.nadamw(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.38E+01
References
Loshchilov et al, Decoupled Weight Decay Regularization, 2019
Dozat, Incorporating Nesterov Momentum into Adam, 2016
See also
Added in version 0.1.9.