optax.adam#
- optax.adam(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: Any | None = None, *, nesterov: bool = False) base.GradientTransformationExtraArgs[source]#
The Adam optimizer.
Adam is an SGD variant with gradient scaling adaptation. The scaling used for each parameter is computed from estimates of first and second-order moments of the gradients (using suitable exponential moving averages).
Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments
b1,b2,epsandeps_rootrespectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.The
initfunction of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdatefunction of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\]With the keyword argument nesterov=True, the optimizer uses Nesterov momentum, replacing the above \(\hat{m}_t\) with
\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. \]- Parameters:
learning_rate β A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().b1 β Exponential decay rate to track the first moment of past gradients.
b2 β Exponential decay rate to track the second moment of past gradients.
eps β A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root β A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.
mu_dtype β Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
nesterov β Whether to use Nesterov momentum. The solver with nesterov=True is equivalent to the
optax.nadam()optimizer, and described in [Dozat 2016].
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adam(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01
References
Kingma et al, Adam: A Method for Stochastic Optimization, 2014
Dozat, Incorporating Nesterov Momentum into Adam, 2016
Warning
PyTorch and optaxβs implementation follow Algorithm 1 of [Kingma et al. 2014]. Note that TensorFlow used instead the formulation just before Section 2.1 of the paper. See deepmind/optax#571 for more detail.
See also