Optimizers#
|
The AdaBelief optimizer. |
|
The Adadelta optimizer. |
|
The Adafactor optimizer. |
|
The Adagrad optimizer. |
|
The Adam optimizer. |
|
Adam with weight decay regularization. |
|
A variant of the Adam optimizer that uses the infinity norm. |
|
Adamax with weight decay regularization. |
|
The AMSGrad optimiser. |
|
The Frobenius matched gradient descent (Fromage) optimizer. |
|
The LAMB optimizer. |
|
The LARS optimizer. |
|
The Lion optimizer. |
|
The NAdam optimizer. |
|
NAdamW optimizer, implemented as part of the AdamW optimizer. |
|
A variant of SGD with added noise. |
|
NovoGrad optimizer. |
|
An Optimistic Gradient Descent optimizer. |
|
SGD with Polyak step-size. |
|
The Rectified Adam optimizer. |
|
A flexible RMSProp optimizer. |
|
A canonical Stochastic Gradient Descent optimizer. |
|
The SM3 optimizer. |
|
The Yogi optimizer. |
AdaBelief#
- optax.adabelief(learning_rate, b1=0.9, b2=0.999, eps=1e-16, eps_root=1e-16)[source]#
The AdaBelief optimizer.
AdaBelief is an adaptive learning rate optimizer that focuses on fast convergence, generalization, and stability. It adapts the step size depending on its “belief” in the gradient direction — the optimizer adaptively scales the step size by the difference between the predicted and observed gradients. AdaBelief is a modified version of
optax.adam()
and contains the same number of parameters.Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments
b1
,b2
,eps
andeps_root
respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.The
init
function of this optimizer initializes an internal state \(S_0 := (m_0, s_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdate
function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ s_t &\leftarrow \beta_2 \cdot s_{t-1} + (1-\beta_2) \cdot (g_t - m_t)^2 + \bar{\varepsilon} \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{s}_t &\leftarrow s_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left(\sqrt{\hat{s}_t} + \varepsilon \right) \\ S_t &\leftarrow (m_t, s_t). \end{align*}\]Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adabelief(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.38E+01
References
Zhuang et al, 2020: https://arxiv.org/abs/2010.07468
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – Term added to the denominator to improve numerical stability.eps_root (
float
) – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
AdaDelta#
- optax.adadelta(learning_rate=None, rho=0.9, eps=1e-06, weight_decay=0.0, weight_decay_mask=None)[source]#
The Adadelta optimizer.
Adadelta is a stochastic gradient descent method that adapts learning rates based on a moving window of gradient updates. Adadelta is a modification of Adagrad.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> f = lambda x: jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adadelta(learning_rate=10.) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.36E+01 Objective function: 1.32E+01 Objective function: 1.29E+01 Objective function: 1.25E+01 Objective function: 1.21E+01
References
[Matthew D. Zeiler, 2012](https://arxiv.org/pdf/1212.5701.pdf)
- Parameters:
learning_rate (Optional[base.ScalarOrSchedule]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate()
.rho (float) – A coefficient used for computing a running average of squared gradients.
eps (float) – Term added to the denominator to improve numerical stability.
weight_decay (float) – Optional rate at which to decay weights.
weight_decay_mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
- Return type:
optax.GradientTransformation
- Returns:
The corresponding GradientTransformation.
AdaGrad#
- optax.adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-07)[source]#
The Adagrad optimizer.
Adagrad is an algorithm for gradient based optimization that anneals the learning rate for each parameter during the course of training.
Warning
Adagrad’s main limit is the monotonic accumulation of squared gradients in the denominator: since all terms are >0, the sum keeps growing during training and the learning rate eventually becomes vanishingly small.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adagrad(learning_rate=1.0) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 5.01E+00 Objective function: 2.40E+00 Objective function: 1.25E+00 Objective function: 6.86E-01 Objective function: 3.85E-01
References
Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.initial_accumulator_value (
float
) – Initial value for the accumulator.eps (
float
) – A small constant applied to denominator inside of the square root (as in RMSProp) to avoid dividing by zero when rescaling.
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
AdaFactor#
- optax.adafactor(learning_rate=None, min_dim_size_to_factor=128, decay_rate=0.8, decay_offset=0, multiply_by_parameter_scale=True, clipping_threshold=1.0, momentum=None, dtype_momentum=<class 'jax.numpy.float32'>, weight_decay_rate=None, eps=1e-30, factored=True, weight_decay_mask=None)[source]#
The Adafactor optimizer.
Adafactor is an adaptive learning rate optimizer that focuses on fast training of large scale neural networks. It saves memory by using a factored estimate of the second order moments used to scale gradients.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adafactor(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.36E+01
References
Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235
- Parameters:
learning_rate (Optional[base.ScalarOrSchedule]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate()
. Note that the natural scale for Adafactor’s LR is markedly different from Adam, one doesn’t use the 1/sqrt(hidden) correction for this optim with attention-based models.min_dim_size_to_factor (int) – Only factor the statistics if two array dimensions have at least this size.
decay_rate (float) – Controls second-moment exponential decay schedule.
decay_offset (int) – For fine-tuning, one may set this to the starting step number of the fine-tuning phase.
multiply_by_parameter_scale (float) – If True, then scale learning_rate by parameter norm. If False, provided learning_rate is absolute step size.
clipping_threshold (Optional[float]) – Optional clipping threshold. Must be >= 1. If None, clipping is disabled.
momentum (Optional[float]) – Optional value between 0 and 1, enables momentum and uses extra memory if non-None! None by default.
dtype_momentum (Any) – Data type of momentum buffers.
weight_decay_rate (Optional[float]) – Optional rate at which to decay weights.
eps (float) – Regularization constant for root mean squared gradient.
factored (bool) – Whether to use factored second-moment estimates.
weight_decay_mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
- Return type:
optax.GradientTransformation
- Returns:
The corresponding GradientTransformation.
Adam#
- optax.adam(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None, *, nesterov=False)[source]#
The Adam optimizer.
Adam is an SGD variant with gradient scaling adaptation. The scaling used for each parameter is computed from estimates of first and second-order moments of the gradients (using suitable exponential moving averages).
Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments
b1
,b2
,eps
andeps_root
respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.The
init
function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdate
function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\]With the keyword argument nesterov=True, the optimizer uses Nesterov momentum, replacing the above \(\hat{m}_t\) with
\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. \]Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adam(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01
References
Kingma et al, Adam: A Method for Stochastic Optimization, 2014
Dozat, Incorporating Nesterov Momentum into Adam, 2016
Warning
PyTorch and optax’s implementation follow Algorithm 1 of [Kingma et al. 2014]. Note that TensorFlow used instead the formulation just before Section 2.1 of the paper. See deepmind/optax#571 for more detail.
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.mu_dtype (
Optional
[Any
]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.nesterov (
bool
) – Whether to use Nesterov momentum. The solver with nesterov=True is equivalent to theoptax.nadam()
optimizer, and described in [Dozat 2016].
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
See also
Adamax#
- optax.adamax(learning_rate, b1=0.9, b2=0.999, eps=1e-08)[source]#
A variant of the Adam optimizer that uses the infinity norm.
AdaMax is a variant of the
optax.adam()
optimizer. By generalizing Adam’s \(L^2\) norm to an \(L^p\) norm and taking the limit as \(p \rightarrow \infty\), we obtain a simple and stable update rule.Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\) represent the arguments
b1
,b2
andeps
respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.The
init
function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdate
function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \max(\left| g_t \right| + \varepsilon, \beta_2 \cdot v_{t-1}) \\ \hat{m}_t &\leftarrow m_t / (1-\beta_1^t) \\ u_t &\leftarrow -\alpha_t \cdot \hat{m}_t / v_t \\ S_t &\leftarrow (m_t, v_t). \end{align*}\]Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adamax(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01
References
Kingma et al, 2014: https://arxiv.org/abs/1412.6980
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the maximum of past gradients.eps (
float
) – A small constant applied to denominator to avoid dividing by zero when rescaling.
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
See also
AdamaxW#
- optax.adamaxw(learning_rate, b1=0.9, b2=0.999, eps=1e-08, weight_decay=0.0001, mask=None)[source]#
Adamax with weight decay regularization.
AdamaxW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam.
WARNING: Sometimes you may want to skip weight decay for BatchNorm scale or for the bias parameters. You can use optax.masked to make your own AdamaxW variant where additive_weight_decay is applied only to a subset of params.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adamaxw(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01
References
Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101
- Parameters:
learning_rate (base.ScalarOrSchedule) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate()
.b1 (float) – Exponential decay rate to track the first moment of past gradients.
b2 (float) – Exponential decay rate to track the maximum of past gradients.
eps (float) – A small constant applied to denominator to avoid dividing by zero when rescaling.
weight_decay (float) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.
mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adamax gradient transformations are applied to all parameters.
- Return type:
optax.GradientTransformation
- Returns:
The corresponding GradientTransformation.
See also
AdamW#
- optax.adamw(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None, weight_decay=0.0001, mask=None, *, nesterov=False)[source]#
Adam with weight decay regularization.
AdamW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam, see [Loshchilov et al, 2019].
Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments
b1
,b2
,eps
andeps_root
respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function. Let \(\lambda\) be the weight decay and \(\theta_t\) the parameter vector at time \(t\).The
init
function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdate
function of this optimizer takes as arguments the incoming gradients \(g_t\), the optimizer state \(S_t\) and the parameters \(\theta_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot \left( \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right) + \lambda \theta_{t} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\]This implementation can incorporate a momentum a la Nesterov introduced by [Dozat 2016]. The resulting optimizer is then often referred as NAdamW. With the keyword argument nesterov=True, the optimizer uses Nesterov momentum, replacing the above \(\hat{m}_t\) with
\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. \]Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.adamw(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01
References
Loshchilov et al, Decoupled Weight Decay Regularization, 2019
Dozat, Incorporating Nesterov Momentum into Adam, 2016
- Parameters:
learning_rate (base.ScalarOrSchedule) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate()
.b1 (float) – Exponential decay rate to track the first moment of past gradients.
b2 (float) – Exponential decay rate to track the second moment of past gradients.
eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.
mu_dtype (Optional[Any]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
weight_decay (float) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.
mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.
nesterov (bool) – Whether to use Nesterov momentum. The solver with nesterov=True is equivalent to the
optax.nadamw()
optimizer. This modification is described in [Dozat 2016].
- Return type:
optax.GradientTransformation
- Returns:
The corresponding GradientTransformation.
See also
AMSGrad#
- optax.amsgrad(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None)[source]#
The AMSGrad optimiser.
The original Adam can fail to converge to the optimal solution in some cases. AMSGrad guarantees convergence by using a long-term memory of past gradients.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.amsgrad(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01
References
Reddi et al, 2018: https://openreview.net/forum?id=ryQu7f-RZ
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.mu_dtype (
Optional
[Any
]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
Fromage#
- optax.fromage(learning_rate, min_norm=1e-06)[source]#
The Frobenius matched gradient descent (Fromage) optimizer.
Fromage is a learning algorithm that does not require learning rate tuning. The optimizer is based on modeling neural network gradients via deep relative trust (a distance function on deep neural networks). Fromage is similar to the LARS optimizer and can work on a range of standard neural network benchmarks, such as natural language Transformers and generative adversarial networks.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.fromage(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.37E+01 Objective function: 1.36E+01
References
Bernstein et al, 2020: https://arxiv.org/abs/2002.03432
- Parameters:
learning_rate (
float
) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.min_norm (
float
) – A minimum value that the norm of the gradient updates and the norm of the layer parameters can be clipped to to avoid dividing by zero when computing the trust ratio (as in the LARS paper).
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
Lamb#
- optax.lamb(learning_rate, b1=0.9, b2=0.999, eps=1e-06, eps_root=0.0, weight_decay=0.0, mask=None)[source]#
The LAMB optimizer.
LAMB is a general purpose layer-wise adaptive large batch optimizer designed to provide consistent training performance across a wide range of tasks, including those that use attention-based models (such as Transformers) and ResNet-50. The optimizer is able to work with small and large batch sizes. LAMB was inspired by the LARS learning algorithm.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.lamb(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.36E+01
References
You et al, 2019: https://arxiv.org/abs/1904.00962
- Parameters:
learning_rate (base.ScalarOrSchedule) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate()
.b1 (float) – Exponential decay rate to track the first moment of past gradients.
b2 (float) – Exponential decay rate to track the second moment of past gradients.
eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.
weight_decay (float) – Strength of the weight decay regularization.
mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
- Return type:
optax.GradientTransformation
- Returns:
The corresponding GradientTransformation.
Lars#
- optax.lars(learning_rate, weight_decay=0.0, weight_decay_mask=True, trust_coefficient=0.001, eps=0.0, trust_ratio_mask=True, momentum=0.9, nesterov=False)[source]#
The LARS optimizer.
LARS is a layer-wise adaptive optimizer introduced to help scale SGD to larger batch sizes. LARS later inspired the LAMB optimizer.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.lars(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01
References
You et al, 2017: https://arxiv.org/abs/1708.03888
- Parameters:
learning_rate (base.ScalarOrSchedule) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate()
.weight_decay (float) – Strength of the weight decay regularization.
weight_decay_mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
trust_coefficient (float) – A multiplier for the trust ratio.
eps (float) – Optional additive constant in the trust ratio denominator.
trust_ratio_mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
momentum (float) – Decay rate for momentum.
nesterov (bool) – Whether to use Nesterov momentum.
- Return type:
optax.GradientTransformation
- Returns:
The corresponding GradientTransformation.
Lion#
- optax.lion(learning_rate, b1=0.9, b2=0.99, mu_dtype=None, weight_decay=0.001, mask=None)[source]#
The Lion optimizer.
Lion is discovered by symbolic program search. Unlike most adaptive optimizers such as AdamW, Lion only tracks momentum, making it more memory-efficient. The update of Lion is produced through the sign operation, resulting in a larger norm compared to updates produced by other optimizers such as SGD and AdamW. A suitable learning rate for Lion is typically 3-10x smaller than that for AdamW, the weight decay for Lion should be in turn 3-10x larger than that for AdamW to maintain a similar strength (lr * wd).
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.lion(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01
References
Chen et al, 2023: https://arxiv.org/abs/2302.06675
- Parameters:
learning_rate (base.ScalarOrSchedule) – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate()
.b1 (float) – Rate to combine the momentum and the current gradient.
b2 (float) – Exponential decay rate to track the momentum of past gradients.
mu_dtype (Optional[Any]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
weight_decay (float) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.
mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.
- Return type:
optax.GradientTransformation
- Returns:
The corresponding GradientTransformation.
Nadam#
- optax.nadam(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Any | None = None, *, nesterov: bool = True) base.GradientTransformation #
The NAdam optimizer.
Nadam is a variant of
optax.adam()
with Nesterov’s momentum. The update rule of this solver is as follows:\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}\\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow \alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*}\]Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.nadam(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.38E+01
References
Dozat, Incorporating Nesterov Momentum into Adam, 2016
Added in version 0.1.9.
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate()
.b1 – Exponential decay rate to track the first moment of past gradients.
b2 – Exponential decay rate to track the second moment of past gradients.
eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.
mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
- Returns:
The corresponding GradientTransformation.
See also
NadamW#
- optax.nadamw(learning_rate: base.ScalarOrSchedule, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-08, eps_root: float = 0.0, mu_dtype: Any | None = None, weight_decay: float = 0.0001, mask: Any | Callable[[base.Params], Any] | None = None, *, nesterov: bool = True) base.GradientTransformation #
NAdamW optimizer, implemented as part of the AdamW optimizer.
NadamW is variant of
optax.adamw()
with Nesterov’s momentum. Compared to AdamW, this optimizer replaces the assignment\[\hat{m}_t \leftarrow m_t / {(1-\beta_1^t)}\]with
\[\hat{m}_t \leftarrow \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}.\]Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.nadamw(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.38E+01
References
Loshchilov et al, Decoupled Weight Decay Regularization, 2019
Dozat, Incorporating Nesterov Momentum into Adam, 2016
Added in version 0.1.9.
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate()
.b1 – Exponential decay rate to track the first moment of past gradients.
b2 – Exponential decay rate to track the second moment of past gradients.
eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.
mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.
mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.
- Returns:
The corresponding GradientTransformation.
See also
Noisy SGD#
- optax.noisy_sgd(learning_rate, eta=0.01, gamma=0.55, seed=0)[source]#
A variant of SGD with added noise.
Noisy SGD is a variant of
optax.sgd()
that incorporates Gaussian noise into the updates. It has been found that adding noise to the gradients can improve both the training error and the generalization error in very deep networks.The update \(u_t\) is modified to include this noise as follows:
\[u_t \leftarrow -\alpha_t (g_t + N(0, \sigma_t^2)), \]where \(N(0, \sigma_t^2)\) represents Gaussian noise with zero mean and a variance of \(\sigma_t^2\).
The variance of this noise decays over time according to the formula
\[\sigma_t^2 = \frac{\eta}{(1+t)^\gamma}, \]where \(\gamma\) is the decay rate parameter
gamma
and \(\eta\) represents the initial varianceeta
.Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.noisy_sgd(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.35E+01 Objective function: 1.33E+01 Objective function: 1.32E+01
References
Neelakantan et al, 2014: https://arxiv.org/abs/1511.06807
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.eta (
float
) – Initial variance for the Gaussian noise added to gradients.gamma (
float
) – A parameter controlling the annealing of noise over timet
, the variance decays according to(1+t)**(-gamma)
.seed (
int
) – Seed for the pseudo-random generation process.
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
Novograd#
- optax.novograd(learning_rate, b1=0.9, b2=0.25, eps=1e-06, eps_root=0.0, weight_decay=0.0)[source]#
NovoGrad optimizer.
NovoGrad is more robust to the initial learning rate and weight initialization than other methods. For example, NovoGrad works well without LR warm-up, while other methods require it. NovoGrad performs exceptionally well for large batch training, e.g. it outperforms other methods for ResNet-50 for all batches up to 32K. In addition, NovoGrad requires half the memory compared to Adam. It was introduced together with Jasper ASR model.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.novograd(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.37E+01
References
Ginsburg et al, 2019: https://arxiv.org/abs/1905.11286 Li et al, 2019: https://arxiv.org/abs/1904.03288
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.b1 (
float
) – An exponential decay rate to track the first moment of past gradients.b2 (
float
) – An exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.weight_decay (
float
) – Strength of the weight decay regularization.
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
Optimistic GD#
- optax.optimistic_gradient_descent(learning_rate, alpha=1.0, beta=1.0)[source]#
An Optimistic Gradient Descent optimizer.
Optimistic gradient descent is an approximation of extra-gradient methods which require multiple gradient calls to compute the next update. It has strong formal guarantees for last-iterate convergence in min-max games, for which standard gradient descent can oscillate or even diverge.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.optimistic_gradient_descent(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.37E+01 Objective function: 1.35E+01 Objective function: 1.33E+01 Objective function: 1.32E+01 Objective function: 1.30E+01
References
Mokhtari et al, 2019: https://arxiv.org/abs/1901.08511v2
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.alpha (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – Coefficient for generalized OGD.beta (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – Coefficient for generalized OGD negative momentum.
- Return type:
GradientTransformation
- Returns:
A GradientTransformation.
Polyak step-size SGD#
- optax.polyak_sgd(max_learning_rate=1.0, scaling=1.0, f_min=0.0, eps=0.0)[source]#
SGD with Polyak step-size.
This solver implements the SGD with Polyak step size of (Loizou et al. 2021). It sets the step-size as
\[s \min\left\{\frac{f(x) - f^\star}{\|\nabla f(x)\|^2 + \epsilon}, \gamma_{\max}\right\}\,, \]where \(f\) is the function from which a gradient is computed, \(\gamma_{\max}\) is a maximal acceptable learning rate set by
max_learning_rate
, \(\epsilon\) is a constant preventing division by zero set witheps
, \(s\) scales the formula byscaling
, and \(f^\star\) is a guess of the minimum value of the function set withf_min
.Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.polyak_sgd() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 3.5 Objective function: 0.875 Objective function: 0.21875 Objective function: 0.0546875 Objective function: 0.013671875
Warning
This method requires knowledge of an approximate value of the of the objective function minimum, passed through the
f_min
argument. For models that interpolate the data, this can be set to 0 (default value). Failing to set an appropriate value forf_min
can lead to divergence or convergence to a suboptimal solution.References
Loizou et al. Stochastic polyak step-size for SGD: An adaptive learning rate for fast convergence, 2021
Berrada et al., Training neural networks for and by interpolation, 2020
- Parameters:
max_learning_rate (
float
) – a maximum step size to use (defaults to 1).scaling (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler (defaults to 1).f_min (
float
) – a lower bound on the objective function (defaults to 0). Corresponds to \(f^\star\) in the formula above.eps (
float
) – a value to add in the denominator of the update (defaults to 0).
- Return type:
GradientTransformationExtraArgs
- Returns:
A
GradientTransformationExtraArgs
, where theupdate
function takes an additional keyword argumentvalue
containing the current value of the objective function.
RAdam#
- optax.radam(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, threshold=5.0, *, nesterov=False)[source]#
The Rectified Adam optimizer.
The adaptive learning rate in Adam has undesirably large variance in early stages of training, due to the limited number of training samples used to estimate the optimizer’s statistics. Rectified Adam addresses this issue by analytically reducing the large variance.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.radam(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.35E+01 Objective function: 1.33E+01 Objective function: 1.32E+01
References
Liu et al, 2020: https://arxiv.org/abs/1908.03265
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.threshold (
float
) – Threshold for variance tractability.nesterov (
bool
) – Whether to use Nesterov momentum.
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
RMSProp#
- optax.rmsprop(learning_rate, decay=0.9, eps=1e-08, initial_scale=0.0, centered=False, momentum=None, nesterov=False)[source]#
A flexible RMSProp optimizer.
RMSProp is an SGD variant with learning rate adaptation. The learning_rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimizer that can be used to switch between several of these variants.
- ..warning::
PyTorch and optax’s RMSprop implementations differ and could impact performance. In the denominator, optax uses \($\sqrt{v + \epsilon}$\) whereas PyTorch uses \($\sqrt{v} + \epsilon$\). See google-deepmind/optax#532 for more detail.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.rmsprop(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.37E+01 Objective function: 1.36E+01
References
Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf Graves, 2013: https://arxiv.org/abs/1308.0850
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.decay (
float
) – Decay used to track the magnitude of previous gradients.eps (
float
) – A small numerical constant to avoid dividing by zero when rescaling.initial_scale (
float
) – Initial value of accumulators tracking the magnitude of previous updates. PyTorch uses 0, TF1 uses 1. When reproducing results from a paper, verify the value used by the authors.centered (
bool
) – Whether the second moment or the variance of the past gradients is used to rescale the latest gradients.momentum (
Optional
[float
]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.nesterov (
bool
) – Whether Nesterov momentum is used.
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
RProp#
- optax.rprop(learning_rate, eta_minus=0.5, eta_plus=1.2, min_step_size=1e-06, max_step_size=50.0)[source]#
The Rprop optimizer.
Rprop, short for resillient backpropogation, is a first order variant of gradient descent. It responds only to the sign of the gradient by increasing or decreasing the step size selected per parameter exponentially to speed up convergence and avoid oscillations.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.rprop(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.38E+01
References
- PyTorch implementation:
https://pytorch.org/docs/stable/generated/torch.optim.Rprop.html
Riedmiller and Braun, 1993: https://ieeexplore.ieee.org/document/298623 Igel and HĂĽsken, 2003:
- Parameters:
learning_rate (
float
) – The initial step size.eta_minus (
float
) – Multiplicative factor for decreasing step size. This is applied when the gradient changes sign from one step to the next.eta_plus (
float
) – Multiplicative factor for increasing step size. This is applied when the gradient has the same sign from one step to the next.min_step_size (
float
) – Minimum allowed step size. Smaller steps will be clipped to this value.max_step_size (
float
) – Maximum allowed step size. Larger steps will be clipped to this value.
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
SGD#
- optax.sgd(learning_rate, momentum=None, nesterov=False, accumulator_dtype=None)[source]#
A canonical Stochastic Gradient Descent optimizer.
This implements stochastic gradient descent. It also includes support for momentum, and Nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.
The canonical stochastic gradient descent returns an update \(u_t\) of the form
\[u_t \leftarrow -\alpha_t g_t, \]where \(g_t\) is the gradient of the objective (potentially preprocessed by other transformations) and \(\alpha_t\) is the
learning_rate
at time \(t\) (constant or selected by anoptax.Schedule
).Stochastic gradient descent with momentum takes two possible forms.
\[\begin{align*} m_t &\leftarrow g_t + \mu m_{t-1} \\ u_t &\leftarrow \begin{cases} -\alpha_t m_t & \text{ if } \texttt{nesterov = False} \\ -\alpha_t (g_t + \mu m_t) & \text{ if } \texttt{nesterov = True} \end{cases} \\ S_t &\leftarrow m_t, \end{align*}\]where \(\mu\) is the
momentum
parameter and \(S_t\) is the state of the optimizer.Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.sgd(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.38E+01 Objective function: 1.37E+01 Objective function: 1.35E+01 Objective function: 1.33E+01 Objective function: 1.32E+01
References
Sutskever et al, On the importance of initialization and momentum in deep learning, 2013
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.momentum (
Optional
[float
]) – Decay rate used by the momentum term, when it is set toNone
, then momentum is not used at all.nesterov (
bool
) – Whether Nesterov momentum is used.accumulator_dtype (
Optional
[Any
]) – Optionaldtype
to be used for the accumulator; ifNone
then thedtype
is inferred fromparams
andupdates
.
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
SM3#
- optax.sm3(learning_rate, momentum=0.9)[source]#
The SM3 optimizer.
SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients Method) is a memory-efficient adaptive optimizer designed to decrease memory overhead when training very large models, such as the Transformer for machine translation, BERT for language modeling, and AmoebaNet-D for image classification. SM3: 1) applies to tensors of arbitrary dimensions and any predefined cover of the parameters; 2) adapts the learning rates in an adaptive and data-driven manner (like Adagrad and unlike Adafactor); and 3) comes with rigorous convergence guarantees in stochastic convex optimization settings.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.sm3(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01
References
Anil et al, 2019: https://arxiv.org/abs/1901.11150
- Parameters:
learning_rate (
float
) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.momentum (
float
) – Decay rate used by the momentum term (when it is not set to None, then momentum is not used at all).
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.
Yogi#
- optax.yogi(learning_rate, b1=0.9, b2=0.999, eps=0.001)[source]#
The Yogi optimizer.
Yogi is an adaptive optimizer, which provides control in tuning the effective learning rate to prevent it from increasing. By doing so, it focuses on addressing the issues of convergence and generalization in exponential moving average-based adaptive methods (such as Adam and RMSprop). Yogi is a modification of Adam and uses the same parameters.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.yogi(learning_rate=0.002) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.39E+01 Objective function: 1.39E+01 Objective function: 1.39E+01
References
Zaheer et al, 2018: https://proceedings.neurips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A global scaling factor, either fixed or evolving along iterations with a scheduler, seeoptax.scale_by_learning_rate()
.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
- Return type:
GradientTransformation
- Returns:
The corresponding GradientTransformation.