optax.adamax

Contents

optax.adamax#

optax.adamax(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08) base.GradientTransformationExtraArgs[source]#

A variant of the Adam optimizer that uses the infinity norm.

AdaMax is a variant of the optax.adam() optimizer. By generalizing Adamโ€™s \(L^2\) norm to an \(L^p\) norm and taking the limit as \(p \rightarrow \infty\), we obtain a simple and stable update rule.

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\) represent the arguments b1, b2 and eps respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.

The init function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \max(\left| g_t \right| + \varepsilon, \beta_2 \cdot v_{t-1}) \\ \hat{m}_t &\leftarrow m_t / (1-\beta_1^t) \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / v_t \\ S_t &\leftarrow (m_t, v_t). \end{align*}\]
Parameters:
  • learning_rate โ€“ A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 โ€“ Exponential decay rate to track the first moment of past gradients.

  • b2 โ€“ Exponential decay rate to track the maximum of past gradients.

  • eps โ€“ A small constant applied to denominator to avoid dividing by zero when rescaling.

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adamax(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

References

Kingma et al, 2014: https://arxiv.org/abs/1412.6980