optax.adamax#
- optax.adamax(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08) base.GradientTransformationExtraArgs[source]#
A variant of the Adam optimizer that uses the infinity norm.
AdaMax is a variant of the
optax.adam()optimizer. By generalizing Adamโs \(L^2\) norm to an \(L^p\) norm and taking the limit as \(p \rightarrow \infty\), we obtain a simple and stable update rule.Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\) represent the arguments
b1,b2andepsrespectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.The
initfunction of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdatefunction of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \max(\left| g_t \right| + \varepsilon, \beta_2 \cdot v_{t-1}) \\ \hat{m}_t &\leftarrow m_t / (1-\beta_1^t) \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / v_t \\ S_t &\leftarrow (m_t, v_t). \end{align*}\]- Parameters:
learning_rate โ A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().b1 โ Exponential decay rate to track the first moment of past gradients.
b2 โ Exponential decay rate to track the maximum of past gradients.
eps โ A small constant applied to denominator to avoid dividing by zero when rescaling.
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adamax(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01
References
Kingma et al, 2014: https://arxiv.org/abs/1412.6980
See also