Common Optimizers
Contents
Common Optimizers#

The AdaBelief optimizer. 

The Adafactor optimizer. 

The Adagrad optimizer. 

The classic Adam optimizer. 

Adam with weight decay regularization. 

A variant of the Adam optimizer that uses the infinity norm. 

Adamax with weight decay regularization. 

The AMSGrad optimiser. 

The Frobenius matched gradient descent (Fromage) optimizer. 

The LAMB optimizer. 

The LARS optimizer. 

The Lion optimizer. 

A variant of SGD with added noise. 

NovoGrad optimizer. 

An Optimistic Gradient Descent optimizer. 

The DPSGD optimizer. 

The Rectified Adam optimizer. 

A flexible RMSProp optimizer. 

A canonical Stochastic Gradient Descent optimizer. 

The SM3 optimizer. 

The Yogi optimizer. 
AdaBelief#
 optax.adabelief(learning_rate, b1=0.9, b2=0.999, eps=1e16, eps_root=1e16)[source]#
The AdaBelief optimizer.
AdaBelief is an adaptive learning rate optimizer that focuses on fast convergence, generalization, and stability. It adapts the step size depending on its “belief” in the gradient direction — the optimizer adaptively scales the step size by the difference between the predicted and observed gradients. AdaBelief is a modified version of Adam and contains the same number of parameters.
References
Zhuang et al, 2020: https://arxiv.org/abs/2010.07468
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – Term added to the denominator to improve numerical stability.eps_root (
float
) – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for metalearning), this must be nonzero.
 Return type
 Returns
The corresponding GradientTransformation.
AdaGrad#
 optax.adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e07)[source]#
The Adagrad optimizer.
Adagrad is an algorithm for gradient based optimization that anneals the learning rate for each parameter during the course of training.
WARNING: Adagrad’s main limit is the monotonic accumulation of squared gradients in the denominator: since all terms are >0, the sum keeps growing during training and the learning rate eventually becomes vanishingly small.
References
Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.initial_accumulator_value (
float
) – Initial value for the accumulator.eps (
float
) – A small constant applied to denominator inside of the square root (as in RMSProp) to avoid dividing by zero when rescaling.
 Return type
 Returns
The corresponding GradientTransformation.
AdaFactor#
 optax.adafactor(learning_rate=None, min_dim_size_to_factor=128, decay_rate=0.8, decay_offset=0, multiply_by_parameter_scale=True, clipping_threshold=1.0, momentum=None, dtype_momentum=<class 'jax.numpy.float32'>, weight_decay_rate=None, eps=1e30, factored=True, weight_decay_mask=None)[source]#
The Adafactor optimizer.
Adafactor is an adaptive learning rate optimizer that focuses on fast training of large scale neural networks. It saves memory by using a factored estimate of the second order moments used to scale gradients.
References
Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],None
]) – A fixed global scaling factor. Note: the natural scale for Adafactor’s LR is markedly different from Adam, one doesn’t use the 1/sqrt(hidden) correction for this optim with attentionbased models.min_dim_size_to_factor (
int
) – Only factor the statistics if two array dimensions have at least this size.decay_rate (
float
) – Controls secondmoment exponential decay schedule.decay_offset (
int
) – For finetuning, one may set this to the starting step number of the finetuning phase.multiply_by_parameter_scale (
float
) – If True, then scale learning_rate by parameter norm. If False, provided learning_rate is absolute step size.clipping_threshold (
Optional
[float
]) – Optional clipping threshold. Must be >= 1. If None, clipping is disabled.momentum (
Optional
[float
]) – Optional value between 0 and 1, enables momentum and uses extra memory if nonNone! None by default.dtype_momentum (
Any
) – Data type of momentum buffers.weight_decay_rate (
Optional
[float
]) – Optional rate at which to decay weights.eps (
float
) – Regularization constant for root mean squared gradient.factored (
bool
) – Whether to use factored secondmoment estimates.weight_decay_mask (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]],Any
],None
]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
 Return type
 Returns
The corresponding GradientTransformation.
Adam#
 optax.adam(learning_rate, b1=0.9, b2=0.999, eps=1e08, eps_root=0.0, mu_dtype=None)[source]#
The classic Adam optimizer.
Adam is an SGD variant with gradient scaling adaptation. The scaling used for each parameter is computed from estimates of first and secondorder moments of the gradients (using suitable exponential moving averages).
Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments
b1
,b2
,eps
andeps_root
respectievly. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.The
init
function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdate
function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t1} + (1\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t1} + (1\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1\beta_2^t)} \\ u_t &\leftarrow \alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*} \]References
Kingma et al, 2014: https://arxiv.org/abs/1412.6980
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta)gradients through Adam.mu_dtype (
Optional
[Any
]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
 Return type
 Returns
The corresponding GradientTransformation.
Adamax#
 optax.adamax(learning_rate, b1=0.9, b2=0.999, eps=1e08)[source]#
A variant of the Adam optimizer that uses the infinity norm.
References
Kingma et al, 2014: https://arxiv.org/abs/1412.6980
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the maximum of past gradients.eps (
float
) – A small constant applied to denominator to avoid dividing by zero when rescaling.
 Return type
 Returns
The corresponding GradientTransformation.
AdamaxW#
 optax.adamaxw(learning_rate, b1=0.9, b2=0.999, eps=1e08, weight_decay=0.0001, mask=None)[source]#
Adamax with weight decay regularization.
AdamaxW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam.
WARNING: Sometimes you may want to skip weight decay for BatchNorm scale or for the bias parameters. You can use optax.masked to make your own AdamaxW variant where additive_weight_decay is applied only to a subset of params.
References
Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the maximum of past gradients.eps (
float
) – A small constant applied to denominator to avoid dividing by zero when rescaling.weight_decay (
float
) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.mask (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]],Any
],None
]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adamax gradient transformations are applied to all parameters.
 Return type
 Returns
The corresponding GradientTransformation.
AdamW#
 optax.adamw(learning_rate, b1=0.9, b2=0.999, eps=1e08, eps_root=0.0, mu_dtype=None, weight_decay=0.0001, mask=None)[source]#
Adam with weight decay regularization.
AdamW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam.
References
Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta)gradients through Adam.mu_dtype (
Optional
[Any
]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.weight_decay (
float
) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.mask (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]],Any
],None
]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.
 Return type
 Returns
The corresponding GradientTransformation.
AMSGrad#
 optax.amsgrad(learning_rate, b1=0.9, b2=0.999, eps=1e08, eps_root=0.0, mu_dtype=None)[source]#
The AMSGrad optimiser.
The original Adam can fail to converge to the optimal solution in some cases. AMSGrad guarantees convergence by using a longterm memory of past gradients.
References
Reddi et al, 2018: https://openreview.net/forum?id=ryQu7fRZ
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta)gradients through Adam.mu_dtype (
Optional
[Any
]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
 Return type
 Returns
The corresponding GradientTransformation.
Fromage#
 optax.fromage(learning_rate, min_norm=1e06)[source]#
The Frobenius matched gradient descent (Fromage) optimizer.
Fromage is a learning algorithm that does not require learning rate tuning. The optimizer is based on modeling neural network gradients via deep relative trust (a distance function on deep neural networks). Fromage is similar to the LARS optimizer and can work on a range of standard neural network benchmarks, such as natural language Transformers and generative adversarial networks.
References
Bernstein et al, 2020: https://arxiv.org/abs/2002.03432
 Parameters
learning_rate (
float
) – A fixed global scaling factor.min_norm (
float
) – A minimum value that the norm of the gradient updates and the norm of the layer parameters can be clipped to to avoid dividing by zero when computing the trust ratio (as in the LARS paper).
 Return type
 Returns
The corresponding GradientTransformation.
Lamb#
 optax.lamb(learning_rate, b1=0.9, b2=0.999, eps=1e06, eps_root=0.0, weight_decay=0.0, mask=None)[source]#
The LAMB optimizer.
LAMB is a general purpose layerwise adaptive large batch optimizer designed to provide consistent training performance across a wide range of tasks, including those that use attentionbased models (such as Transformers) and ResNet50. The optimizer is able to work with small and large batch sizes. LAMB was inspired by the LARS learning algorithm.
References
You et al, 2019: https://arxiv.org/abs/1904.00962
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta)gradients through Adam.weight_decay (
float
) – Strength of the weight decay regularization.mask (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]],Any
],None
]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
 Return type
 Returns
The corresponding GradientTransformation.
Lars#
 optax.lars(learning_rate, weight_decay=0.0, weight_decay_mask=True, trust_coefficient=0.001, eps=0.0, trust_ratio_mask=True, momentum=0.9, nesterov=False)[source]#
The LARS optimizer.
LARS is a layerwise adaptive optimizer introduced to help scale SGD to larger batch sizes. LARS later inspired the LAMB optimizer.
References
You et al, 2017: https://arxiv.org/abs/1708.03888
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.weight_decay (
float
) – Strength of the weight decay regularization.weight_decay_mask (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]],Any
],None
]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.trust_coefficient (
float
) – A multiplier for the trust ratio.eps (
float
) – Optional additive constant in the trust ratio denominator.trust_ratio_mask (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]],Any
],None
]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.momentum (
float
) – Decay rate for momentum.nesterov (
bool
) – Whether to use Nesterov momentum.
 Return type
 Returns
The corresponding GradientTransformation.
Lion#
 optax.lion(learning_rate, b1=0.9, b2=0.99, mu_dtype=None, weight_decay=0.001, mask=None)[source]#
The Lion optimizer.
Lion is discovered by symbolic program search. Unlike most adaptive optimizers such as AdamW, Lion only tracks momentum, making it more memoryefficient. The update of Lion is produced through the sign operation, resulting in a larger norm compared to updates produced by other optimizers such as SGD and AdamW. A suitable learning rate for Lion is typically 310x smaller than that for AdamW, the weight decay for Lion should be in turn 310x larger than that for AdamW to maintain a similar strength (lr * wd).
References
Chen et al, 2023: https://arxiv.org/abs/2302.06675
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.b1 (
float
) – Rate to combine the momentum and the current gradient.b2 (
float
) – Exponential decay rate to track the momentum of past gradients.mu_dtype (
Optional
[Any
]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.weight_decay (
float
) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.mask (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]],Any
],None
]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.
 Return type
 Returns
The corresponding GradientTransformation.
SM3#
 optax.sm3(learning_rate, momentum=0.9)[source]#
The SM3 optimizer.
SM3 (Squareroot of Minima of Sums of Maxima of Squaredgradients Method) is a memoryefficient adaptive optimizer designed to decrease memory overhead when training very large models, such as the Transformer for machine translation, BERT for language modeling, and AmoebaNetD for image classification. SM3: 1) applies to tensors of arbitrary dimensions and any predefined cover of the parameters; 2) adapts the learning rates in an adaptive and datadriven manner (like Adagrad and unlike Adafactor); and 3) comes with rigorous convergence guarantees in stochastic convex optimization settings.
References
Anil et al, 2019: https://arxiv.org/abs/1901.11150
 Parameters
learning_rate (
float
) – A fixed global scaling factor.momentum (
float
) – Decay rate used by the momentum term (when it is not set to None, then momentum is not used at all).
 Return type
 Returns
The corresponding GradientTransformation.
Noisy SGD#
 optax.noisy_sgd(learning_rate, eta=0.01, gamma=0.55, seed=0)[source]#
A variant of SGD with added noise.
It has been found that adding noise to the gradients can improve both the training error and the generalization error in very deep networks.
References
Neelakantan et al, 2014: https://arxiv.org/abs/1511.06807
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.eta (
float
) – Initial variance for the Gaussian noise added to gradients.gamma (
float
) – A parameter controlling the annealing of noise over time, the variance decays according to (1+t)^gamma.seed (
int
) – Seed for the pseudorandom generation process.
 Return type
 Returns
The corresponding GradientTransformation.
Novograd#
 optax.novograd(learning_rate, b1=0.9, b2=0.25, eps=1e06, eps_root=0.0, weight_decay=0.0)[source]#
NovoGrad optimizer.
NovoGrad is more robust to the initial learning rate and weight initialization than other methods. For example, NovoGrad works well without LR warmup, while other methods require it. NovoGrad performs exceptionally well for large batch training, e.g. it outperforms other methods for ResNet50 for all batches up to 32K. In addition, NovoGrad requires half the memory compared to Adam. It was introduced together with Jasper ASR model.
References
Ginsburg et al, 2019: https://arxiv.org/abs/1905.11286 Li et al, 2019: https://arxiv.org/abs/1904.03288
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.b1 (
float
) – An exponential decay rate to track the first moment of past gradients.b2 (
float
) – An exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta)gradients through Adam.weight_decay (
float
) – Strength of the weight decay regularization.
 Return type
 Returns
The corresponding GradientTransformation.
Optimistic GD#
 optax.optimistic_gradient_descent(learning_rate, alpha=1.0, beta=1.0)[source]#
An Optimistic Gradient Descent optimizer.
Optimistic gradient descent is an approximation of extragradient methods which require multiple gradient calls to compute the next update. It has strong formal guarantees for lastiterate convergence in minmax games, for which standard gradient descent can oscillate or even diverge.
References
Mokhtari et al, 2019: https://arxiv.org/abs/1901.08511v2
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.alpha (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – Coefficient for generalized OGD.beta (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – Coefficient for generalized OGD negative momentum.
 Return type
 Returns
A GradientTransformation.
Differentially Private SGD#
 optax.dpsgd(learning_rate, l2_norm_clip, noise_multiplier, seed, momentum=None, nesterov=False)[source]#
The DPSGD optimizer.
Differential privacy is a standard for privacy guarantees of algorithms learning from aggregate databases including potentially sensitive information. DPSGD offers protection against a strong adversary with full knowledge of the training mechanism and access to the model’s parameters.
WARNING: This GradientTransformation expects input updates to have a batch dimension on the 0th axis. That is, this function expects perexample gradients as input (which are easy to obtain in JAX using jax.vmap).
References
Abadi et al, 2016: https://arxiv.org/abs/1607.00133
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.l2_norm_clip (
float
) – Maximum L2 norm of the perexample gradients.noise_multiplier (
float
) – Ratio of standard deviation to the clipping norm.seed (
int
) – Initial seed used for the jax.random.PRNGKeymomentum (
Optional
[float
]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.nesterov (
bool
) – Whether Nesterov momentum is used.
 Return type
 Returns
A GradientTransformation.
RAdam#
 optax.radam(learning_rate, b1=0.9, b2=0.999, eps=1e08, eps_root=0.0, threshold=5.0)[source]#
The Rectified Adam optimizer.
The adaptive learning rate in Adam has undesirably large variance in early stages of training, due to the limited number of training samples used to estimate the optimizer’s statistics. Rectified Adam addresses this issue by analytically reducing the large variance.
References
Kingma et al, 2014: https://arxiv.org/abs/1412.6980
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.eps_root (
float
) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta)gradients through Adam.threshold (
float
) – Threshold for variance tractability.
 Return type
 Returns
The corresponding GradientTransformation.
RMSProp#
 optax.rmsprop(learning_rate, decay=0.9, eps=1e08, initial_scale=0.0, centered=False, momentum=None, nesterov=False)[source]#
A flexible RMSProp optimizer.
RMSProp is an SGD variant with learning rate adaptation. The learning_rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimizer that can be used to switch between several of these variants.
References
Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf Graves, 2013: https://arxiv.org/abs/1308.0850
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.decay (
float
) – Decay used to track the magnitude of previous gradients.eps (
float
) – A small numerical constant to avoid dividing by zero when rescaling.initial_scale (
float
) – Initial value of accumulators tracking the magnitude of previous updates. PyTorch uses 0, TF1 uses 1. When reproducing results from a paper, verify the value used by the authors.centered (
bool
) – Whether the second moment or the variance of the past gradients is used to rescale the latest gradients.momentum (
Optional
[float
]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.nesterov (
bool
) – Whether Nesterov momentum is used.
 Return type
 Returns
The corresponding GradientTransformation.
SGD#
 optax.sgd(learning_rate, momentum=None, nesterov=False, accumulator_dtype=None)[source]#
A canonical Stochastic Gradient Descent optimizer.
This implements stochastic gradient descent. It also includes support for momentum, and nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.
References
Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.momentum (
Optional
[float
]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.nesterov (
bool
) – Whether Nesterov momentum is used.accumulator_dtype (
Optional
[Any
]) – Optional dtype to be used for the accumulator; if None then the dtype is inferred from params and updates.
 Return type
 Returns
A GradientTransformation.
Yogi#
 optax.yogi(learning_rate, b1=0.9, b2=0.999, eps=0.001)[source]#
The Yogi optimizer.
Yogi is an adaptive optimizer, which provides control in tuning the effective learning rate to prevent it from increasing. By doing so, it focuses on addressing the issues of convergence and generalization in exponential moving averagebased adaptive methods (such as Adam and RMSprop). Yogi is a modification of Adam and uses the same parameters.
References
Zaheer et al, 2018: https://proceedings.neurips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3Paper.pdf
 Parameters
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A fixed global scaling factor.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
 Return type
 Returns
The corresponding GradientTransformation.
Optax Transformations#

Clips updates to be at most 

Add parameter scaled by weight_decay. 

Add gradient noise. 
alias of 


Add parameter scaled by weight_decay. 
alias of 


State for adding gradient noise. 

Accumulate gradients and apply them every k steps. 

Contains a counter and a gradient accumulator. 

Performs bias correction. 
Centralize gradients. 


Clips updates elementwise, to be in 

Clips updates to a max rms for the gradient of each param vector or matrix. 

Clips updates using their global norm. 
alias of 

alias of 


Compute an exponential moving average of past updates. 

Holds an exponential moving average of past updates. 
An empty state for the simplest stateless transformations. 


Overall state of the gradient transformation. 

Compute the global norm across a nested structure of tensors. 

A pair of pure functions implementing a gradient transformation. 

A specialization of GradientTransformation that supports extra args. 

Stateless identity transformation that leaves input gradients untouched. 
Modifies the updates to keep parameters nonnegative, i.e. >= 0. 

alias of 

The central part of internal API. 

The central part of internal API. 


Scale updates by some fixed scalar step_size. 

Rescale updates according to the Adam algorithm. 

Rescale updates according to the Adamax algorithm. 

Rescale updates according to the AMSGrad algorithm. 

Rescale updates according to the AdaBelief algorithm. 

Scaling by a factored estimate of the gradient rms (as in Adafactor). 

Rescale updates according to the Lion algorithm. 

Computes NovoGrad updates. 

Compute generalized optimistic gradients. 

Scale updates for each param block by the norm of that block's parameters. 

Scale updates by rms of the gradient for each param vector or matrix. 

Rescale updates according to the Rectified Adam algorithm. 

Rescale updates by the root of the exp. 

Rescale updates by the root of the sum of all squared gradients to date. 

Scale updates using a custom schedule for the step_size. 

Scale updates by sm3. 

Rescale updates by the root of the centered exp. 

Scale updates by trust ratio. 

Rescale updates according to the Yogi algorithm. 

State for the Adam algorithm. 

State for the AMSGrad algorithm. 

State for the Lion algorithm. 

State for Novograd. 

State for exponential root meansquared (RMS)normalized updates. 

State holding the sum of gradient squares to date. 

State for centered exponential moving average of squares of updates. 

Maintains count for scale scheduling. 
The scale and decay trust ratio transformation is stateless. 


State for the SM3 algorithm. 
alias of 


Creates a stateless transformation from an updatelike function. 
Creates a stateless transformation from an updatelike function for arrays. 

Stateless transformation that maps input gradients to zero. 


Compute a trace of past updates. 

Apply a callable over all params in the given optimizer state. 

Holds an aggregation of past updates. 

A callable type for the init step of a GradientTransformation. 

A callable type for the update step of a GradientTransformation. 

Compute the exponential moving average of the infinity norm. 

Compute the exponential moving average of the orderth moment. 

Compute the EMA of the orderth moment of the elementwise norm. 
The central part of internal API. 

A transformation which replaces NaNs with 0. 


Contains a tree. 

Wraps a gradient transformation, so that it ignores extra args. 
Optax Types#
 class optax.GradientTransformation(init: TransformInitFn, update: TransformUpdateFn)[source]#
A pair of pure functions implementing a gradient transformation.
Optax optimizers are all implemented as _gradient transformations_. A gradient transformation is defined to be a pair of pure functions, which are combined together in a NamedTuple so that they can be referred to by name.
Note that an extended API is provided for users wishing to build optimizers that take additional arguments during the update step. For more details, see
GradientTransoformationExtraArgs
.Since gradient transformations do not contain any internal state, all stateful optimizer properties (such as the current step count when using optimizer scheduels, or momemtum values) are passed through optax gradient transformations by using the optimizer _state_ pytree. Each time a gradient transformation is applied, a new state is computed and returned, ready to be passed to the next call to the gradient transformation.
Since gradient transformations are pure, idempotent functions, the only way to change the behaviour of a gradient transformation between steps, is to change the values in the optimizer state. To see an example of mutating the optimizer state in order to control the behaviour of an optax gradient transformation, see the metalearning example in the optax documentation.
 init#
A pure function which, when called with an example instance of the parameters whose gradients will be transformed, returns a pytree containing the initial value for the optimizer state.
 Type
 update#
A pure function which takes as input a pytree of updates (with the same tree structure as the original params pytree passed to init), the previous optimizer state (which may have been initialized using the init function), and optionally the current params. The update function then returns the computed gradient updates, and a new optimizer state.
 Type
 init: TransformInitFn#
Alias for field number 0
 update: TransformUpdateFn#
Alias for field number 1
 class optax.TransformInitFn(*args, **kwargs)[source]#
A callable type for the init step of a GradientTransformation.
The init step takes a tree of params and uses these to construct an arbitrary structured initial state for the gradient transformation. This may hold statistics of the past updates or any other non static information.
 __call__(params)[source]#
The init function.
 Parameters
params (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – The initial value of the parameters. Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]] Returns
The initial state of the gradient transformation.
 __subclasshook__()[source]#
Abstract classes can override this to customize issubclass().
This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).
 class optax.TransformUpdateFn(*args, **kwargs)[source]#
A callable type for the update step of a GradientTransformation.
The update step takes a tree of candidate parameter updates (e.g. their gradient with respect to some loss), an arbitrary structured state, and the current params of the model being optimised. The params argument is optional, it must however be provided when using transformations that require access to the current values of the parameters.
For the case where additional arguments are required, an alternative interface may be used, see
TransformUpdateExtraArgsFn
for details. __call__(updates, state, params=None)[source]#
The update function.
 Parameters
updates (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – A tree of candidate updates.state (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – The state of the gradient transformation.params (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
],None
]) – (Optionally) the current value of the parameters.
 Return type
Tuple
[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]] Returns
The transformed updates, and the updated state.
 __subclasshook__()[source]#
Abstract classes can override this to customize issubclass().
This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).
 optax.OptState#
alias of
Union
[jax.Array
,numpy.ndarray
,jax._src.interpreters.batching.BatchTracer
,jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase
,Iterable
[ArrayTree
],Mapping
[Any
,ArrayTree
]]
Optax Transforms and States#
 optax.adaptive_grad_clip(clipping, eps=0.001)[source]#
Clips updates to be at most
clipping * parameter_norm
, unitwise.References
[Brock, Smith, De, Simonyan 2021] HighPerformance LargeScale Image Recognition Without Normalization. (https://arxiv.org/abs/2102.06171)
 Parameters
clipping (
float
) – The maximum allowed ratio of update norm to parameter norm.eps (
float
) – An epsilon term to prevent clipping of zeroinitialized params.
 Return type
 Returns
A GradientTransformation object.
 optax.AdaptiveGradClipState[source]#
alias of
optax._src.base.EmptyState
 optax.add_decayed_weights(weight_decay=0.0, mask=None)[source]#
Add parameter scaled by weight_decay.
 Parameters
weight_decay (
Union
[float
,Array
]) – A scalar weight decay rate.mask (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]],Any
],None
]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
 Return type
 Returns
A GradientTransformation object.
 optax.add_noise(eta, gamma, seed)[source]#
Add gradient noise.
References
[Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807)
 Parameters
eta (
float
) – Base variance of the gaussian noise added to the gradient.gamma (
float
) – Decay exponent for annealing of the variance.seed (
int
) – Seed for random number generation.
 Return type
 Returns
A GradientTransformation object.
 optax.AddDecayedWeightsState[source]#
alias of
optax._src.base.EmptyState
 optax.additive_weight_decay(weight_decay=0.0, mask=None)[source]#
Add parameter scaled by weight_decay.
 Parameters
weight_decay (
Union
[float
,Array
]) – A scalar weight decay rate.mask (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]],Any
],None
]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
 Return type
 Returns
A GradientTransformation object.
 optax.AdditiveWeightDecayState[source]#
alias of
optax._src.base.EmptyState
 class optax.AddNoiseState(count: chex.Array, rng_key: chex.PRNGKey)[source]#
State for adding gradient noise. Contains a count for annealing.
 count: chex.Array#
Alias for field number 0
 rng_key: chex.PRNGKey#
Alias for field number 1
 optax.apply_every(k=1)[source]#
Accumulate gradients and apply them every k steps.
Note that if this transformation is part of a chain, the states of the other transformations will still be updated at every step. In particular, using apply_every with a batch size of N/2 and k=2 is not necessarily equivalent to not using apply_every with a batch size of N. If this equivalence is important for you, consider using the optax.MultiSteps.
 Parameters
k (
int
) – Emit nonzero gradients every k steps, otherwise accumulate them. Return type
 Returns
A GradientTransformation object.
 class optax.ApplyEvery(count: chex.Array, grad_acc: base.Updates)[source]#
Contains a counter and a gradient accumulator.
 count: chex.Array#
Alias for field number 0
 grad_acc: base.Updates#
Alias for field number 1
 optax.centralize()[source]#
Centralize gradients.
References
[Yong et al, 2020](https://arxiv.org/abs/2004.01461)
 Return type
 Returns
A GradientTransformation object.
 optax.clip(max_delta)[source]#
Clips updates elementwise, to be in
[max_delta, +max_delta]
. Parameters
max_delta (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]) – The maximum absolute value for each element in the update. Return type
 Returns
A GradientTransformation object.
 optax.clip_by_block_rms(threshold)[source]#
Clips updates to a max rms for the gradient of each param vector or matrix.
A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.
 Parameters
threshold (
float
) – The maximum rms for the gradient of each param vector or matrix. Return type
 Returns
A GradientTransformation object.
 optax.clip_by_global_norm(max_norm)[source]#
Clips updates using their global norm.
References
[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)
 Parameters
max_norm (
float
) – The maximum global norm for an update. Return type
 Returns
A GradientTransformation object.
 optax.ClipByGlobalNormState[source]#
alias of
optax._src.base.EmptyState
 optax.ClipState[source]#
alias of
optax._src.base.EmptyState
 optax.ema(decay, debias=True, accumulator_dtype=None)[source]#
Compute an exponential moving average of past updates.
Note: trace and ema have very similar but distinct updates; ema = decay * ema + (1decay) * t, while trace = decay * trace + t. Both are frequently found in the optimization literature.
 Parameters
decay (
float
) – Decay rate for the exponential moving average.debias (
bool
) – Whether to debias the transformed gradient.accumulator_dtype (
Optional
[Any
]) – Optional dtype to used for the accumulator; if None then the dtype is inferred from params and updates.
 Return type
 Returns
A GradientTransformation object.
 class optax.EmaState(count: chex.Array, ema: base.Params)[source]#
Holds an exponential moving average of past updates.
 count: chex.Array#
Alias for field number 0
 ema: base.Params#
Alias for field number 1
 class optax.EmptyState[source]#
An empty state for the simplest stateless transformations.
 static __new__(_cls)#
Create new instance of EmptyState()
 class optax.FactoredState(count: chex.Array, v_row: chex.ArrayTree, v_col: chex.ArrayTree, v: chex.ArrayTree)[source]#
Overall state of the gradient transformation.
 count: chex.Array#
Alias for field number 0
 v_row: chex.ArrayTree#
Alias for field number 1
 v_col: chex.ArrayTree#
Alias for field number 2
 v: chex.ArrayTree#
Alias for field number 3
 optax.global_norm(updates)[source]#
Compute the global norm across a nested structure of tensors.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]
 optax.identity()[source]#
Stateless identity transformation that leaves input gradients untouched.
This function passes through the gradient updates unchanged.
Note, this should not to be confused with set_to_zero, which maps the input updates to zero  which is the transform required for the model parameters to be left unchanged when the updates are applied to them.
 Return type
 Returns
A GradientTransformation object.
 optax.keep_params_nonnegative()[source]#
Modifies the updates to keep parameters nonnegative, i.e. >= 0.
This transformation ensures that parameters after the update will be larger than or equal to zero. In a chain of transformations, this should be the last one.
WARNING: the transformation expects input params to be nonnegative. When params is negative the transformed update will move them to 0.
 Return type
 Returns
A GradientTransformation object.
 optax.NonNegativeParamsState[source]#
alias of
optax._src.base.EmptyState
 optax.scale(step_size)[source]#
Scale updates by some fixed scalar step_size.
 Parameters
step_size (
float
) – A scalar corresponding to a fixed scaling factor for updates. Return type
 Returns
A GradientTransformation object.
 optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e08, eps_root=0.0, mu_dtype=None)[source]#
Rescale updates according to the Adam algorithm.
References
[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
 Parameters
b1 (
float
) – Decay rate for the exponentially weighted average of grads.b2 (
float
) – Decay rate for the exponentially weighted average of squared grads.eps (
float
) – Term added to the denominator to improve numerical stability.eps_root (
float
) – Term added to the denominator inside the squareroot to improve numerical stability when backpropagating gradients through the rescaling.mu_dtype (
Optional
[_ScalarMeta
]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
 Return type
 Returns
A GradientTransformation object.
 optax.scale_by_adamax(b1=0.9, b2=0.999, eps=1e08)[source]#
Rescale updates according to the Adamax algorithm.
References
[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)
 Parameters
b1 (
float
) – Decay rate for the exponentially weighted average of grads.b2 (
float
) – Decay rate for the exponentially weighted maximum of grads.eps (
float
) – Term added to the denominator to improve numerical stability.
 Return type
 Returns
A GradientTransformation object.
 optax.scale_by_amsgrad(b1=0.9, b2=0.999, eps=1e08, eps_root=0.0, mu_dtype=None)[source]#
Rescale updates according to the AMSGrad algorithm.
References
[Reddi et al, 2018](https://openreview.net/forum?id=ryQu7fRZ)
 Parameters
b1 (
float
) – Decay rate for the exponentially weighted average of grads.b2 (
float
) – Decay rate for the exponentially weighted average of squared grads.eps (
float
) – Term added to the denominator to improve numerical stability.eps_root (
float
) – Term added to the denominator inside the squareroot to improve numerical stability when backpropagating gradients through the rescaling.mu_dtype (
Optional
[_ScalarMeta
]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
 Return type
 Returns
A GradientTransformation object.
 optax.scale_by_belief(b1=0.9, b2=0.999, eps=1e16, eps_root=1e16)[source]#
Rescale updates according to the AdaBelief algorithm.
References
[Zhuang et al, 2020](https://arxiv.org/abs/2010.07468)
 Parameters
b1 (
float
) – Decay rate for the exponentially weighted average of grads.b2 (
float
) – Decay rate for the exponentially weighted average of variance of grads.eps (
float
) – Term added to the denominator to improve numerical stability.eps_root (
float
) – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for metalearning), this must be nonzero.
 Return type
 Returns
A GradientTransformation object.
 optax.scale_by_factored_rms(factored=True, decay_rate=0.8, step_offset=0, min_dim_size_to_factor=128, epsilon=1e30, decay_rate_fn=<function _decay_rate_pow>)[source]#
Scaling by a factored estimate of the gradient rms (as in Adafactor).
This is a socalled “1+epsilon” scaling algorithms, that is extremely memory efficient compared to RMSProp/Adam, and has had wide success when applied to largescale training of attentionbased models.
References
[Shazeer et al, 2018](https://arxiv.org/abs/1804.04235)
 Parameters
factored (
bool
) – boolean: whether to use factored secondmoment estimates..decay_rate (
float
) – float: controls secondmoment exponential decay schedule.step_offset (
int
) – for finetuning, one may set this to the starting stepnumber of the fine tuning phase.min_dim_size_to_factor (
int
) – only factor accumulator if two array dimensions are at least this size.epsilon (
float
) – Regularization constant for squared gradient.decay_rate_fn (
Callable
[[int
,float
],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]]) – A function that accepts the current step, the decay rate parameter and controls the schedule for the second momentum. Defaults to the original adafactor’s power decay schedule. One potential shortcoming of the orignal schedule is the fact that second momentum converges to 1, which effectively freezes the second momentum. To prevent this the user can opt for a custom schedule that sets an upper bound for the second momentum, like in [Zhai et al., 2021](https://arxiv.org/abs/2106.04560).
 Returns
the corresponding GradientTransformation.
 optax.scale_by_lion(b1=0.9, b2=0.99, mu_dtype=None)[source]#
Rescale updates according to the Lion algorithm.
References
[Chen et al, 2023](https://arxiv.org/abs/2302.06675)
 Parameters
b1 (
float
) – Rate for combining the momentum and the current grad.b2 (
float
) – Decay rate for the exponentially weighted average of grads.mu_dtype (
Optional
[_ScalarMeta
]) – Optional dtype to be used for the momentum; if None then the dtype is inferred from `params and updates.
 Return type
 Returns
A GradientTransformation object.
 optax.scale_by_novograd(b1=0.9, b2=0.25, eps=1e08, eps_root=0.0, weight_decay=0.0, mu_dtype=None)[source]#
Computes NovoGrad updates.
References
[Ginsburg et al, 2019](https://arxiv.org/abs/1905.11286)
 Parameters
b1 (
float
) – A decay rate for the exponentially weighted average of grads.b2 (
float
) – A decay rate for the exponentially weighted average of squared grads.eps (
float
) – A term added to the denominator to improve numerical stability.eps_root (
float
) – A term added to the denominator inside the squareroot to improve numerical stability when backpropagating gradients through the rescaling.weight_decay (
float
) – A scalar weight decay rate.mu_dtype (
Optional
[_ScalarMeta
]) – An optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
 Return type
 Returns
The corresponding GradientTransformation.
 optax.scale_by_param_block_norm(min_scale=0.001)[source]#
Scale updates for each param block by the norm of that block’s parameters.
A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.
 Parameters
min_scale (
float
) – Minimum scaling factor. Return type
 Returns
A GradientTransformation object.
 optax.scale_by_param_block_rms(min_scale=0.001)[source]#
Scale updates by rms of the gradient for each param vector or matrix.
A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.
 Parameters
min_scale (
float
) – Minimum scaling factor. Return type
 Returns
A GradientTransformation object.
 optax.scale_by_radam(b1=0.9, b2=0.999, eps=1e08, eps_root=0.0, threshold=5.0)[source]#
Rescale updates according to the Rectified Adam algorithm.
References
[Liu et al, 2020](https://arxiv.org/abs/1908.03265)
 Parameters
b1 (
float
) – Decay rate for the exponentially weighted average of grads.b2 (
float
) – Decay rate for the exponentially weighted average of squared grads.eps (
float
) – Term added to the denominator to improve numerical stability.eps_root (
float
) – Term added to the denominator inside the squareroot to improve numerical stability when backpropagating gradients through the rescaling.threshold (
float
) – Threshold for variance tractability.
 Return type
 Returns
A GradientTransformation object.
 optax.scale_by_rms(decay=0.9, eps=1e08, initial_scale=0.0)[source]#
Rescale updates by the root of the exp. moving avg of the square.
References
[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
 Parameters
decay (
float
) – Decay rate for the exponentially weighted average of squared grads.eps (
float
) – Term added to the denominator to improve numerical stability.initial_scale (
float
) – Initial value for second moment.
 Return type
 Returns
A GradientTransformation object.
 optax.scale_by_rss(initial_accumulator_value=0.1, eps=1e07)[source]#
Rescale updates by the root of the sum of all squared gradients to date.
References
[Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) [McMahan et al., 2010](https://arxiv.org/abs/1002.4908)
 Parameters
initial_accumulator_value (
float
) – Starting value for accumulators, must be >= 0.eps (
float
) – A small floating point value to avoid zero denominator.
 Return type
 Returns
A GradientTransformation object.
 optax.scale_by_schedule(step_size_fn)[source]#
Scale updates using a custom schedule for the step_size.
 Parameters
step_size_fn (
Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]) – A function that takes an update count as input and proposes the step_size to multiply the updates by. Return type
 Returns
A GradientTransformation object.
 optax.scale_by_sm3(b1=0.9, b2=1.0, eps=1e08)[source]#
Scale updates by sm3.
References
[Anil et. al 2019](https://arxiv.org/abs/1901.11150)
 Parameters
b1 (
float
) – Decay rate for the exponentially weighted average of grads.b2 (
float
) – Decay rate for the exponentially weighted average of squared grads.eps (
float
) – Term added to the denominator to improve numerical stability.
 Return type
 Returns
A GradientTransformation object.
 optax.scale_by_stddev(decay=0.9, eps=1e08, initial_scale=0.0)[source]#
Rescale updates by the root of the centered exp. moving average of squares.
References
[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
 Parameters
decay (
float
) – Decay rate for the exponentially weighted average of squared grads.eps (
float
) – Term added to the denominator to improve numerical stability.initial_scale (
float
) – Initial value for second moment.
 Return type
 Returns
A GradientTransformation object.
 optax.scale_by_trust_ratio(min_norm=0.0, trust_coefficient=1.0, eps=0.0)[source]#
Scale updates by trust ratio.
References
[You et. al 2020](https://arxiv.org/abs/1904.00962)
 Parameters
min_norm (
float
) – Minimum norm for params and gradient norms; by default is zero.trust_coefficient (
float
) – A multiplier for the trust ratio.eps (
float
) – Additive constant added to the denominator for numerical stability.
 Return type
 Returns
A GradientTransformation object.
 optax.scale_by_yogi(b1=0.9, b2=0.999, eps=0.001, eps_root=0.0, initial_accumulator_value=1e06)[source]#
Rescale updates according to the Yogi algorithm.
Supports complex numbers, see https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29
References
[Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3Abstract.html) #pylint:disable=linetoolong
 Parameters
b1 (
float
) – Decay rate for the exponentially weighted average of grads.b2 (
float
) – Decay rate for the exponentially weighted average of variance of grads.eps (
float
) – Term added to the denominator to improve numerical stability.eps_root (
float
) – Term added to the denominator inside the squareroot to improve numerical stability when backpropagating gradients through the rescaling.initial_accumulator_value (
float
) – The starting value for accumulators. Only positive values are allowed.
 Return type
 Returns
A GradientTransformation object.
 class optax.ScaleByAdamState(count: chex.Array, mu: base.Updates, nu: base.Updates)[source]#
State for the Adam algorithm.
 count: chex.Array#
Alias for field number 0
 mu: base.Updates#
Alias for field number 1
 nu: base.Updates#
Alias for field number 2
 class optax.ScaleByAmsgradState(count: chex.Array, mu: base.Updates, nu: base.Updates, nu_max: base.Updates)[source]#
State for the AMSGrad algorithm.
 count: chex.Array#
Alias for field number 0
 mu: base.Updates#
Alias for field number 1
 nu: base.Updates#
Alias for field number 2
 nu_max: base.Updates#
Alias for field number 3
 class optax.ScaleByLionState(count: chex.Array, mu: base.Updates)[source]#
State for the Lion algorithm.
 count: chex.Array#
Alias for field number 0
 mu: base.Updates#
Alias for field number 1
 class optax.ScaleByNovogradState(count: chex.Array, mu: base.Updates, nu: base.Updates)[source]#
State for Novograd.
 count: chex.Array#
Alias for field number 0
 mu: base.Updates#
Alias for field number 1
 nu: base.Updates#
Alias for field number 2
 class optax.ScaleByRmsState(nu: base.Updates)[source]#
State for exponential root meansquared (RMS)normalized updates.
 nu: base.Updates#
Alias for field number 0
 class optax.ScaleByRssState(sum_of_squares: base.Updates)[source]#
State holding the sum of gradient squares to date.
 sum_of_squares: base.Updates#
Alias for field number 0
 class optax.ScaleByRStdDevState(mu: base.Updates, nu: base.Updates)[source]#
State for centered exponential moving average of squares of updates.
 mu: base.Updates#
Alias for field number 0
 nu: base.Updates#
Alias for field number 1
 class optax.ScaleByScheduleState(count: chex.Array)[source]#
Maintains count for scale scheduling.
 count: chex.Array#
Alias for field number 0
 class optax.ScaleBySM3State(mu: base.Updates, nu: base.Updates)[source]#
State for the SM3 algorithm.
 mu: base.Updates#
Alias for field number 0
 nu: base.Updates#
Alias for field number 1
 class optax.ScaleByTrustRatioState[source]#
The scale and decay trust ratio transformation is stateless.
 static __new__(_cls)#
Create new instance of ScaleByTrustRatioState()
 optax.ScaleState[source]#
alias of
optax._src.base.EmptyState
 optax.set_to_zero()[source]#
Stateless transformation that maps input gradients to zero.
The resulting update function, when called, will return a tree of zeros matching the shape of the input gradients. This means that when the updates returned from this transformation are applied to the model parameters, the model parameters will remain unchanged.
This can be used in combination with multi_transform or masked to freeze (i.e. keep fixed) some parts of the tree of model parameters while applying gradient updates to other parts of the tree.
When updates are set to zero inside the same jitcompiled function as the calculation of gradients, optax transformations, and application of updates to parameters, unnecessary computations will in general be dropped.
 Return type
 Returns
A GradientTransformation object.
 optax.stateless(f)[source]#
Creates a stateless transformation from an updatelike function.
This wrapper eliminates the boilerplate needed to create a transformation that does not require saved state between iterations.
 Parameters
f (
Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
],None
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]]) – Update function that takes in updates (e.g. gradients) and parameters and returns updates. The parameters may be None. Return type
 Returns
An optax.GradientTransformation.
 optax.stateless_with_tree_map(f)[source]#
Creates a stateless transformation from an updatelike function for arrays.
This wrapper eliminates the boilerplate needed to create a transformation that does not require saved state between iterations, just like optax.stateless. In addition, this function will apply the tree_map over update/params for you.
 Parameters
f (
Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,None
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]]) – Update function that takes in an update array (e.g. gradients) and parameter array and returns an update array. The parameter array may be None. Return type
 Returns
An optax.GradientTransformation.
 optax.trace(decay, nesterov=False, accumulator_dtype=None)[source]#
Compute a trace of past updates.
Note: trace and ema have very similar but distinct updates; trace = decay * trace + t, while ema = decay * ema + (1decay) * t. Both are frequently found in the optimization literature.
 Parameters
decay (
float
) – Decay rate for the trace of past updates.nesterov (
bool
) – Whether to use Nesterov momentum.accumulator_dtype (
Optional
[Any
]) – Optional dtype to be used for the accumulator; if None then the dtype is inferred from params and updates.
 Return type
 Returns
A GradientTransformation object.
 class optax.TraceState(trace: base.Params)[source]#
Holds an aggregation of past updates.
 trace: base.Params#
Alias for field number 0
 optax.zero_nans()[source]#
A transformation which replaces NaNs with 0.
Zeroing values in gradients is guaranteed to produce a direction of nonincreasing loss.
The state of the transformation has the same tree structure as that of the parameters. Each leaf is a single boolean which contains True iff a NaN was detected in the corresponding parameter array at the last call to update. This state is not used by the transformation internally, but lets users be aware when NaNs have been zeroed out.
 Return type
 Returns
A GradientTransformation.
 class optax.ZeroNansState(found_nan: Any)[source]#
Contains a tree.
The entry found_nan has the same tree structure as that of the parameters. Each leaf is a single boolean which contains True iff a NaN was detected in the corresponding parameter array at the last call to update.
 found_nan: Any#
Alias for field number 0
Apply Updates#

Applies an update to the corresponding parameters. 

Incrementally update parameters via polyak averaging. 

Periodically update all parameters with new values. 
apply_updates#
 optax.apply_updates(params, updates)[source]#
Applies an update to the corresponding parameters.
This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a sequence of`GradientTransformations`. This function is exposed for convenience, but it just adds updates and parameters; you may also apply updates to parameters manually, using tree_map (e.g. if you want to manipulate updates in custom ways before applying them).
 Parameters
params (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – a tree of parameters.updates (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – a tree of updates, the tree structure and the shape of the leafparams. (nodes must match that of) –
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]] Returns
Updated parameters, with same structure, shape and type as params.
incremental_update#
 optax.incremental_update(new_tensors, old_tensors, step_size)[source]#
Incrementally update parameters via polyak averaging.
Polyak averaging tracks an (exponential moving) average of the past parameters of a model, for use at test/evaluation time.
References
[Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046)
 Parameters
new_tensors (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – the latest value of the tensors.old_tensors (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – a moving average of the values of the tensors.step_size (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]) – the step_size used to update the polyak average on each step.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]] Returns
an updated moving average step_size*new+(1step_size)*old of the params.
periodic_update#
 optax.periodic_update(new_tensors, old_tensors, steps, update_period)[source]#
Periodically update all parameters with new values.
A slow copy of a model’s parameters, updated every K actual updates, can be used to implement forms of selfsupervision (in supervised learning), or to stabilise temporal difference learning updates (in reinforcement learning).
References
[Grill et al., 2020](https://arxiv.org/abs/2006.07733) [Mnih et al., 2015](https://arxiv.org/abs/1312.5602)
 Parameters
new_tensors (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – the latest value of the tensors.old_tensors (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – a slow copy of the model’s parameters.steps (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – number of update steps on the “online” network.update_period (
int
) – every how many steps to update the “target” network.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]] Returns
a slow copy of the model’s parameters, updated every update_period steps.
Combining Optimizers#

Applies a list of chainable update transformations. 

Partitions params and applies a different transformation to each subset. 
chain#
 optax.chain(*args)[source]#
Applies a list of chainable update transformations.
Given a sequence of chainable transforms, chain returns an init_fn that constructs a state by concatenating the states of the individual transforms, and returns an update_fn which chains the update transformations feeding the appropriate state to each.
 Parameters
*args – a sequence of chainable (init_fn, update_fn) tuples.
 Return type
optax.GradientTransformationExtraArgs
 Returns
A
GradientTransformationExtraArgs
, created by chaining the input transformations. Note that independent of the argument types, the resulting transformation always supports extra args. Any extra arguments passed to the returned transformation will be passed only to those transformations in the chain that support extra args.
Multi Transform#
 optax.multi_transform(transforms, param_labels)[source]#
Partitions params and applies a different transformation to each subset.
Below is an example where we apply Adam to the weights and SGD to the biases of a 2layer neural network:
import optax import jax import jax.numpy as jnp def map_nested_fn(fn): '''Recursively apply `fn` to the keyvalue pairs of a nested dict''' def map_fn(nested_dict): return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) for k, v in nested_dict.items()} return map_fn params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}} gradients = jax.tree_util.tree_map(jnp.ones_like, params) # dummy gradients label_fn = map_nested_fn(lambda k, _: k) tx = optax.multi_transform({'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, label_fn) state = tx.init(params) updates, new_state = tx.update(gradients, state, params) new_params = optax.apply_updates(params, updates)
Instead of providing a
label_fn
, you may provide a PyTree of labels directly. Also, this PyTree may be a prefix of the parameters PyTree. This is demonstrated in the GAN pseudocode below:generator_params = ... discriminator_params = ... all_params = (generator_params, discriminator_params) param_labels = ('generator', 'discriminator') tx = optax.multi_transform( {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)}, param_labels)
If you would like to not optimize some parameters, you may wrap
optax.multi_transform
withoptax.masked()
. Parameters
transforms (
Mapping
[Hashable
,optax.GradientTransformation
]) – A mapping from labels to transformations. Each transformation will be only be applied to parameters with the same label.param_labels (
Union
[Any
,Callable
[[Any
],Any
]]) – A PyTree that is the same shape or a prefix of the parameters/updates (or a function that returns one given the parameters as input). The leaves of this PyTree correspond to the keys of the transforms (therefore the values at the leaves must be a subset of the keys).
 Return type
 Returns
An
optax.GradientTransformation
.
Optimizer Wrappers#

A function that wraps an optimizer to make it robust to a few NaNs or Infs. 

State of the GradientTransformation returned by apply_if_finite. 

Flattens parameters and gradients for init and update of inner transform. 

Lookahead optimizer. 

Holds a pair of slow and fast parameters for the lookahead optimizer. 

State of the GradientTransformation returned by lookahead. 

Mask updates so only some are transformed, the rest are passed through. 

Maintains inner transform state for masked transformations. 

Calls the inner update function only at certain steps. 

Maintains inner transform state and adds a step counter. 

An optimizer wrapper to accumulate gradients over multiple steps. 

State of the GradientTransformation returned by MultiSteps. 



Returns True if the global norm square of updates is small enough. 

Returns True iff any of the updates contains an inf or a NaN. 
Apply if Finite#
 optax.apply_if_finite(inner, max_consecutive_errors)[source]#
A function that wraps an optimizer to make it robust to a few NaNs or Infs.
The purpose of this function is to prevent any optimization to happen if the gradients contain NaNs or Infs. That is, when a NaN of Inf is detected in the gradients, the wrapped optimizer ignores that gradient update. If the NaNs or Infs persist after a given number of updates, the wrapped optimizer gives up and accepts the update.
 Parameters
inner (
optax.GradientTransformation
) – Inner transformation to be wrapped.max_consecutive_errors (
int
) – Maximum number of consecutive gradient updates containing NaNs of Infs that the wrapped optimizer will ignore. After that many ignored updates, the optimizer will give up and accept.
 Return type
 Returns
New
GradientTransformationExtraArgs
.
 class optax.ApplyIfFiniteState(notfinite_count: jnp.array, last_finite: jnp.array, total_notfinite: jnp.array, inner_state: Any)[source]#
State of the GradientTransformation returned by apply_if_finite.
 Fields:
 notfinite_count: Number of consecutive gradient updates containing an Inf or
a NaN. This number is reset to 0 whenever a gradient update without an Inf or a NaN is done.
 last_finite: Whether or not the last gradient update contained an Inf of a
NaN.
 total_notfinite: Total number of gradient updates containing an Inf or
a NaN since this optimizer was initialised. This number is never reset.
inner_state: The state of the inner GradientTransformation.
 notfinite_count: jnp.array#
Alias for field number 0
 last_finite: jnp.array#
Alias for field number 1
 total_notfinite: jnp.array#
Alias for field number 2
 inner_state: Any#
Alias for field number 3
flatten#
 optax.flatten(inner)[source]#
Flattens parameters and gradients for init and update of inner transform.
This can reduce the overhead of performing many calculations on lots of small variables, at the cost of slightly increased memory usage.
 Parameters
inner (
optax.GradientTransformation
) – Inner transformation to flatten inputs for. Return type
optax.GradientTransformationExtraArgs
 Returns
New
GradientTransformationExtraArgs
Lookahead#
 optax.lookahead(fast_optimizer, sync_period, slow_step_size, reset_state=False)[source]#
Lookahead optimizer.
Performs steps with a fast optimizer and periodically updates a set of slow parameters. Optionally resets the fast optimizer state after synchronization by calling the init function of the fast optimizer.
Updates returned by the lookahead optimizer should not be modified before they are applied, otherwise fast and slow parameters are not synchronized correctly.
References
[Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf)
 Parameters
fast_optimizer (
optax.GradientTransformation
) – The optimizer to use in the inner loop of lookahead.sync_period (
int
) – Number of fast optimizer steps to take before synchronizing parameters. Must be >= 1.slow_step_size (
float
) – Step size of the slow parameter updates.reset_state (
bool
) – Whether to reset the optimizer state of the fast opimizer after each synchronization.
 Return type
 Returns
A GradientTransformation with init and update functions. The updates passed to the update function should be calculated using the fast lookahead parameters only.
 class optax.LookaheadParams(fast: base.Params, slow: base.Params)[source]#
Holds a pair of slow and fast parameters for the lookahead optimizer.
Gradients should always be calculated with the fast parameters. The slow parameters should be used for testing and inference as they generalize better. See the reference for a detailed discussion.
References
[Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf)
 fast#
Fast parameters.
 Type
optax.Params
 slow#
Slow parameters.
 Type
optax.Params
 fast: base.Params#
Alias for field number 0
 slow: base.Params#
Alias for field number 1
 class optax.LookaheadState(fast_state: base.OptState, steps_since_sync: jnp.ndarray)[source]#
State of the GradientTransformation returned by lookahead.
 fast_state#
Optimizer state of the fast optimizer.
 Type
optax.OptState
 steps_since_sync#
Number of fast optimizer steps taken since slow and fast parameters were synchronized.
 Type
jnp.ndarray
 fast_state: base.OptState#
Alias for field number 0
 steps_since_sync: jnp.ndarray#
Alias for field number 1
Masked Update#
 optax.masked(inner, mask)[source]#
Mask updates so only some are transformed, the rest are passed through.
For example, it is common to skip weight decay for BatchNorm scale and all bias parameters. In many networks, these are the only parameters with only one dimension. So, you may create a mask function to mask these out as follows:
mask_fn = lambda p: jax.tree_util.tree_map(lambda x: x.ndim != 1, p) weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask_fn)
You may alternatively create the mask pytree upfront:
mask = jax.tree_util.tree_map(lambda x: x.ndim != 1, params) weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask)
For the
inner
transform, state will only be stored for the parameters that have a mask value ofTrue
. Parameters
inner (
optax.GradientTransformation
) – Inner transformation to mask.mask (
Union
[Any
,Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]],Any
]]) – a PyTree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans,True
for leaves/subtrees you want to apply the transformation to, andFalse
for those you want to skip. The mask must be static for the gradient transformation to be jitcompilable.
 Return type
optax.GradientTransformationExtraArgs
 Returns
New
GradientTransformationExtraArgs
wrappinginner
.
Maybe Update#
 optax.maybe_update(inner, should_update_fn)[source]#
Calls the inner update function only at certain steps.
Creates a transformation wrapper which counts the number of times the update function has been called. This counter is passed to the should_update_fn to decide when to call the inner update function.
When not calling the inner update function, the updates and the inner state are left untouched and just passed through. The step counter is increased regardless.
 Parameters
inner (
optax.GradientTransformation
) – the inner transformation.should_update_fn (
Callable
[[Array
],Array
]) – this function takes in a step counter (array of shape [] and dtype int32), and returns a boolean array of shape [].
 Return type
optax.GradientTransformationExtraArgs
 Returns
A new
GradientTransformationExtraArgs
.
Multistep Update#
 class optax.MultiSteps(opt, every_k_schedule, use_grad_mean=True, should_skip_update_fn=None)[source]#
An optimizer wrapper to accumulate gradients over multiple steps.
This wrapper collects together the updates passed to its update function over consecutive steps until a given number of scheduled steps is reached. In each of these intermediate steps, the returned value from the optimizer is a tree of zeros of the same shape of the updates passed as input.
Once the scheduled number of intermediate ‘ministeps’ has been reached, the gradients accumulated to the current time will be passed to the wrapped optimizer’s update function, (with the inner optimizer’s state being updated appropriately) and then returned to the caller. The wrapper’s accumulated gradients are then set back to zero and the process starts again.
The number of ministeps per gradient update is controlled by a function, and it can vary over training. This offers a means of varying batch size over training.
 __init__(opt, every_k_schedule, use_grad_mean=True, should_skip_update_fn=None)[source]#
Initialiser.
 Parameters
opt (
optax.GradientTransformation
) – the wrapped optimizer.every_k_schedule (
Union
[int
,Callable
[[Array
],Array
]]) –an int or f a function. * As a function, it returns how many ministeps should be accumulated
in a single gradient step. Its only argument is the current gradient step count. By varying the returned value, users can vary the overall training batch size.
If an int, this is the constant number of ministeps per gradient update.
use_grad_mean (
bool
) – if True (the default), gradients accumulated over multiple ministeps are averaged. Otherwise, they are summed.should_skip_update_fn (
Optional
[ShouldSkipUpdateFunction
]) –if provided, this function is used to decide when to accept or reject the updates from a ministep. When a ministep is rejected, the inner state of MultiSteps is not updated. In other words, it is as if this ministep never happened. For example: * to ignore updates containing inf or NaN, do
should_skip_update_fn=skip_not_finite;
to ignore updates with a norm square larger then 42, do `should_skip_update_fn=functools.partial(skip_large_updates,
max_norm_sq=42.)`.
Note that the optimizer’s state MultiStepsState contains a field skip_state in which debugging and monitoring information returned by should_skip_update_fn is written.
 update(updates, state, params=None, **extra_args)[source]#
Accumulates gradients and proposes nonzero updates every k_steps.
 Return type
Tuple
[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],MultiStepsState
]
 class optax.MultiStepsState(mini_step: Array, gradient_step: Array, inner_opt_state: Any, acc_grads: Any, skip_state: chex.ArrayTree = ())[source]#
State of the GradientTransformation returned by MultiSteps.
 Fields:
 mini_step: current ministep counter. At an update, this either increases by
1 or is reset to 0.
 gradient_step: gradient step counter. This only increases after enough
ministeps have been accumulated.
inner_opt_state: the state of the wrapped otpimiser. acc_grads: accumulated gradients over multiple ministeps. skip_state: an arbitrarily nested tree of arrays. This is only
relevant when passing a should_skip_update_fn to MultiSteps. This structure will then contain values for debugging and or monitoring. The actual structure will vary depending on the choice of ShouldSkipUpdateFunction.
 mini_step: Array#
Alias for field number 0
 gradient_step: Array#
Alias for field number 1
 inner_opt_state: Any#
Alias for field number 2
 acc_grads: Any#
Alias for field number 3
 skip_state: chex.ArrayTree#
Alias for field number 4
Common Losses#

Computes a convex version of the KullbackLeibler divergence loss. 

Computes the cosine distance between targets and predictions. 

Computes the cosine similarity between targets and predictions. 

Computes CTC loss. 

Computes CTC loss and CTC forwardprobabilities. 

Computes the hinge loss for binary classification. 

Huber loss, similar to L2 loss close to zero, L1 loss away from zero. 

Computes the KullbackLeibler divergence (relative entropy) loss. 

Calculates the L2 loss for a set of predictions. 

Calculates the logcosh loss for a set of predictions. 

Computes elementwise sigmoid cross entropy given logits and labels. 

Apply label smoothing. 

Computes the softmax cross entropy between sets of logits and labels. 
Computes softmax cross entropy between sets of logits and integer labels. 


Calculates the squared error for a set of predictions. 
Losses#
 optax.convex_kl_divergence(log_predictions, targets)[source]#
Computes a convex version of the KullbackLeibler divergence loss.
Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution. This version is jointly convex in p (targets) and q (log_predictions).
References
[Kullback, Leibler, 1951](https://www.jstor.org/stable/2236703)
 Parameters
log_predictions (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – Probabilities of predicted distribution with shape […, dim]. Expected to be in the logspace to avoid underflow.targets (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – Probabilities of target distribution with shape […, dim]. Expected to be strictly positive.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
KullbackLeibler divergence of predicted distribution from target distribution with shape […].
 optax.cosine_distance(predictions, targets, epsilon=0.0)[source]#
Computes the cosine distance between targets and predictions.
The cosine distance, implemented here, measures the dissimilarity of two vectors as the opposite of cosine similarity: 1  cos(theta).
References
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)
 Parameters
predictions (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – The predicted vectors, with shape […, dim].targets (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – Ground truth target vectors, with shape […, dim].epsilon (
float
) – minimum norm for terms in the denominator of the cosine similarity.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
cosine distances, with shape […].
 optax.cosine_similarity(predictions, targets, epsilon=0.0)[source]#
Computes the cosine similarity between targets and predictions.
The cosine similarity is a measure of similarity between vectors defined as the cosine of the angle between them, which is also the inner product of those vectors normalized to have unit norm.
References
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)
 Parameters
predictions (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – The predicted vectors, with shape […, dim].targets (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – Ground truth target vectors, with shape […, dim].epsilon (
float
) – minimum norm for terms in the denominator of the cosine similarity.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
cosine similarity measures, with shape […].
 optax.ctc_loss(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon= 100000.0)[source]#
Computes CTC loss.
See docstring for
ctc_loss_with_forward_probs
for details. Parameters
logits (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – (B, T, K)array containing logits of each class where B denotes the batch size, T denotes the max time frames inlogits
, and K denotes the number of classes including a class for blanks.logit_paddings (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – (B, T)array. Padding indicators forlogits
. Each element must be either 1.0 or 0.0, andlogitpaddings[b, t] == 1.0
denotes thatlogits[b, t, :]
are padded values.labels (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – (B, N)array containing reference integer labels where N denotes the max time frames in the label sequence.label_paddings (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – (B, N)array. Padding indicators forlabels
. Each element must be either 1.0 or 0.0, andlabelpaddings[b, n] == 1.0
denotes thatlabels[b, n]
is a padded label. In the current implementation,labels
must be rightpadded, i.e. each rowlabelpaddings[b, :]
must be repetition of zeroes, followed by repetition of ones.blank_id (
int
) – Id for blank token.logits[b, :, blank_id]
are used as probabilities of blank symbols.log_epsilon (
float
) – Numericallystable approximation of log(+0).
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
(B,)array containing loss values for each sequence in the batch.
 optax.ctc_loss_with_forward_probs(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon= 100000.0)[source]#
Computes CTC loss and CTC forwardprobabilities.
The CTC loss is a loss function based on loglikelihoods of the model that introduces a special blank symbol \(\phi\) to represent variablelength output sequences.
Forward probabilities returned by this function, as auxiliary results, are grouped into two part: blank alphaprobability and nonblank alpha probability. Those are defined as follows:
\[\alpha_{\mathrm{BLANK}}(t, n) = \sum_{\pi_{1:t1}} p(\pi_t = \phi  \pi_{1:t1}, y_{1:n1}, \cdots), \\ \alpha_{\mathrm{LABEL}}(t, n) = \sum_{\pi_{1:t1}} p(\pi_t = y_n  \pi_{1:t1}, y_{1:n1}, \cdots). \]Here, \(\pi\) denotes the alignment sequence in the reference [Graves et al, 2006] that is blankinserted representations of
labels
. The return values are the logarithms of the above probabilities.References
[Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891)
 Parameters
logits (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – (B, T, K)array containing logits of each class where B denotes the batch size, T denotes the max time frames inlogits
, and K denotes the number of classes including a class for blanks.logit_paddings (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – (B, T)array. Padding indicators forlogits
. Each element must be either 1.0 or 0.0, andlogitpaddings[b, t] == 1.0
denotes thatlogits[b, t, :]
are padded values.labels (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – (B, N)array containing reference integer labels where N denotes the max time frames in the label sequence.label_paddings (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – (B, N)array. Padding indicators forlabels
. Each element must be either 1.0 or 0.0, andlabelpaddings[b, n] == 1.0
denotes thatlabels[b, n]
is a padded label. In the current implementation,labels
must be rightpadded, i.e. each rowlabelpaddings[b, :]
must be repetition of zeroes, followed by repetition of ones.blank_id (
int
) – Id for blank token.logits[b, :, blank_id]
are used as probabilities of blank symbols.log_epsilon (
float
) – Numericallystable approximation of log(+0).
 Return type
Tuple
[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]] Returns
A tuple
(loss_value, logalpha_blank, logalpha_nonblank)
. Here,loss_value
is a (B,)array containing the loss values for each sequence in the batch,logalpha_blank
andlogalpha_nonblank
are (T, B, N+1)arrays where the (t, b, n)th element denotes log alpha_B(t, n) and log alpha_L(t, n), respectively, forb
th sequence in the batch.
 optax.hinge_loss(predictor_outputs, targets)[source]#
Computes the hinge loss for binary classification.
 Parameters
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
Binary Hinge Loss.
 optax.huber_loss(predictions, targets=None, delta=1.0)[source]#
Huber loss, similar to L2 loss close to zero, L1 loss away from zero.
If gradient descent is applied to the huber loss, it is equivalent to clipping gradients of an l2_loss to [delta, delta] in the backward pass.
References
[Huber, 1964](www.projecteuclid.org/download/pdf_1/euclid.aoms/1177703732)
 Parameters
predictions (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – a vector of arbitrary shape […].targets (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,None
]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.delta (
float
) – the bounds for the huber loss transformation, defaults at 1.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
elementwise huber losses, with the same shape of predictions.
 optax.kl_divergence(log_predictions, targets)[source]#
Computes the KullbackLeibler divergence (relative entropy) loss.
Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution.
References
[Kullback, Leibler, 1951](https://www.jstor.org/stable/2236703)
 Parameters
log_predictions (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – Probabilities of predicted distribution with shape […, dim]. Expected to be in the logspace to avoid underflow.targets (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – Probabilities of target distribution with shape […, dim]. Expected to be strictly positive.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
KullbackLeibler divergence of predicted distribution from target distribution with shape […].
 optax.l2_loss(predictions, targets=None)[source]#
Calculates the L2 loss for a set of predictions.
Note: the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.
References
[Chris Bishop, 2006](https://bit.ly/3eeP0ga)
 Parameters
predictions (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – a vector of arbitrary shape […].targets (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,None
]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
elementwise squared differences, with same shape as predictions.
 optax.log_cosh(predictions, targets=None)[source]#
Calculates the logcosh loss for a set of predictions.
log(cosh(x)) is approximately (x**2) / 2 for small x and abs(x)  log(2) for large x. It is a twice differentiable alternative to the Huber loss.
References
[Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym)
 Parameters
predictions (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – a vector of arbitrary shape […].targets (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,None
]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
the logcosh loss, with same shape as predictions.
 optax.sigmoid_binary_cross_entropy(logits, labels)[source]#
Computes elementwise sigmoid cross entropy given logits and labels.
This function can be used for binary or multiclass classification (where each class is an independent binary prediction and different classes are not mutually exclusive e.g. predicting that an image contains both a cat and a dog.)
Because this function is overloaded, please ensure your logits and labels are compatible with each other. If you’re passing in binary labels (values in {0, 1}), ensure your logits correspond to class 1 only. If you’re passing in perclass target probabilities or onehot labels, please ensure your logits are also multiclass. Be particularly careful if you’re relying on implicit broadcasting to reshape logits or labels.
References
[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)
 Parameters
logits – Each element is the unnormalized log probability of a binary prediction. See note about compatibility with labels above.
labels – Binary labels whose values are {0,1} or multiclass target probabilities. See note about compatibility with logits above.
 Returns
cross entropy for each binary prediction, same shape as logits.
 optax.smooth_labels(labels, alpha)[source]#
Apply label smoothing.
Label smoothing is often used in combination with a crossentropy loss. Smoothed labels favour small logit gaps, and it has been shown that this can provide better model calibration by preventing overconfident predictions.
References
[Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf)
 Parameters
labels (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – one hot labels to be smoothed.alpha (
float
) – the smoothing factor, the greedy category with be assigned probability (1alpha) + alpha / num_categories
 Return type
 Returns
a smoothed version of the one hot input labels.
 optax.softmax_cross_entropy(logits, labels)[source]#
Computes the softmax cross entropy between sets of logits and labels.
Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.
References
[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)
 Parameters
logits (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – Unnormalized log probabilities, with shape […, num_classes].labels (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – Valid probability distributions (nonnegative, sum to 1), e.g a one hot encoding specifying the correct class for each input; must have a shape broadcastable to […, num_classes].
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
cross entropy between each prediction and the corresponding target distributions, with shape […].
 optax.softmax_cross_entropy_with_integer_labels(logits, labels)[source]#
Computes softmax cross entropy between sets of logits and integer labels.
Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.
References
[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)
 Parameters
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
Cross entropy between each prediction and the corresponding target distributions, with shape […].
 optax.squared_error(predictions, targets=None)[source]#
Calculates the squared error for a set of predictions.
Mean Squared Error can be computed as squared_error(a, b).mean().
Note: l2_loss = 0.5 * squared_error, where the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.
References
[Chris Bishop, 2006](https://bit.ly/3eeP0ga)
 Parameters
predictions (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – a vector of arbitrary shape […].targets (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,None
]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
elementwise squared differences, with same shape as predictions.
Linear Algebra Operators#

Computes matrix^(1/p), where p is a positive integer. 



Power iteration algorithm. 
matrix_inverse_pth_root#
 optax.matrix_inverse_pth_root(matrix, p, num_iters=100, ridge_epsilon=1e06, error_tolerance=1e06, precision=<Precision.HIGHEST: 2>)[source]#
Computes matrix^(1/p), where p is a positive integer.
This function uses the Coupled newton iterations algorithm for the computation of a matrix’s inverse pth root.
References
 [Functions of Matrices, Theory and Computation,
Nicholas J Higham, Pg 184, Eq 7.18]( https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
 Parameters
matrix (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – the symmetric PSD matrix whose power it to be computedp (
int
) – exponent, for p a positive integer.num_iters (
int
) – Maximum number of iterations.ridge_epsilon (
float
) – Ridge epsilon added to make the matrix positive definite.error_tolerance (
float
) – Error indicator, useful for early termination.precision (
Precision
) – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).
 Returns
matrix^(1/p)
Utilities for numerical stability#

Increments int32 counter by one. 

Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients. 

Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads. 
Numerics#
 optax.safe_int32_increment(count)[source]#
Increments int32 counter by one.
Normally max_int + 1 would overflow to min_int. This functions ensures that when max_int is reached the counter stays at max_int.
 optax.safe_norm(x, min_norm, ord=None, axis=None, keepdims=False)[source]#
Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients.
The gradients of jnp.maximum(jnp.linalg.norm(x), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.
 Parameters
x (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – jax array.min_norm (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]) – lower bound for the returned norm.ord (
Union
[int
,float
,str
,None
]) – {nonzero int, inf, inf, ‘fro’, ‘nuc’}, optional. Order of the norm. inf means numpy’s inf object. The default is None.axis (
Union
[None
,Tuple
[int
, …],int
]) – {None, int, 2tuple of ints}, optional. If axis is an integer, it specifies the axis of x along which to compute the vector norms. If axis is a 2tuple, it specifies the axes that hold 2D matrices, and the matrix norms of these matrices are computed. If axis is None then either a vector norm (when x is 1D) or a matrix norm (when x is 2D) is returned. The default is None.keepdims (
bool
) – bool, optional. If this is set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original x.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
The safe norm of the input vector, accounting for correct gradient.
 optax.safe_root_mean_squares(x, min_rms)[source]#
Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads.
The gradients of maximum(sqrt(mean(abs_sq(x))), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.
 Parameters
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
] Returns
The safe RMS of the input vector, accounting for correct gradient.
power_iteration#
 optax.power_iteration(matrix, num_iters=100, error_tolerance=1e06, precision=<Precision.HIGHEST: 2>)[source]#
Power iteration algorithm.
The power iteration algorithm takes a symmetric PSD matrix A, and produces a scalar lambda , which is the greatest (in absolute value) eigenvalue of A, and a vector v, which is the corresponding eigenvector of A.
References
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
 Parameters
matrix (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]) – the symmetric PSD matrix.num_iters (
int
) – Number of iterations.error_tolerance (
float
) – Iterative exit condition.precision (
Precision
) – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).
 Returns
eigen vector, eigen value
Optimizer Schedules#

Constructs a constant schedule. 

Returns a function which implements cosine learning rate decay. 

Returns a function which implements the onecycle learning rate schedule. 

Constructs a schedule with either continuous or discrete exponential decay. 

Sequentially apply multiple schedules. 

Returns a function which implements the onecycle learning rate schedule. 



Returns a function which implements a piecewise constant schedule. 

Returns a function which implements a piecewise interpolated schedule. 

Constructs a schedule with polynomial transition from init to end value. 

SGD with warm restarts, from Loschilov & Hutter (arXiv:1608.03983). 

Linear warmup followed by cosine decay. 

Linear warmup followed by exponential decay. 
The central part of internal API. 


Maintains inner transform state, hyperparameters, and step count. 

Wrapper that injects hyperparameters into the inner GradientTransformation. 
Schedules#
 optax.constant_schedule(value)[source]#
Constructs a constant schedule.
 Parameters
value (
Union
[float
,int
]) – value to be held constant throughout. Returns
A function that maps step counts to values.
 Return type
schedule
 optax.cosine_decay_schedule(init_value, decay_steps, alpha=0.0, exponent=1.0)[source]#
Returns a function which implements cosine learning rate decay.
The schedule does not restart when
decay_steps
has been reached. Instead, the learning rate remains constant afterwards. For a cosine schedule with restarts,optax.join_schedules()
can be used to join several cosine decay schedules.For more details see: https://arxiv.org/abs/1608.03983.
 Parameters
init_value (
float
) – An initial value init_v.decay_steps (
int
) – Positive integer  the number of steps for which to apply the decay for.alpha (
float
) – Float. The minimum value of the multiplier used to adjust the learning rate.exponent (
float
) – Float. The default decay is 0.5 * (1 + cos(pi * t/T)), where t is the current timestep and T is the decay_steps. The exponent modifies this to be (0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.
 Returns
A function that maps step counts to values.
 Return type
schedule
 optax.cosine_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, div_factor=25.0, final_div_factor=10000.0)[source]#
Returns a function which implements the onecycle learning rate schedule.
This function uses a cosine annealing strategy. For more details see: https://arxiv.org/abs/1708.07120
 Parameters
transition_steps (
int
) – Number of steps over which annealing takes place.peak_value (
float
) – Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).pct_start (
float
) – The percentage of the cycle (in number of steps) spent increasing the learning rate.div_factor (
float
) – Determines the initial value via init_value = peak_value / div_factorfinal_div_factor (
float
) – Determines the final value via final_value = init_value / final_div_factor
 Returns
A function that maps step counts to values.
 Return type
schedule
 optax.exponential_decay(init_value, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None)[source]#
Constructs a schedule with either continuous or discrete exponential decay.
This function applies an exponential decay function to a provided initial value. The function returns the decayed value as follows:
` decayed_value = init_value * decay_rate ^ (count / transition_steps) `
If the argument staircase is True, then count / transition_steps is an integer division and the decayed value follows a staircase function.
 Parameters
init_value (
float
) – the initial learning rate.transition_steps (
int
) – must be positive. See the decay computation above.decay_rate (
float
) – must not be zero. The decay rate.transition_begin (
int
) – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).staircase (
bool
) – if True, decay the values at discrete intervals.end_value (
Optional
[float
]) – the value at which the exponential decay stops. When decay_rate < 1, end_value is treated as a lower bound, otherwise as an upper bound. Has no effect when decay_rate = 0.
 Returns
A function that maps step counts to values.
 Return type
schedule
 optax.join_schedules(schedules, boundaries)[source]#
Sequentially apply multiple schedules.
 Parameters
schedules (
Sequence
[Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – A list of callables (expected to be optax schedules). Each schedule will receive a step count indicating the number of steps since the previous boundary transition.boundaries (
Sequence
[int
]) – A list of integers (of length one less than schedules) that indicate when to transition between schedules.
 Returns
A function that maps step counts to values.
 Return type
schedule
 optax.linear_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, div_factor=25.0, final_div_factor=10000.0)[source]#
Returns a function which implements the onecycle learning rate schedule.
This function uses a linear annealing strategy. For more details see: https://arxiv.org/abs/1708.07120
 Parameters
transition_steps (
int
) – Number of steps over which annealing takes place.peak_value (
float
) – Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).pct_start (
float
) – The percentage of the cycle (in number of steps) spent increasing the learning rate.pct_final (
float
) – The percentage of the cycle (in number of steps) spent increasing to peak_value then decreasing back to init_value.div_factor (
float
) – Determines the initial value via init_value = peak_value / div_factorfinal_div_factor (
float
) – Determines the final value via final_value = init_value / final_div_factor
 Returns
A function that maps step counts to values.
 Return type
schedule
 optax.piecewise_constant_schedule(init_value, boundaries_and_scales=None)[source]#
Returns a function which implements a piecewise constant schedule.
 Parameters
init_value (
float
) – An initial value init_v.boundaries_and_scales (
Optional
[Dict
[int
,float
]]) – A map from boundaries b_i to nonnegative scaling factors f_i. For any step count s, the schedule returns init_v scaled by the product of all factors f_i such that b_i < s.
 Returns
A function that maps step counts to values.
 Return type
schedule
 optax.piecewise_interpolate_schedule(interpolate_type, init_value, boundaries_and_scales=None)[source]#
Returns a function which implements a piecewise interpolated schedule.
 Parameters
interpolate_type (
str
) – ‘linear’ or ‘cosine’, specifying the interpolation strategy.init_value (
float
) – An initial value init_v.boundaries_and_scales (
Optional
[Dict
[int
,float
]]) – A map from boundaries b_i to nonnegative scaling factors f_i. At boundary step b_i, the schedule returns init_v scaled by the product of all factors f_j such that b_j <= b_i. The values in between each boundary will be interpolated as per type.
 Returns
A function that maps step counts to values.
 Return type
schedule
 optax.polynomial_schedule(init_value, end_value, power, transition_steps, transition_begin=0)[source]#
Constructs a schedule with polynomial transition from init to end value.
 Parameters
init_value (
Union
[float
,int
]) – initial value for the scalar to be annealed.end_value (
Union
[float
,int
]) – end value of the scalar to be annealed.power (
Union
[float
,int
]) – the power of the polynomial used to transition from init to end.transition_steps (
int
) – number of steps over which annealing takes place, the scalar starts changing at transition_begin steps and completes the transition by transition_begin + transition_steps steps. If transition_steps <= 0, then the entire annealing process is disabled and the value is held fixed at init_value.transition_begin (
int
) – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).
 Returns
A function that maps step counts to values.
 Return type
schedule
 optax.sgdr_schedule(cosine_kwargs)[source]#
SGD with warm restarts, from Loschilov & Hutter (arXiv:1608.03983).
This learning rate schedule applies multiple joined cosine decay cycles. For more details see: https://arxiv.org/abs/1608.03983
 Parameters
cosine_kwargs (
Iterable
[Dict
[str
,Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,float
,int
]]]) – An Iterable of dicts, where each element specifies the arguments to pass to each cosine decay cycle. The decay_steps kwarg will specify how long each cycle lasts for, and therefore when to transition to the next cycle. Returns
A function that maps step counts to values.
 Return type
schedule
 optax.warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value=0.0, exponent=1.0)[source]#
Linear warmup followed by cosine decay.
 Parameters
init_value (
float
) – Initial value for the scalar to be annealed.peak_value (
float
) – Peak value for scalar to be annealed at end of warmup.warmup_steps (
int
) – Positive integer, the length of the linear warmup.decay_steps (
int
) – Positive integer, the total length of the schedule. Note that this includes the warmup time, so the number of steps during which cosine annealing is applied is decay_steps  warmup_steps.end_value (
float
) – End value of the scalar to be annealed.exponent (
float
) – Float. The default decay is 0.5 * (1 + cos(pi * t/T)), where t is the current timestep and T is the decay_steps. The exponent modifies this to be (0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.
 Returns
A function that maps step counts to values.
 Return type
schedule
 optax.warmup_exponential_decay_schedule(init_value, peak_value, warmup_steps, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None)[source]#
Linear warmup followed by exponential decay.
 Parameters
init_value (
float
) – Initial value for the scalar to be annealed.peak_value (
float
) – Peak value for scalar to be annealed at end of warmup.warmup_steps (
int
) – Positive integer, the length of the linear warmup.transition_steps (
int
) – must be positive. See exponential_decay for more details.decay_rate (
float
) – must not be zero. The decay rate.transition_begin (
int
) – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at peak_value).staircase (
bool
) – if True, decay the values at discrete intervals.end_value (
Optional
[float
]) – the value at which the exponential decay stops. When decay_rate < 1, end_value is treated as a lower bound, otherwise as an upper bound. Has no effect when decay_rate = 0.
 Returns
A function that maps step counts to values.
 Return type
schedule
 optax.inject_hyperparams(inner_factory, static_args=(), hyperparam_dtype=None)[source]#
Wrapper that injects hyperparameters into the inner GradientTransformation.
This wrapper allows you to pass schedules (i.e. a function that returns a numeric value given a step count) instead of constants for hyperparameters. You may only schedule numeric hyperparameters (i.e. boolean flags cannot be scheduled).
For example, to use
scale_by_adam
with a piecewise linear schedule for beta_1 and constant for beta_2:scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)( b1=optax.piecewise_linear_schedule(...), b2=0.99)
You may manually change numeric hyperparameters that were not scheduled through the
hyperparams
dict in theInjectHyperparamState
:state = scheduled_adam.init(params) updates, state = scheduled_adam.update(grads, state) state.hyperparams['b2'] = 0.95 updates, state = scheduled_adam.update(updates, state) # uses b2 = 0.95
Manually overriding scheduled hyperparameters will have no effect (e.g. in the code sample above, you cannot manually adjust
b1
). Parameters
inner_factory (
Callable
[…,optax.GradientTransformation
]) – a function that returns the inneroptax.GradientTransformation
given the hyperparameters.static_args (
Union
[str
,Iterable
[str
]]) – a string or iterable of strings specifying which callable parameters are not schedules. inject_hyperparams treats all callables as schedules by default, so if a hyperparameter is a nonschedule callable, you must specify that using this argument.hyperparam_dtype (
Optional
[dtype
]) – Optional datatype override. If specified, all float hyperparameters will be cast to this type.
 Return type
Callable
[…,optax.GradientTransformation
] Returns
A callable that returns a
optax.GradientTransformation
. This callable accepts the same arguments asinner_factory
, except you may provide schedules in place of the constant arguments.
 optax.Schedule#
alias of
Callable
[[Union
[jax.Array
,numpy.ndarray
,jax._src.interpreters.batching.BatchTracer
,jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase
,float
,int
]],Union
[jax.Array
,numpy.ndarray
,jax._src.interpreters.batching.BatchTracer
,jaxlib.xla_extension.pmap_lib.ShardedDeviceArrayBase
,float
,int
]]
 class optax.InjectHyperparamsState(count: jnp.ndarray, hyperparams: Dict[str, chex.Numeric], inner_state: base.OptState)[source]#
Maintains inner transform state, hyperparameters, and step count.
 count: jnp.ndarray#
Alias for field number 0
 hyperparams: Dict[str, chex.Numeric]#
Alias for field number 1
 inner_state: base.OptState#
Alias for field number 2
Second Order Optimization Utilities#

Computes the diagonal of the (observed) Fisher information matrix. 

Computes the diagonal hessian of loss at (inputs, targets). 

Performs an efficient vectorHessian (of loss) product. 
fisher_diag#
hessian_diag#
Control Variates#

The control delta covariate method. 

Obtain jacobians using control variates. 

A moving average baseline. 
control_delta_method#
 optax.control_delta_method(function)[source]#
The control delta covariate method.
 Control variate obtained by performing a second order Taylor expansion
on the cost function f at the mean of the input distribution.
Only implemented for Gaussian random variables.
For details, see: https://icml.cc/2012/papers/687.pdf
 Parameters
function (
Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]],float
]) – The function for which to compute the control variate. The function takes in one argument (a sample from the distribution) and returns a floating point value. Return type
Tuple
[Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
],Any
],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]],Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],Any
],Any
],Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
],Any
],Any
]] Returns
A tuple of three functions, to compute the control variate, the expected value of the control variate, and to update the control variate state.
control_variates_jacobians#
 optax.control_variates_jacobians(function, control_variate_from_function, grad_estimator, params, dist_builder, rng, num_samples, control_variate_state=None, estimate_cv_coeffs=False, estimate_cv_coeffs_num_samples=20)[source]#
Obtain jacobians using control variates.
 We will compute each term individually. The first term will use stochastic
gradient estimation. The second term will be computes using Monte Carlo estimation and automatic differentiation to compute nabla_{theta} h(x; theta). The the third term will be computed using automatic differentiation, as we restrict ourselves to control variates which compute this expectation in closed form.
 This function updates the state of the control variate (once), before
computing the control variate coefficients.
 Parameters
function (
Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]],float
]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.control_variate_from_function (
Callable
[[Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]],float
]],Tuple
[Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
],Any
],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]],Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],Any
],Any
],Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
],Any
],Any
]]]) – The control variate to use to reduce variance. See control_delta_method and moving_avg_baseline examples.grad_estimator (
Callable
[…,array
]) – The gradient estimator to be used to compute the gradients. Note that not all control variates will reduce variance for all estimators. For example, the moving_avg_baseline will make no difference to the measure valued or pathwise estimators.params (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – A tuple of jnp arrays. The parameters for which to construct the distribution and for which we want to compute the jacobians.dist_builder (
Callable
[…,Any
]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.rng (
PRNGKeyArray
) – a PRNGKey key.num_samples (
int
) – Int, the number of samples used to compute the grads.control_variate_state (
Optional
[Any
]) – The control variate state. This is used for control variates which keep states (such as the moving average baselines).estimate_cv_coeffs (
bool
) – Boolean. Whether or not to estimate the optimal control variate coefficient via estimate_control_variate_coefficients.estimate_cv_coeffs_num_samples (
int
) – The number of samples to use to estimate the optimal coefficient. These need to be new samples to ensure that the objective is unbiased.
 Returns
 A tuple of size params, each element is num_samples x param.shape
jacobian vector containing the estimates of the gradients obtained for each sample.
The mean of this vector is the gradient wrt to parameters that can be used for learning. The entire jacobian vector can be used to assess estimator variance.
The updated CV state.
 Return type
A tuple of size two
moving_avg_baseline#
 optax.moving_avg_baseline(function, decay=0.99, zero_debias=True, use_decay_early_training_heuristic=True)[source]#
A moving average baseline.
It has no effect on the pathwise or measure valued estimator.
 Parameters
function (
Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]],float
]) – The function for which to compute the control variate. The function takes in one argument (a sample from the distribution) and returns a floating point value.decay (
float
) – The decay rate for the moving average.zero_debias (
bool
) – Whether or not to use zero debiasing for the moving average.use_decay_early_training_heuristic –
Whether or not to use a heuristic which overrides the decay value early in training based on
min(decay, (1.0 + i) / (10.0 + i)). This stabilises training and was adapted from the Tensorflow codebase.
 Return type
Tuple
[Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
],Any
],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]],Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],Any
],Any
],Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]],Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
],Any
],Any
]] Returns
A tuple of three functions, to compute the control variate, the expected value of the control variate, and to update the control variate state.
Stochastic Gradient Estimators#

Measure valued gradient estimation. 

Pathwise gradient estimation. 

Score function gradient estimation. 
measure_valued_jacobians#
 optax.measure_valued_jacobians(function, params, dist_builder, rng, num_samples, coupling=True)[source]#
Measure valued gradient estimation.
 Approximates:
nabla_{theta} E_{p(x; theta)} f(x)
 With:
1./ c (E_{p1(x; theta)} f(x)  E_{p2(x; theta)} f(x)) where p1 and p2 are measures which depend on p.
Currently only supports computing gradients of expectations of Gaussian RVs.
 Parameters
function (
Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]],float
]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.params (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – A tuple of jnp arrays. The parameters for which to construct the distribution.dist_builder (
Callable
[…,Any
]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.rng (
PRNGKeyArray
) – a PRNGKey key.num_samples (
int
) – Int, the number of samples used to compute the grads.coupling (
bool
) – A boolean. Whether or not to use coupling for the positive and negative samples. Recommended: True, as this reduces variance.
 Return type
Sequence
[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]] Returns
 A tuple of size params, each element is num_samples x param.shape
jacobian vector containing the estimates of the gradients obtained for each sample.
 The mean of this vector is the gradient wrt to parameters that can be used
for learning. The entire jacobian vector can be used to assess estimator variance.
pathwise_jacobians#
 optax.pathwise_jacobians(function, params, dist_builder, rng, num_samples)[source]#
Pathwise gradient estimation.
 Approximates:
nabla_{theta} E_{p(x; theta)} f(x)
 With:
 E_{p(epsilon)} nabla_{theta} f(g(epsilon, theta))
where x = g(epsilon, theta). g depends on the distribution p.
 Requires: p to be reparametrizable and the reparametrization to be implemented
in tensorflow_probability. Applicable to continuous random variables. f needs to be differentiable.
 Parameters
function (
Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]],float
]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.params (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – A tuple of jnp arrays. The parameters for which to construct the distribution.dist_builder (
Callable
[…,Any
]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.rng (
PRNGKeyArray
) – a PRNGKey key.num_samples (
int
) – Int, the number of samples used to compute the grads.
 Return type
Sequence
[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]] Returns
 A tuple of size params, each element is num_samples x param.shape
jacobian vector containing the estimates of the gradients obtained for each sample.
 The mean of this vector is the gradient wrt to parameters that can be used
for learning. The entire jacobian vector can be used to assess estimator variance.
score_function_jacobians#
 optax.score_function_jacobians(function, params, dist_builder, rng, num_samples)[source]#
Score function gradient estimation.
 Approximates:
nabla_{theta} E_{p(x; theta)} f(x)
 With:
E_{p(x; theta)} f(x) nabla_{theta} log p(x; theta)
 Requires: p to be differentiable wrt to theta. Applicable to both continuous
and discrete random variables. No requirements on f.
 Parameters
function (
Callable
[[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]],float
]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.params (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – A tuple of jnp arrays. The parameters for which to construct the distribution.dist_builder (
Callable
[…,Any
]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.rng (
PRNGKeyArray
) – a PRNGKey key.num_samples (
int
) – Int, the number of samples used to compute the grads.
 Return type
Sequence
[Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
]] Returns
 A tuple of size params, each element is num_samples x param.shape
jacobian vector containing the estimates of the gradients obtained for each sample.
 The mean of this vector is the gradient wrt to parameters that can be used
for learning. The entire jacobian vector can be used to assess estimator variance.
PrivacySensitive Optax Methods#

State containing PRNGKey for differentially_private_aggregate. 
Aggregates gradients based on the DPSGD algorithm. 
differentially_private_aggregate#
 optax.differentially_private_aggregate(l2_norm_clip, noise_multiplier, seed)[source]#
Aggregates gradients based on the DPSGD algorithm.
WARNING: Unlike other transforms, differentially_private_aggregate expects the input updates to have a batch dimension in the 0th axis. That is, this function expects perexample gradients as input (which are easy to obtain in JAX using jax.vmap). It can still be composed with other transformations as long as it is the first in the chain.
References
[Abadi et al, 2016](https://arxiv.org/abs/1607.00133)
 Parameters
l2_norm_clip (
float
) – maximum L2 norm of the perexample gradients.noise_multiplier (
float
) – ratio of standard deviation to the clipping norm.seed (
int
) – initial seed used for the jax.random.PRNGKey
 Return type
 Returns
A GradientTransformation.
General Utilities#



Scales gradients for the backwards pass. 
scale_gradient#
 optax.scale_gradient(inputs, scale)[source]#
Scales gradients for the backwards pass.
 Parameters
inputs (
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]]) – A nested array.scale (
float
) – The scale factor for the gradient on the backwards pass.
 Return type
Union
[Array
,ndarray
,BatchTracer
,ShardedDeviceArrayBase
,Iterable
[ForwardRef
],Mapping
[Any
,ForwardRef
]] Returns
An array of the same structure as inputs, with scaled backward gradient.
🚧 Experimental#

Splits the real and imaginary components of complex updates into two. 

Maintains the inner transformation state for split_real_and_imaginary. 
ComplexValued Optimization#
 optax.experimental.split_real_and_imaginary(inner)[source]#
Splits the real and imaginary components of complex updates into two.
The inner transformation processes real parameters and updates, and the pairs of transformed real updates are merged into complex updates.
Parameters and updates that are real before splitting are passed through unmodified.
 Parameters
inner (
optax.GradientTransformation
) – The inner transformation. Return type
 Returns
An optax.GradientTransformation.