Optimizers#

adabelief(learning_rate[, b1, b2, eps, eps_root])

The AdaBelief optimizer.

adadelta([learning_rate, rho, eps, ...])

The Adadelta optimizer.

adafactor([learning_rate, ...])

The Adafactor optimizer.

adagrad(learning_rate[, ...])

The Adagrad optimizer.

adam(learning_rate[, b1, b2, eps, eps_root, ...])

The Adam optimizer.

adamw(learning_rate[, b1, b2, eps, ...])

Adam with weight decay regularization.

adamax(learning_rate[, b1, b2, eps])

A variant of the Adam optimizer that uses the infinity norm.

adamaxw(learning_rate[, b1, b2, eps, ...])

Adamax with weight decay regularization.

amsgrad(learning_rate[, b1, b2, eps, ...])

The AMSGrad optimiser.

fromage(learning_rate[, min_norm])

The Frobenius matched gradient descent (Fromage) optimizer.

lamb(learning_rate[, b1, b2, eps, eps_root, ...])

The LAMB optimizer.

lars(learning_rate[, weight_decay, ...])

The LARS optimizer.

lion(learning_rate[, b1, b2, mu_dtype, ...])

The Lion optimizer.

nadam(learning_rate[, b1, b2, eps, ...])

The NAdam optimizer.

nadamw(learning_rate[, b1, b2, eps, ...])

NAdamW optimizer, implemented as part of the AdamW optimizer.

noisy_sgd(learning_rate[, eta, gamma, seed])

A variant of SGD with added noise.

novograd(learning_rate[, b1, b2, eps, ...])

NovoGrad optimizer.

optimistic_gradient_descent(learning_rate[, ...])

An Optimistic Gradient Descent optimizer.

polyak_sgd([max_learning_rate, scaling, ...])

SGD with Polyak step-size.

radam(learning_rate[, b1, b2, eps, ...])

The Rectified Adam optimizer.

rmsprop(learning_rate[, decay, eps, ...])

A flexible RMSProp optimizer.

sgd(learning_rate[, momentum, nesterov, ...])

A canonical Stochastic Gradient Descent optimizer.

sm3(learning_rate[, momentum])

The SM3 optimizer.

yogi(learning_rate[, b1, b2, eps])

The Yogi optimizer.

AdaBelief#

optax.adabelief(learning_rate, b1=0.9, b2=0.999, eps=1e-16, eps_root=1e-16)[source]#

The AdaBelief optimizer.

AdaBelief is an adaptive learning rate optimizer that focuses on fast convergence, generalization, and stability. It adapts the step size depending on its “belief” in the gradient direction — the optimizer adaptively scales the step size by the difference between the predicted and observed gradients. AdaBelief is a modified version of optax.adam() and contains the same number of parameters.

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments b1, b2, eps and eps_root respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.

The init function of this optimizer initializes an internal state \(S_0 := (m_0, s_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ s_t &\leftarrow \beta_2 \cdot s_{t-1} + (1-\beta_2) \cdot (g_t - m_t)^2 + \bar{\varepsilon} \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{s}_t &\leftarrow s_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left(\sqrt{\hat{s}_t} + \varepsilon \right) \\ S_t &\leftarrow (m_t, s_t). \end{align*}\]

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adabelief(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01

References

Zhuang et al, 2020: https://arxiv.org/abs/2010.07468

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

AdaDelta#

optax.adadelta(learning_rate=None, rho=0.9, eps=1e-06, weight_decay=0.0, weight_decay_mask=None)[source]#

The Adadelta optimizer.

Adadelta is a stochastic gradient descent method that adapts learning rates based on a moving window of gradient updates. Adadelta is a modification of Adagrad.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> f = lambda x: jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adadelta(learning_rate=10.)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.36E+01
Objective function: 1.32E+01
Objective function: 1.29E+01
Objective function: 1.25E+01
Objective function: 1.21E+01

References

[Matthew D. Zeiler, 2012](https://arxiv.org/pdf/1212.5701.pdf)

Parameters:
  • learning_rate (Optional[base.ScalarOrSchedule]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • rho (float) – A coefficient used for computing a running average of squared gradients.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • weight_decay (float) – Optional rate at which to decay weights.

  • weight_decay_mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Return type:

optax.GradientTransformation

Returns:

The corresponding GradientTransformation.

AdaGrad#

optax.adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-07)[source]#

The Adagrad optimizer.

Adagrad is an algorithm for gradient based optimization that anneals the learning rate for each parameter during the course of training.

Warning

Adagrad’s main limit is the monotonic accumulation of squared gradients in the denominator: since all terms are >0, the sum keeps growing during training and the learning rate eventually becomes vanishingly small.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adagrad(learning_rate=1.0)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 5.01E+00
Objective function: 2.40E+00
Objective function: 1.25E+00
Objective function: 6.86E-01
Objective function: 3.85E-01

References

Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • initial_accumulator_value (float) – Initial value for the accumulator.

  • eps (float) – A small constant applied to denominator inside of the square root (as in RMSProp) to avoid dividing by zero when rescaling.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

AdaFactor#

optax.adafactor(learning_rate=None, min_dim_size_to_factor=128, decay_rate=0.8, decay_offset=0, multiply_by_parameter_scale=True, clipping_threshold=1.0, momentum=None, dtype_momentum=<class 'jax.numpy.float32'>, weight_decay_rate=None, eps=1e-30, factored=True, weight_decay_mask=None)[source]#

The Adafactor optimizer.

Adafactor is an adaptive learning rate optimizer that focuses on fast training of large scale neural networks. It saves memory by using a factored estimate of the second order moments used to scale gradients.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adafactor(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.36E+01

References

Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235

Parameters:
  • learning_rate (Optional[base.ScalarOrSchedule]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate(). Note that the natural scale for Adafactor’s LR is markedly different from Adam, one doesn’t use the 1/sqrt(hidden) correction for this optim with attention-based models.

  • min_dim_size_to_factor (int) – Only factor the statistics if two array dimensions have at least this size.

  • decay_rate (float) – Controls second-moment exponential decay schedule.

  • decay_offset (int) – For fine-tuning, one may set this to the starting step number of the fine-tuning phase.

  • multiply_by_parameter_scale (float) – If True, then scale learning_rate by parameter norm. If False, provided learning_rate is absolute step size.

  • clipping_threshold (Optional[float]) – Optional clipping threshold. Must be >= 1. If None, clipping is disabled.

  • momentum (Optional[float]) – Optional value between 0 and 1, enables momentum and uses extra memory if non-None! None by default.

  • dtype_momentum (Any) – Data type of momentum buffers.

  • weight_decay_rate (Optional[float]) – Optional rate at which to decay weights.

  • eps (float) – Regularization constant for root mean squared gradient.

  • factored (bool) – Whether to use factored second-moment estimates.

  • weight_decay_mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Return type:

optax.GradientTransformation

Returns:

The corresponding GradientTransformation.

Adam#

optax.adam(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None, *, nesterov=False)[source]#

The Adam optimizer.

Adam is an SGD variant with gradient scaling adaptation. The scaling used for each parameter is computed from estimates of first and second-order moments of the gradients (using suitable exponential moving averages).

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments b1, b2, eps and eps_root respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.

The init function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\]

With the keyword argument nesterov=True, the optimizer uses Nesterov momentum, replacing the above \(\hat{m}_t\) with

\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. \]

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adam(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

References

Kingma et al, Adam: A Method for Stochastic Optimization, 2014

Dozat, Incorporating Nesterov Momentum into Adam, 2016

Warning

PyTorch and optax’s implementation follow Algorithm 1 of [Kingma et al. 2014]. Note that TensorFlow used instead the formulation just before Section 2.1 of the paper. See deepmind/optax#571 for more detail.

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

  • mu_dtype (Optional[Any]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • nesterov (bool) – Whether to use Nesterov momentum. The solver with nesterov=True is equivalent to the optax.nadam() optimizer, and described in [Dozat 2016].

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

Adamax#

optax.adamax(learning_rate, b1=0.9, b2=0.999, eps=1e-08)[source]#

A variant of the Adam optimizer that uses the infinity norm.

AdaMax is a variant of the optax.adam() optimizer. By generalizing Adam’s \(L^2\) norm to an \(L^p\) norm and taking the limit as \(p \rightarrow \infty\), we obtain a simple and stable update rule.

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\) represent the arguments b1, b2 and eps respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.

The init function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \max(\left| g_t \right| + \varepsilon, \beta_2 \cdot v_{t-1}) \\ \hat{m}_t &\leftarrow m_t / (1-\beta_1^t) \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / v_t \\ S_t &\leftarrow (m_t, v_t). \end{align*}\]

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adamax(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

References

Kingma et al, 2014: https://arxiv.org/abs/1412.6980

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the maximum of past gradients.

  • eps (float) – A small constant applied to denominator to avoid dividing by zero when rescaling.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

AdamaxW#

optax.adamaxw(learning_rate, b1=0.9, b2=0.999, eps=1e-08, weight_decay=0.0001, mask=None)[source]#

Adamax with weight decay regularization.

AdamaxW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam.

WARNING: Sometimes you may want to skip weight decay for BatchNorm scale or for the bias parameters. You can use optax.masked to make your own AdamaxW variant where additive_weight_decay is applied only to a subset of params.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adamaxw(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

References

Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101

Parameters:
  • learning_rate (base.ScalarOrSchedule) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the maximum of past gradients.

  • eps (float) – A small constant applied to denominator to avoid dividing by zero when rescaling.

  • weight_decay (float) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adamax gradient transformations are applied to all parameters.

Return type:

optax.GradientTransformation

Returns:

The corresponding GradientTransformation.

AdamW#

optax.adamw(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None, weight_decay=0.0001, mask=None, *, nesterov=False)[source]#

Adam with weight decay regularization.

AdamW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam, see [Loshchilov et al, 2019].

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments b1, b2, eps and eps_root respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function. Let \(\lambda\) be the weight decay and \(\theta_t\) the parameter vector at time \(t\).

The init function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\), the optimizer state \(S_t\) and the parameters \(\theta_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \left( \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\]

This implementation can incorporate a momentum a la Nesterov introduced by [Dozat 2016]. The resulting optimizer is then often referred as NAdamW. With the keyword argument nesterov=True, the optimizer uses Nesterov momentum, replacing the above \(\hat{m}_t\) with

\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. \]

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.adamw(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

References

Loshchilov et al, Decoupled Weight Decay Regularization, 2019

Dozat, Incorporating Nesterov Momentum into Adam, 2016

Parameters:
  • learning_rate (base.ScalarOrSchedule) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • mu_dtype (Optional[Any]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • weight_decay (float) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

  • nesterov (bool) – Whether to use Nesterov momentum. The solver with nesterov=True is equivalent to the optax.nadamw() optimizer. This modification is described in [Dozat 2016].

Return type:

optax.GradientTransformation

Returns:

The corresponding GradientTransformation.

AMSGrad#

optax.amsgrad(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None)[source]#

The AMSGrad optimiser.

The original Adam can fail to converge to the optimal solution in some cases. AMSGrad guarantees convergence by using a long-term memory of past gradients.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.amsgrad(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

References

Reddi et al, 2018: https://openreview.net/forum?id=ryQu7f-RZ

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • mu_dtype (Optional[Any]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

Fromage#

optax.fromage(learning_rate, min_norm=1e-06)[source]#

The Frobenius matched gradient descent (Fromage) optimizer.

Fromage is a learning algorithm that does not require learning rate tuning. The optimizer is based on modeling neural network gradients via deep relative trust (a distance function on deep neural networks). Fromage is similar to the LARS optimizer and can work on a range of standard neural network benchmarks, such as natural language Transformers and generative adversarial networks.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.fromage(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.37E+01
Objective function: 1.36E+01

References

Bernstein et al, 2020: https://arxiv.org/abs/2002.03432

Parameters:
  • learning_rate (float) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • min_norm (float) – A minimum value that the norm of the gradient updates and the norm of the layer parameters can be clipped to to avoid dividing by zero when computing the trust ratio (as in the LARS paper).

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

Lamb#

optax.lamb(learning_rate, b1=0.9, b2=0.999, eps=1e-06, eps_root=0.0, weight_decay=0.0, mask=None)[source]#

The LAMB optimizer.

LAMB is a general purpose layer-wise adaptive large batch optimizer designed to provide consistent training performance across a wide range of tasks, including those that use attention-based models (such as Transformers) and ResNet-50. The optimizer is able to work with small and large batch sizes. LAMB was inspired by the LARS learning algorithm.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.lamb(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.36E+01

References

You et al, 2019: https://arxiv.org/abs/1904.00962

Parameters:
  • learning_rate (base.ScalarOrSchedule) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • weight_decay (float) – Strength of the weight decay regularization.

  • mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Return type:

optax.GradientTransformation

Returns:

The corresponding GradientTransformation.

Lars#

optax.lars(learning_rate, weight_decay=0.0, weight_decay_mask=True, trust_coefficient=0.001, eps=0.0, trust_ratio_mask=True, momentum=0.9, nesterov=False)[source]#

The LARS optimizer.

LARS is a layer-wise adaptive optimizer introduced to help scale SGD to larger batch sizes. LARS later inspired the LAMB optimizer.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.lars(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01

References

You et al, 2017: https://arxiv.org/abs/1708.03888

Parameters:
  • learning_rate (base.ScalarOrSchedule) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • weight_decay (float) – Strength of the weight decay regularization.

  • weight_decay_mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

  • trust_coefficient (float) – A multiplier for the trust ratio.

  • eps (float) – Optional additive constant in the trust ratio denominator.

  • trust_ratio_mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

  • momentum (float) – Decay rate for momentum.

  • nesterov (bool) – Whether to use Nesterov momentum.

Return type:

optax.GradientTransformation

Returns:

The corresponding GradientTransformation.

Lion#

optax.lion(learning_rate, b1=0.9, b2=0.99, mu_dtype=None, weight_decay=0.001, mask=None)[source]#

The Lion optimizer.

Lion is discovered by symbolic program search. Unlike most adaptive optimizers such as AdamW, Lion only tracks momentum, making it more memory-efficient. The update of Lion is produced through the sign operation, resulting in a larger norm compared to updates produced by other optimizers such as SGD and AdamW. A suitable learning rate for Lion is typically 3-10x smaller than that for AdamW, the weight decay for Lion should be in turn 3-10x larger than that for AdamW to maintain a similar strength (lr * wd).

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.lion(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

References

Chen et al, 2023: https://arxiv.org/abs/2302.06675

Parameters:
  • learning_rate (base.ScalarOrSchedule) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – Rate to combine the momentum and the current gradient.

  • b2 (float) – Exponential decay rate to track the momentum of past gradients.

  • mu_dtype (Optional[Any]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • weight_decay (float) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

Return type:

optax.GradientTransformation

Returns:

The corresponding GradientTransformation.

Nadam#

optax.nadam(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Any | None = None, *, nesterov: bool = True) base.GradientTransformation#

The NAdam optimizer.

Nadam is a variant of optax.adam() with Nesterov’s momentum. The update rule of this solver is as follows:

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}\\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow \alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\]

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.nadam(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01

References

Dozat, Incorporating Nesterov Momentum into Adam, 2016

Added in version 0.1.9.

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 – Exponential decay rate to track the first moment of past gradients.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

  • mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Returns:

The corresponding GradientTransformation.

NadamW#

optax.nadamw(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Any | None = None, weight_decay: float = 0.0001, mask: Any | Callable[[base.Params], Any] | None = None, *, nesterov: bool = True) base.GradientTransformation#

NAdamW optimizer, implemented as part of the AdamW optimizer.

NadamW is variant of optax.adamw() with Nesterov’s momentum. Compared to AdamW, this optimizer replaces the assignment

\[\hat{m}_t \leftarrow m_t / {(1-\beta_1^t)}\]

with

\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.\]

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.nadamw(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.38E+01

References

Loshchilov et al, Decoupled Weight Decay Regularization, 2019

Dozat, Incorporating Nesterov Momentum into Adam, 2016

Added in version 0.1.9.

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 – Exponential decay rate to track the first moment of past gradients.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

Returns:

The corresponding GradientTransformation.

Noisy SGD#

optax.noisy_sgd(learning_rate, eta=0.01, gamma=0.55, seed=0)[source]#

A variant of SGD with added noise.

Noisy SGD is a variant of optax.sgd() that incorporates Gaussian noise into the updates. It has been found that adding noise to the gradients can improve both the training error and the generalization error in very deep networks.

The update \(u_t\) is modified to include this noise as follows:

\[u_t \leftarrow -\alpha_t (g_t + N(0, \sigma_t^2)), \]

where \(N(0, \sigma_t^2)\) represents Gaussian noise with zero mean and a variance of \(\sigma_t^2\).

The variance of this noise decays over time according to the formula

\[\sigma_t^2 = \frac{\eta}{(1+t)^\gamma}, \]

where \(\gamma\) is the decay rate parameter gamma and \(\eta\) represents the initial variance eta.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.noisy_sgd(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01

References

Neelakantan et al, 2014: https://arxiv.org/abs/1511.06807

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • eta (float) – Initial variance for the Gaussian noise added to gradients.

  • gamma (float) – A parameter controlling the annealing of noise over time t, the variance decays according to (1+t)**(-gamma).

  • seed (int) – Seed for the pseudo-random generation process.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

Novograd#

optax.novograd(learning_rate, b1=0.9, b2=0.25, eps=1e-06, eps_root=0.0, weight_decay=0.0)[source]#

NovoGrad optimizer.

NovoGrad is more robust to the initial learning rate and weight initialization than other methods. For example, NovoGrad works well without LR warm-up, while other methods require it. NovoGrad performs exceptionally well for large batch training, e.g. it outperforms other methods for ResNet-50 for all batches up to 32K. In addition, NovoGrad requires half the memory compared to Adam. It was introduced together with Jasper ASR model.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.novograd(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.37E+01

References

Ginsburg et al, 2019: https://arxiv.org/abs/1905.11286 Li et al, 2019: https://arxiv.org/abs/1904.03288

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – An exponential decay rate to track the first moment of past gradients.

  • b2 (float) – An exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • weight_decay (float) – Strength of the weight decay regularization.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

Optimistic GD#

optax.optimistic_gradient_descent(learning_rate, alpha=1.0, beta=1.0)[source]#

An Optimistic Gradient Descent optimizer.

Optimistic gradient descent is an approximation of extra-gradient methods which require multiple gradient calls to compute the next update. It has strong formal guarantees for last-iterate convergence in min-max games, for which standard gradient descent can oscillate or even diverge.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.optimistic_gradient_descent(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01
Objective function: 1.30E+01

References

Mokhtari et al, 2019: https://arxiv.org/abs/1901.08511v2

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • alpha (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – Coefficient for generalized OGD.

  • beta (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – Coefficient for generalized OGD negative momentum.

Return type:

GradientTransformation

Returns:

A GradientTransformation.

Polyak step-size SGD#

optax.polyak_sgd(max_learning_rate=1.0, scaling=1.0, f_min=0.0, eps=0.0)[source]#

SGD with Polyak step-size.

This solver implements the SGD with Polyak step size of (Loizou et al. 2021). It sets the step-size as

\[s \min\left\{\frac{f(x) - f^\star}{\|\nabla f(x)\|^2 + \epsilon}, \gamma_{\max}\right\}\,, \]

where \(f\) is the function from which a gradient is computed, \(\gamma_{\max}\) is a maximal acceptable learning rate set by max_learning_rate, \(\epsilon\) is a constant preventing division by zero set with eps, \(s\) scales the formula by scaling, and \(f^\star\) is a guess of the minimum value of the function set with f_min.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.polyak_sgd()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  params, opt_state = solver.update(grad, opt_state, params, value=value)
...  print('Objective function: ', f(params))
Objective function:  3.5
Objective function:  0.875
Objective function:  0.21875
Objective function:  0.0546875
Objective function:  0.013671875

Warning

This method requires knowledge of an approximate value of the of the objective function minimum, passed through the f_min argument. For models that interpolate the data, this can be set to 0 (default value). Failing to set an appropriate value for f_min can lead to divergence or convergence to a suboptimal solution.

References

Loizou et al. Stochastic polyak step-size for SGD: An adaptive learning rate for fast convergence, 2021

Berrada et al., Training neural networks for and by interpolation, 2020

Parameters:
  • max_learning_rate (float) – a maximum step size to use (defaults to 1).

  • scaling (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler (defaults to 1).

  • f_min (float) – a lower bound on the objective function (defaults to 0). Corresponds to \(f^\star\) in the formula above.

  • eps (float) – a value to add in the denominator of the update (defaults to 0).

Return type:

GradientTransformationExtraArgs

Returns:

A GradientTransformationExtraArgs, where the update function takes an additional keyword argument value containing the current value of the objective function.

RAdam#

optax.radam(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, threshold=5.0, *, nesterov=False)[source]#

The Rectified Adam optimizer.

The adaptive learning rate in Adam has undesirably large variance in early stages of training, due to the limited number of training samples used to estimate the optimizer’s statistics. Rectified Adam addresses this issue by analytically reducing the large variance.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.radam(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01

References

Liu et al, 2020: https://arxiv.org/abs/1908.03265

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • threshold (float) – Threshold for variance tractability.

  • nesterov (bool) – Whether to use Nesterov momentum.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

RMSProp#

optax.rmsprop(learning_rate, decay=0.9, eps=1e-08, initial_scale=0.0, centered=False, momentum=None, nesterov=False)[source]#

A flexible RMSProp optimizer.

RMSProp is an SGD variant with learning rate adaptation. The learning_rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimizer that can be used to switch between several of these variants.

..warning::

PyTorch and optax’s RMSprop implementations differ and could impact performance. In the denominator, optax uses \($\sqrt{v + \epsilon}$\) whereas PyTorch uses \($\sqrt{v} + \epsilon$\). See google-deepmind/optax#532 for more detail.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.rmsprop(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.37E+01
Objective function: 1.36E+01

References

Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf Graves, 2013: https://arxiv.org/abs/1308.0850

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • decay (float) – Decay used to track the magnitude of previous gradients.

  • eps (float) – A small numerical constant to avoid dividing by zero when rescaling.

  • initial_scale (float) – Initial value of accumulators tracking the magnitude of previous updates. PyTorch uses 0, TF1 uses 1. When reproducing results from a paper, verify the value used by the authors.

  • centered (bool) – Whether the second moment or the variance of the past gradients is used to rescale the latest gradients.

  • momentum (Optional[float]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov (bool) – Whether Nesterov momentum is used.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

RProp#

optax.rprop(learning_rate, eta_minus=0.5, eta_plus=1.2, min_step_size=1e-06, max_step_size=50.0)[source]#

The Rprop optimizer.

Rprop, short for resillient backpropogation, is a first order variant of gradient descent. It responds only to the sign of the gradient by increasing or decreasing the step size selected per parameter exponentially to speed up convergence and avoid oscillations.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.rprop(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.38E+01

References

PyTorch implementation:

https://pytorch.org/docs/stable/generated/torch.optim.Rprop.html

Riedmiller and Braun, 1993: https://ieeexplore.ieee.org/document/298623 Igel and HĂĽsken, 2003:

Parameters:
  • learning_rate (float) – The initial step size.

  • eta_minus (float) – Multiplicative factor for decreasing step size. This is applied when the gradient changes sign from one step to the next.

  • eta_plus (float) – Multiplicative factor for increasing step size. This is applied when the gradient has the same sign from one step to the next.

  • min_step_size (float) – Minimum allowed step size. Smaller steps will be clipped to this value.

  • max_step_size (float) – Maximum allowed step size. Larger steps will be clipped to this value.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

SGD#

optax.sgd(learning_rate, momentum=None, nesterov=False, accumulator_dtype=None)[source]#

A canonical Stochastic Gradient Descent optimizer.

This implements stochastic gradient descent. It also includes support for momentum, and Nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.

The canonical stochastic gradient descent returns an update \(u_t\) of the form

\[u_t \leftarrow -\alpha_t g_t, \]

where \(g_t\) is the gradient of the objective (potentially preprocessed by other transformations) and \(\alpha_t\) is the learning_rate at time \(t\) (constant or selected by an optax.Schedule).

Stochastic gradient descent with momentum takes two possible forms.

\[\begin{align*} m_t &\leftarrow g_t + \mu m_{t-1} \\ u_t &\leftarrow \begin{cases} -\alpha_t m_t & \text{ if } \texttt{nesterov = False} \\ -\alpha_t (g_t + \mu m_t) & \text{ if } \texttt{nesterov = True} \end{cases} \\ S_t &\leftarrow m_t, \end{align*}\]

where \(\mu\) is the momentum parameter and \(S_t\) is the state of the optimizer.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.sgd(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01

References

Sutskever et al, On the importance of initialization and momentum in deep learning, 2013

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • momentum (Optional[float]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov (bool) – Whether Nesterov momentum is used.

  • accumulator_dtype (Optional[Any]) – Optional dtype to be used for the accumulator; if None then the dtype is inferred from params and updates.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

SM3#

optax.sm3(learning_rate, momentum=0.9)[source]#

The SM3 optimizer.

SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients Method) is a memory-efficient adaptive optimizer designed to decrease memory overhead when training very large models, such as the Transformer for machine translation, BERT for language modeling, and AmoebaNet-D for image classification. SM3: 1) applies to tensors of arbitrary dimensions and any predefined cover of the parameters; 2) adapts the learning rates in an adaptive and data-driven manner (like Adagrad and unlike Adafactor); and 3) comes with rigorous convergence guarantees in stochastic convex optimization settings.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.sm3(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01

References

Anil et al, 2019: https://arxiv.org/abs/1901.11150

Parameters:
  • learning_rate (float) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • momentum (float) – Decay rate used by the momentum term (when it is not set to None, then momentum is not used at all).

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.

Yogi#

optax.yogi(learning_rate, b1=0.9, b2=0.999, eps=0.001)[source]#

The Yogi optimizer.

Yogi is an adaptive optimizer, which provides control in tuning the effective learning rate to prevent it from increasing. By doing so, it focuses on addressing the issues of convergence and generalization in exponential moving average-based adaptive methods (such as Adam and RMSprop). Yogi is a modification of Adam and uses the same parameters.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.yogi(learning_rate=0.002)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.39E+01
Objective function: 1.39E+01
Objective function: 1.39E+01

References

Zaheer et al, 2018: https://proceedings.neurips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

Return type:

GradientTransformation

Returns:

The corresponding GradientTransformation.