Common Optimizers#

adabelief(learning_rate[, b1, b2, eps, eps_root])

The AdaBelief optimizer.

adafactor([learning_rate, ...])

The Adafactor optimizer.

adagrad(learning_rate[, ...])

The Adagrad optimizer.

adam(learning_rate[, b1, b2, eps, eps_root, ...])

The classic Adam optimizer.

adamw(learning_rate[, b1, b2, eps, ...])

Adam with weight decay regularization.

adamax(learning_rate[, b1, b2, eps])

A variant of the Adam optimizer that uses the infinity norm.

adamaxw(learning_rate[, b1, b2, eps, ...])

Adamax with weight decay regularization.

amsgrad(learning_rate[, b1, b2, eps, ...])

The AMSGrad optimiser.

fromage(learning_rate[, min_norm])

The Frobenius matched gradient descent (Fromage) optimizer.

lamb(learning_rate[, b1, b2, eps, eps_root, ...])

The LAMB optimizer.

lars(learning_rate[, weight_decay, ...])

The LARS optimizer.

lion(learning_rate[, b1, b2, mu_dtype, ...])

The Lion optimizer.

noisy_sgd(learning_rate[, eta, gamma, seed])

A variant of SGD with added noise.

novograd(learning_rate[, b1, b2, eps, ...])

NovoGrad optimizer.

optimistic_gradient_descent(learning_rate[, ...])

An Optimistic Gradient Descent optimizer.

dpsgd(learning_rate, l2_norm_clip, ...[, ...])

The DPSGD optimizer.

radam(learning_rate[, b1, b2, eps, ...])

The Rectified Adam optimizer.

rmsprop(learning_rate[, decay, eps, ...])

A flexible RMSProp optimizer.

sgd(learning_rate[, momentum, nesterov, ...])

A canonical Stochastic Gradient Descent optimizer.

sm3(learning_rate[, momentum])

The SM3 optimizer.

yogi(learning_rate[, b1, b2, eps])

The Yogi optimizer.

AdaBelief#

optax.adabelief(learning_rate, b1=0.9, b2=0.999, eps=1e-16, eps_root=1e-16)[source]#

The AdaBelief optimizer.

AdaBelief is an adaptive learning rate optimizer that focuses on fast convergence, generalization, and stability. It adapts the step size depending on its “belief” in the gradient direction — the optimizer adaptively scales the step size by the difference between the predicted and observed gradients. AdaBelief is a modified version of Adam and contains the same number of parameters.

References

Zhuang et al, 2020: https://arxiv.org/abs/2010.07468

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

AdaGrad#

optax.adagrad(learning_rate, initial_accumulator_value=0.1, eps=1e-07)[source]#

The Adagrad optimizer.

Adagrad is an algorithm for gradient based optimization that anneals the learning rate for each parameter during the course of training.

WARNING: Adagrad’s main limit is the monotonic accumulation of squared gradients in the denominator: since all terms are >0, the sum keeps growing during training and the learning rate eventually becomes vanishingly small.

References

Duchi et al, 2011: https://jmlr.org/papers/v12/duchi11a.html

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • initial_accumulator_value (float) – Initial value for the accumulator.

  • eps (float) – A small constant applied to denominator inside of the square root (as in RMSProp) to avoid dividing by zero when rescaling.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

AdaFactor#

optax.adafactor(learning_rate=None, min_dim_size_to_factor=128, decay_rate=0.8, decay_offset=0, multiply_by_parameter_scale=True, clipping_threshold=1.0, momentum=None, dtype_momentum=<class 'jax.numpy.float32'>, weight_decay_rate=None, eps=1e-30, factored=True, weight_decay_mask=None)[source]#

The Adafactor optimizer.

Adafactor is an adaptive learning rate optimizer that focuses on fast training of large scale neural networks. It saves memory by using a factored estimate of the second order moments used to scale gradients.

References

Shazeer and Stern, 2018: https://arxiv.org/abs/1804.04235

Parameters
  • learning_rate (Optional[ScalarOrSchedule]) – A fixed global scaling factor. Note: the natural scale for Adafactor’s LR is markedly different from Adam, one doesn’t use the 1/sqrt(hidden) correction for this optim with attention-based models.

  • min_dim_size_to_factor (int) – Only factor the statistics if two array dimensions have at least this size.

  • decay_rate (float) – Controls second-moment exponential decay schedule.

  • decay_offset (int) – For fine-tuning, one may set this to the starting step number of the fine-tuning phase.

  • multiply_by_parameter_scale (float) – If True, then scale learning_rate by parameter norm. If False, provided learning_rate is absolute step size.

  • clipping_threshold (Optional[float]) – Optional clipping threshold. Must be >= 1. If None, clipping is disabled.

  • momentum (Optional[float]) – Optional value between 0 and 1, enables momentum and uses extra memory if non-None! None by default.

  • dtype_momentum (Any) – Data type of momentum buffers.

  • weight_decay_rate (Optional[float]) – Optional rate at which to decay weights.

  • eps (float) – Regularization constant for root mean squared gradient.

  • factored (bool) – Whether to use factored second-moment estimates.

  • weight_decay_mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

Adam#

optax.adam(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None)[source]#

The classic Adam optimizer.

Adam is an SGD variant with gradient scaling adaptation. The scaling used for each parameter is computed from estimates of first and second-order moments of the gradients (using suitable exponential moving averages).

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments b1, b2, eps and eps_root respectievly. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.

The init function of this optimizer initializes an internal state \(S_0 := (m_0, v_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ v_t &\leftarrow \beta_2 \cdot v_{t-1} + (1-\beta_2) \cdot {g_t}^2 \\ \hat{m}_t &\leftarrow m_t / {(1-\beta_1^t)} \\ \hat{v}_t &\leftarrow v_t / {(1-\beta_2^t)} \\ u_t &\leftarrow \alpha_t \cdot \hat{m}_t / \left({\sqrt{\hat{v}_t + \bar{\varepsilon}} + \varepsilon} \right)\\ S_t &\leftarrow (m_t, v_t). \end{align*} \]

References

Kingma et al, 2014: https://arxiv.org/abs/1412.6980

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for example when computing (meta-)gradients through Adam.

  • mu_dtype (Optional[Any, None]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

Adamax#

optax.adamax(learning_rate, b1=0.9, b2=0.999, eps=1e-08)[source]#

A variant of the Adam optimizer that uses the infinity norm.

References

Kingma et al, 2014: https://arxiv.org/abs/1412.6980

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the maximum of past gradients.

  • eps (float) – A small constant applied to denominator to avoid dividing by zero when rescaling.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

AdamaxW#

optax.adamaxw(learning_rate, b1=0.9, b2=0.999, eps=1e-08, weight_decay=0.0001, mask=None)[source]#

Adamax with weight decay regularization.

AdamaxW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam.

WARNING: Sometimes you may want to skip weight decay for BatchNorm scale or for the bias parameters. You can use optax.masked to make your own AdamaxW variant where additive_weight_decay is applied only to a subset of params.

References

Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101

Parameters
  • learning_rate (ScalarOrSchedule) – A fixed global scaling factor.

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the maximum of past gradients.

  • eps (float) – A small constant applied to denominator to avoid dividing by zero when rescaling.

  • weight_decay (float) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adamax gradient transformations are applied to all parameters.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

AdamW#

optax.adamw(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None, weight_decay=0.0001, mask=None)[source]#

Adam with weight decay regularization.

AdamW uses weight decay to regularize learning towards small weights, as this leads to better generalization. In SGD you can also use L2 regularization to implement this as an additive loss term, however L2 regularization does not behave as intended for adaptive gradient algorithms such as Adam.

References

Loshchilov et al, 2019: https://arxiv.org/abs/1711.05101

Parameters
  • learning_rate (ScalarOrSchedule) – A fixed global scaling factor.

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • mu_dtype (Optional[Any]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • weight_decay (float) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

AMSGrad#

optax.amsgrad(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None)[source]#

The AMSGrad optimiser.

The original Adam can fail to converge to the optimal solution in some cases. AMSGrad guarantees convergence by using a long-term memory of past gradients.

References

Reddi et al, 2018: https://openreview.net/forum?id=ryQu7f-RZ

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • mu_dtype (Optional[Any, None]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

Fromage#

optax.fromage(learning_rate, min_norm=1e-06)[source]#

The Frobenius matched gradient descent (Fromage) optimizer.

Fromage is a learning algorithm that does not require learning rate tuning. The optimizer is based on modeling neural network gradients via deep relative trust (a distance function on deep neural networks). Fromage is similar to the LARS optimizer and can work on a range of standard neural network benchmarks, such as natural language Transformers and generative adversarial networks.

References

Bernstein et al, 2020: https://arxiv.org/abs/2002.03432

Parameters
  • learning_rate (float) – A fixed global scaling factor.

  • min_norm (float) – A minimum value that the norm of the gradient updates and the norm of the layer parameters can be clipped to to avoid dividing by zero when computing the trust ratio (as in the LARS paper).

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

Lamb#

optax.lamb(learning_rate, b1=0.9, b2=0.999, eps=1e-06, eps_root=0.0, weight_decay=0.0, mask=None)[source]#

The LAMB optimizer.

LAMB is a general purpose layer-wise adaptive large batch optimizer designed to provide consistent training performance across a wide range of tasks, including those that use attention-based models (such as Transformers) and ResNet-50. The optimizer is able to work with small and large batch sizes. LAMB was inspired by the LARS learning algorithm.

References

You et al, 2019: https://arxiv.org/abs/1904.00962

Parameters
  • learning_rate (ScalarOrSchedule) – A fixed global scaling factor.

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • weight_decay (float) – Strength of the weight decay regularization.

  • mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

Lars#

optax.lars(learning_rate, weight_decay=0.0, weight_decay_mask=True, trust_coefficient=0.001, eps=0.0, trust_ratio_mask=True, momentum=0.9, nesterov=False)[source]#

The LARS optimizer.

LARS is a layer-wise adaptive optimizer introduced to help scale SGD to larger batch sizes. LARS later inspired the LAMB optimizer.

References

You et al, 2017: https://arxiv.org/abs/1708.03888

Parameters
  • learning_rate (ScalarOrSchedule) – A fixed global scaling factor.

  • weight_decay (float) – Strength of the weight decay regularization.

  • weight_decay_mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

  • trust_coefficient (float) – A multiplier for the trust ratio.

  • eps (float) – Optional additive constant in the trust ratio denominator.

  • trust_ratio_mask (MaskOrFn) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

  • momentum (float) – Decay rate for momentum.

  • nesterov (bool) – Whether to use Nesterov momentum.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

Lion#

optax.lion(learning_rate, b1=0.9, b2=0.99, mu_dtype=None, weight_decay=0.001, mask=None)[source]#

The Lion optimizer.

Lion is discovered by symbolic program search. Unlike most adaptive optimizers such as AdamW, Lion only tracks momentum, making it more memory-efficient. The update of Lion is produced through the sign operation, resulting in a larger norm compared to updates produced by other optimizers such as SGD and AdamW. A suitable learning rate for Lion is typically 3-10x smaller than that for AdamW, the weight decay for Lion should be in turn 3-10x larger than that for AdamW to maintain a similar strength (lr * wd).

References

Chen et al, 2023: https://arxiv.org/abs/2302.06675

Parameters
  • learning_rate (ScalarOrSchedule) – A fixed global scaling factor.

  • b1 (float) – Rate to combine the momentum and the current gradient.

  • b2 (float) – Exponential decay rate to track the momentum of past gradients.

  • mu_dtype (Optional[Any]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • weight_decay (float) – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

SM3#

optax.sm3(learning_rate, momentum=0.9)[source]#

The SM3 optimizer.

SM3 (Square-root of Minima of Sums of Maxima of Squared-gradients Method) is a memory-efficient adaptive optimizer designed to decrease memory overhead when training very large models, such as the Transformer for machine translation, BERT for language modeling, and AmoebaNet-D for image classification. SM3: 1) applies to tensors of arbitrary dimensions and any predefined cover of the parameters; 2) adapts the learning rates in an adaptive and data-driven manner (like Adagrad and unlike Adafactor); and 3) comes with rigorous convergence guarantees in stochastic convex optimization settings.

References

Anil et al, 2019: https://arxiv.org/abs/1901.11150

Parameters
  • learning_rate (float) – A fixed global scaling factor.

  • momentum (float) – Decay rate used by the momentum term (when it is not set to None, then momentum is not used at all).

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

Noisy SGD#

optax.noisy_sgd(learning_rate, eta=0.01, gamma=0.55, seed=0)[source]#

A variant of SGD with added noise.

It has been found that adding noise to the gradients can improve both the training error and the generalization error in very deep networks.

References

Neelakantan et al, 2014: https://arxiv.org/abs/1511.06807

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • eta (float) – Initial variance for the Gaussian noise added to gradients.

  • gamma (float) – A parameter controlling the annealing of noise over time, the variance decays according to (1+t)^-gamma.

  • seed (int) – Seed for the pseudo-random generation process.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

Novograd#

optax.novograd(learning_rate, b1=0.9, b2=0.25, eps=1e-06, eps_root=0.0, weight_decay=0.0)[source]#

NovoGrad optimizer.

NovoGrad is more robust to the initial learning rate and weight initialization than other methods. For example, NovoGrad works well without LR warm-up, while other methods require it. NovoGrad performs exceptionally well for large batch training, e.g. it outperforms other methods for ResNet-50 for all batches up to 32K. In addition, NovoGrad requires half the memory compared to Adam. It was introduced together with Jasper ASR model.

References

Ginsburg et al, 2019: https://arxiv.org/abs/1905.11286 Li et al, 2019: https://arxiv.org/abs/1904.03288

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • b1 (float) – An exponential decay rate to track the first moment of past gradients.

  • b2 (float) – An exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • weight_decay (float) – Strength of the weight decay regularization.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

Optimistic GD#

optax.optimistic_gradient_descent(learning_rate, alpha=1.0, beta=1.0)[source]#

An Optimistic Gradient Descent optimizer.

Optimistic gradient descent is an approximation of extra-gradient methods which require multiple gradient calls to compute the next update. It has strong formal guarantees for last-iterate convergence in min-max games, for which standard gradient descent can oscillate or even diverge.

References

Mokhtari et al, 2019: https://arxiv.org/abs/1901.08511v2

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • alpha (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – Coefficient for generalized OGD.

  • beta (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – Coefficient for generalized OGD negative momentum.

Return type

optax.GradientTransformation

Returns

A GradientTransformation.

Differentially Private SGD#

optax.dpsgd(learning_rate, l2_norm_clip, noise_multiplier, seed, momentum=None, nesterov=False)[source]#

The DPSGD optimizer.

Differential privacy is a standard for privacy guarantees of algorithms learning from aggregate databases including potentially sensitive information. DPSGD offers protection against a strong adversary with full knowledge of the training mechanism and access to the model’s parameters.

WARNING: This GradientTransformation expects input updates to have a batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap).

References

Abadi et al, 2016: https://arxiv.org/abs/1607.00133

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • l2_norm_clip (float) – Maximum L2 norm of the per-example gradients.

  • noise_multiplier (float) – Ratio of standard deviation to the clipping norm.

  • seed (int) – Initial seed used for the jax.random.PRNGKey

  • momentum (Optional[float, None]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov (bool) – Whether Nesterov momentum is used.

Return type

optax.GradientTransformation

Returns

A GradientTransformation.

RAdam#

optax.radam(learning_rate, b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, threshold=5.0)[source]#

The Rectified Adam optimizer.

The adaptive learning rate in Adam has undesirably large variance in early stages of training, due to the limited number of training samples used to estimate the optimizer’s statistics. Rectified Adam addresses this issue by analytically reducing the large variance.

References

Kingma et al, 2014: https://arxiv.org/abs/1412.6980

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root (float) – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • threshold (float) – Threshold for variance tractability.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

RMSProp#

optax.rmsprop(learning_rate, decay=0.9, eps=1e-08, initial_scale=0.0, centered=False, momentum=None, nesterov=False)[source]#

A flexible RMSProp optimizer.

RMSProp is an SGD variant with learning rate adaptation. The learning_rate used for each weight is scaled by a suitable estimate of the magnitude of the gradients on previous steps. Several variants of RMSProp can be found in the literature. This alias provides an easy to configure RMSProp optimizer that can be used to switch between several of these variants.

References

Tieleman and Hinton, 2012: http://www.cs.toronto.edu/~hinton/coursera/lecture6/lec6.pdf Graves, 2013: https://arxiv.org/abs/1308.0850

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • decay (float) – Decay used to track the magnitude of previous gradients.

  • eps (float) – A small numerical constant to avoid dividing by zero when rescaling.

  • initial_scale (float) – Initial value of accumulators tracking the magnitude of previous updates. PyTorch uses 0, TF1 uses 1. When reproducing results from a paper, verify the value used by the authors.

  • centered (bool) – Whether the second moment or the variance of the past gradients is used to rescale the latest gradients.

  • momentum (Optional[float, None]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov (bool) – Whether Nesterov momentum is used.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

SGD#

optax.sgd(learning_rate, momentum=None, nesterov=False, accumulator_dtype=None)[source]#

A canonical Stochastic Gradient Descent optimizer.

This implements stochastic gradient descent. It also includes support for momentum, and nesterov acceleration, as these are standard practice when using stochastic gradient descent to train deep neural networks.

References

Sutskever et al, 2013: http://proceedings.mlr.press/v28/sutskever13.pdf

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • momentum (Optional[float, None]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov (bool) – Whether Nesterov momentum is used.

  • accumulator_dtype (Optional[Any, None]) – Optional dtype to be used for the accumulator; if None then the dtype is inferred from params and updates.

Return type

optax.GradientTransformation

Returns

A GradientTransformation.

Yogi#

optax.yogi(learning_rate, b1=0.9, b2=0.999, eps=0.001)[source]#

The Yogi optimizer.

Yogi is an adaptive optimizer, which provides control in tuning the effective learning rate to prevent it from increasing. By doing so, it focuses on addressing the issues of convergence and generalization in exponential moving average-based adaptive methods (such as Adam and RMSprop). Yogi is a modification of Adam and uses the same parameters.

References

Zaheer et al, 2018: https://proceedings.neurips.cc/paper/2018/file/90365351ccc7437a1309dc64e4db32a3-Paper.pdf

Parameters
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

Optax Transformations#

adaptive_grad_clip(clipping[, eps])

Clips updates to be at most clipping * parameter_norm, unit-wise.

add_decayed_weights([weight_decay, mask])

Add parameter scaled by weight_decay.

add_noise(eta, gamma, seed)

Add gradient noise.

AddDecayedWeightsState

alias of optax._src.base.EmptyState

AddNoiseState(count, rng_key)

State for adding gradient noise.

apply_every([k])

Accumulate gradients and apply them every k steps.

ApplyEvery(count, grad_acc)

Contains a counter and a gradient accumulator.

bias_correction(moment, decay, count)

Performs bias correction.

centralize()

Centralize gradients.

clip(max_delta)

Clips updates element-wise, to be in [-max_delta, +max_delta].

clip_by_block_rms(threshold)

Clips updates to a max rms for the gradient of each param vector or matrix.

clip_by_global_norm(max_norm)

Clips updates using their global norm.

ClipByGlobalNormState

alias of optax._src.base.EmptyState

ClipState

alias of optax._src.base.EmptyState

ema(decay[, debias, accumulator_dtype])

Compute an exponential moving average of past updates.

EmaState(count, ema)

Holds an exponential moving average of past updates.

EmptyState()

An empty state for the simplest stateless transformations.

FactoredState(count, v_row, v_col, v)

Overall state of the gradient transformation.

global_norm(updates)

Compute the global norm across a nested structure of tensors.

GradientTransformation(init, update)

A pair of pure functions implementing a gradient transformation.

GradientTransformationExtraArgs(init, update)

A specialization of GradientTransformation that supports extra args.

identity()

Stateless identity transformation that leaves input gradients untouched.

keep_params_nonnegative()

Modifies the updates to keep parameters non-negative, i.e. >= 0.

NonNegativeParamsState

alias of optax._src.base.EmptyState

OptState

alias of Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

Params

alias of Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

scale(step_size)

Scale updates by some fixed scalar step_size.

scale_by_adam([b1, b2, eps, eps_root, mu_dtype])

Rescale updates according to the Adam algorithm.

scale_by_adamax([b1, b2, eps])

Rescale updates according to the Adamax algorithm.

scale_by_amsgrad([b1, b2, eps, eps_root, ...])

Rescale updates according to the AMSGrad algorithm.

scale_by_belief([b1, b2, eps, eps_root])

Rescale updates according to the AdaBelief algorithm.

scale_by_factored_rms([factored, ...])

Scaling by a factored estimate of the gradient rms (as in Adafactor).

scale_by_lion([b1, b2, mu_dtype])

Rescale updates according to the Lion algorithm.

scale_by_novograd([b1, b2, eps, eps_root, ...])

Computes NovoGrad updates.

scale_by_optimistic_gradient([alpha, beta])

Compute generalized optimistic gradients.

scale_by_param_block_norm([min_scale])

Scale updates for each param block by the norm of that block's parameters.

scale_by_param_block_rms([min_scale])

Scale updates by rms of the gradient for each param vector or matrix.

scale_by_radam([b1, b2, eps, eps_root, ...])

Rescale updates according to the Rectified Adam algorithm.

scale_by_rms([decay, eps, initial_scale])

Rescale updates by the root of the exp.

scale_by_rss([initial_accumulator_value, eps])

Rescale updates by the root of the sum of all squared gradients to date.

scale_by_schedule(step_size_fn)

Scale updates using a custom schedule for the step_size.

scale_by_sm3([b1, b2, eps])

Scale updates by sm3.

scale_by_stddev([decay, eps, initial_scale])

Rescale updates by the root of the centered exp.

scale_by_trust_ratio([min_norm, ...])

Scale updates by trust ratio.

scale_by_yogi([b1, b2, eps, eps_root, ...])

Rescale updates according to the Yogi algorithm.

ScaleByAdamState(count, mu, nu)

State for the Adam algorithm.

ScaleByAmsgradState(count, mu, nu, nu_max)

State for the AMSGrad algorithm.

ScaleByLionState(count, mu)

State for the Lion algorithm.

ScaleByNovogradState(count, mu, nu)

State for Novograd.

ScaleByRmsState(nu)

State for exponential root mean-squared (RMS)-normalized updates.

ScaleByRssState(sum_of_squares)

State holding the sum of gradient squares to date.

ScaleByRStdDevState(mu, nu)

State for centered exponential moving average of squares of updates.

ScaleByScheduleState(count)

Maintains count for scale scheduling.

ScaleByTrustRatioState()

The scale and decay trust ratio transformation is stateless.

ScaleBySM3State(mu, nu)

State for the SM3 algorithm.

ScaleState

alias of optax._src.base.EmptyState

stateless(f)

Creates a stateless transformation from an update-like function.

stateless_with_tree_map(f)

Creates a stateless transformation from an update-like function for arrays.

set_to_zero()

Stateless transformation that maps input gradients to zero.

trace(decay[, nesterov, accumulator_dtype])

Compute a trace of past updates.

tree_map_params(initable, f, state, /, *rest)

Apply a callable over all params in the given optimizer state.

TraceState(trace)

Holds an aggregation of past updates.

TransformInitFn(*args, **kwargs)

A callable type for the init step of a GradientTransformation.

TransformUpdateFn(*args, **kwargs)

A callable type for the update step of a GradientTransformation.

update_infinity_moment(updates, moments, ...)

Compute the exponential moving average of the infinity norm.

update_moment(updates, moments, decay, order)

Compute the exponential moving average of the order-th moment.

update_moment_per_elem_norm(updates, ...)

Compute the EMA of the order-th moment of the element-wise norm.

Updates

alias of Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

zero_nans()

A transformation which replaces NaNs with 0.

ZeroNansState(found_nan)

Contains a tree.

with_extra_args_support(tx)

Wraps a gradient transformation, so that it ignores extra args.

Optax Types#

class optax.GradientTransformation(init: TransformInitFn, update: TransformUpdateFn)[source]#

A pair of pure functions implementing a gradient transformation.

Optax optimizers are all implemented as _gradient transformations_. A gradient transformation is defined to be a pair of pure functions, which are combined together in a NamedTuple so that they can be referred to by name.

Note that an extended API is provided for users wishing to build optimizers that take additional arguments during the update step. For more details, see GradientTransoformationExtraArgs.

Since gradient transformations do not contain any internal state, all stateful optimizer properties (such as the current step count when using optimizer scheduels, or momemtum values) are passed through optax gradient transformations by using the optimizer _state_ pytree. Each time a gradient transformation is applied, a new state is computed and returned, ready to be passed to the next call to the gradient transformation.

Since gradient transformations are pure, idempotent functions, the only way to change the behaviour of a gradient transformation between steps, is to change the values in the optimizer state. To see an example of mutating the optimizer state in order to control the behaviour of an optax gradient transformation, see the meta-learning example in the optax documentation.

init#

A pure function which, when called with an example instance of the parameters whose gradients will be transformed, returns a pytree containing the initial value for the optimizer state.

Type

TransformInitFn

update#

A pure function which takes as input a pytree of updates (with the same tree structure as the original params pytree passed to init), the previous optimizer state (which may have been initialized using the init function), and optionally the current params. The update function then returns the computed gradient updates, and a new optimizer state.

Type

TransformUpdateFn

init: TransformInitFn#

Alias for field number 0

update: TransformUpdateFn#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.TransformInitFn(*args, **kwargs)[source]#

A callable type for the init step of a GradientTransformation.

The init step takes a tree of params and uses these to construct an arbitrary structured initial state for the gradient transformation. This may hold statistics of the past updates or any other non static information.

__call__(params)[source]#

The init function.

Parameters

params (Params) – The initial value of the parameters.

Return type

OptState

Returns

The initial state of the gradient transformation.

__init__(*args, **kwargs)[source]#
__subclasshook__()[source]#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).

class optax.TransformUpdateFn(*args, **kwargs)[source]#

A callable type for the update step of a GradientTransformation.

The update step takes a tree of candidate parameter updates (e.g. their gradient with respect to some loss), an arbitrary structured state, and the current params of the model being optimised. The params argument is optional, it must however be provided when using transformations that require access to the current values of the parameters.

For the case where additional arguments are required, an alternative interface may be used, see TransformUpdateExtraArgsFn for details.

__call__(updates, state, params=None)[source]#

The update function.

Parameters
  • updates (Updates) – A tree of candidate updates.

  • state (OptState) – The state of the gradient transformation.

  • params (Optional[Params]) – (Optionally) the current value of the parameters.

Return type

Tuple[Updates, OptState]

Returns

The transformed updates, and the updated state.

__init__(*args, **kwargs)[source]#
__subclasshook__()[source]#

Abstract classes can override this to customize issubclass().

This is invoked early on by abc.ABCMeta.__subclasscheck__(). It should return True, False or NotImplemented. If it returns NotImplemented, the normal algorithm is used. Otherwise, it overrides the normal algorithm (and the outcome is cached).

optax.OptState#

alias of Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

optax.Params#

alias of Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

optax.Updates#

alias of Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, Iterable[ArrayTree], Mapping[Any, ArrayTree]]

Optax Transforms and States#

optax.adaptive_grad_clip(clipping, eps=0.001)[source]#

Clips updates to be at most clipping * parameter_norm, unit-wise.

References

[Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image Recognition Without Normalization. (https://arxiv.org/abs/2102.06171)

Parameters
  • clipping (float) – The maximum allowed ratio of update norm to parameter norm.

  • eps (float) – An epsilon term to prevent clipping of zero-initialized params.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.AdaptiveGradClipState[source]#

alias of optax._src.base.EmptyState

optax.add_decayed_weights(weight_decay=0.0, mask=None)[source]#

Add parameter scaled by weight_decay.

Parameters
  • weight_decay (Union[float, jax.Array]) – A scalar weight decay rate.

  • mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.add_noise(eta, gamma, seed)[source]#

Add gradient noise.

References

[Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807)

Parameters
  • eta (float) – Base variance of the gaussian noise added to the gradient.

  • gamma (float) – Decay exponent for annealing of the variance.

  • seed (int) – Seed for random number generation.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.AddDecayedWeightsState[source]#

alias of optax._src.base.EmptyState

class optax.AddNoiseState(count: chex.Array, rng_key: chex.PRNGKey)[source]#

State for adding gradient noise. Contains a count for annealing.

count: chex.Array#

Alias for field number 0

rng_key: chex.PRNGKey#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.apply_every(k=1)[source]#

Accumulate gradients and apply them every k steps.

Note that if this transformation is part of a chain, the states of the other transformations will still be updated at every step. In particular, using apply_every with a batch size of N/2 and k=2 is not necessarily equivalent to not using apply_every with a batch size of N. If this equivalence is important for you, consider using the optax.MultiSteps.

Parameters

k (int) – Emit non-zero gradients every k steps, otherwise accumulate them.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

class optax.ApplyEvery(count: chex.Array, grad_acc: base.Updates)[source]#

Contains a counter and a gradient accumulator.

count: chex.Array#

Alias for field number 0

grad_acc: base.Updates#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.centralize()[source]#

Centralize gradients.

References

[Yong et al, 2020](https://arxiv.org/abs/2004.01461)

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.clip(max_delta)[source]#

Clips updates element-wise, to be in [-max_delta, +max_delta].

Parameters

max_delta (Union[Array, ndarray, bool_, number, float, int]) – The maximum absolute value for each element in the update.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.clip_by_block_rms(threshold)[source]#

Clips updates to a max rms for the gradient of each param vector or matrix.

A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

Parameters

threshold (float) – The maximum rms for the gradient of each param vector or matrix.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.clip_by_global_norm(max_norm)[source]#

Clips updates using their global norm.

References

[Pascanu et al, 2012](https://arxiv.org/abs/1211.5063)

Parameters

max_norm (float) – The maximum global norm for an update.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.ClipByGlobalNormState[source]#

alias of optax._src.base.EmptyState

optax.ClipState[source]#

alias of optax._src.base.EmptyState

optax.ema(decay, debias=True, accumulator_dtype=None)[source]#

Compute an exponential moving average of past updates.

Note: trace and ema have very similar but distinct updates; ema = decay * ema + (1-decay) * t, while trace = decay * trace + t. Both are frequently found in the optimization literature.

Parameters
  • decay (float) – Decay rate for the exponential moving average.

  • debias (bool) – Whether to debias the transformed gradient.

  • accumulator_dtype (Optional[Any, None]) – Optional dtype to used for the accumulator; if None then the dtype is inferred from params and updates.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

class optax.EmaState(count: chex.Array, ema: base.Params)[source]#

Holds an exponential moving average of past updates.

count: chex.Array#

Alias for field number 0

ema: base.Params#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.EmptyState[source]#

An empty state for the simplest stateless transformations.

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

static __new__(_cls)#

Create new instance of EmptyState()

class optax.FactoredState(count: chex.Array, v_row: chex.ArrayTree, v_col: chex.ArrayTree, v: chex.ArrayTree)[source]#

Overall state of the gradient transformation.

count: chex.Array#

Alias for field number 0

v_row: chex.ArrayTree#

Alias for field number 1

v_col: chex.ArrayTree#

Alias for field number 2

v: chex.ArrayTree#

Alias for field number 3

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.global_norm(updates)[source]#

Compute the global norm across a nested structure of tensors.

Return type

Union[Array, ndarray, bool_, number]

optax.identity()[source]#

Stateless identity transformation that leaves input gradients untouched.

This function passes through the gradient updates unchanged.

Note, this should not to be confused with set_to_zero, which maps the input updates to zero - which is the transform required for the model parameters to be left unchanged when the updates are applied to them.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.keep_params_nonnegative()[source]#

Modifies the updates to keep parameters non-negative, i.e. >= 0.

This transformation ensures that parameters after the update will be larger than or equal to zero. In a chain of transformations, this should be the last one.

WARNING: the transformation expects input params to be non-negative. When params is negative the transformed update will move them to 0.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.NonNegativeParamsState[source]#

alias of optax._src.base.EmptyState

optax.scale(step_size)[source]#

Scale updates by some fixed scalar step_size.

Parameters

step_size (float) – A scalar corresponding to a fixed scaling factor for updates.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None)[source]#

Rescale updates according to the Adam algorithm.

References

[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)

Parameters
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • mu_dtype (Optional[_ScalarMeta, None]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_adamax(b1=0.9, b2=0.999, eps=1e-08)[source]#

Rescale updates according to the Adamax algorithm.

References

[Kingma et al, 2014](https://arxiv.org/abs/1412.6980)

Parameters
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted maximum of grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_amsgrad(b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, mu_dtype=None)[source]#

Rescale updates according to the AMSGrad algorithm.

References

[Reddi et al, 2018](https://openreview.net/forum?id=ryQu7f-RZ)

Parameters
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • mu_dtype (Optional[_ScalarMeta, None]) – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_belief(b1=0.9, b2=0.999, eps=1e-16, eps_root=1e-16)[source]#

Rescale updates according to the AdaBelief algorithm.

References

[Zhuang et al, 2020](https://arxiv.org/abs/2010.07468)

Parameters
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of variance of grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_factored_rms(factored=True, decay_rate=0.8, step_offset=0, min_dim_size_to_factor=128, epsilon=1e-30, decay_rate_fn=<function _decay_rate_pow>)[source]#

Scaling by a factored estimate of the gradient rms (as in Adafactor).

This is a so-called “1+epsilon” scaling algorithms, that is extremely memory efficient compared to RMSProp/Adam, and has had wide success when applied to large-scale training of attention-based models.

References

[Shazeer et al, 2018](https://arxiv.org/abs/1804.04235)

Parameters
  • factored (bool) – boolean: whether to use factored second-moment estimates..

  • decay_rate (float) – float: controls second-moment exponential decay schedule.

  • step_offset (int) – for finetuning, one may set this to the starting step-number of the fine tuning phase.

  • min_dim_size_to_factor (int) – only factor accumulator if two array dimensions are at least this size.

  • epsilon (float) – Regularization constant for squared gradient.

  • decay_rate_fn (Callable[[int, float], Union[Array, ndarray, bool_, number]]) – A function that accepts the current step, the decay rate parameter and controls the schedule for the second momentum. Defaults to the original adafactor’s power decay schedule. One potential shortcoming of the orignal schedule is the fact that second momentum converges to 1, which effectively freezes the second momentum. To prevent this the user can opt for a custom schedule that sets an upper bound for the second momentum, like in [Zhai et al., 2021](https://arxiv.org/abs/2106.04560).

Returns

the corresponding GradientTransformation.

optax.scale_by_lion(b1=0.9, b2=0.99, mu_dtype=None)[source]#

Rescale updates according to the Lion algorithm.

References

[Chen et al, 2023](https://arxiv.org/abs/2302.06675)

Parameters
  • b1 (float) – Rate for combining the momentum and the current grad.

  • b2 (float) – Decay rate for the exponentially weighted average of grads.

  • mu_dtype (Optional[_ScalarMeta, None]) – Optional dtype to be used for the momentum; if None then the dtype is inferred from `params and updates.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_novograd(b1=0.9, b2=0.25, eps=1e-08, eps_root=0.0, weight_decay=0.0, mu_dtype=None)[source]#

Computes NovoGrad updates.

References

[Ginsburg et al, 2019](https://arxiv.org/abs/1905.11286)

Parameters
  • b1 (float) – A decay rate for the exponentially weighted average of grads.

  • b2 (float) – A decay rate for the exponentially weighted average of squared grads.

  • eps (float) – A term added to the denominator to improve numerical stability.

  • eps_root (float) – A term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • weight_decay (float) – A scalar weight decay rate.

  • mu_dtype (Optional[_ScalarMeta, None]) – An optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Return type

optax.GradientTransformation

Returns

The corresponding GradientTransformation.

optax.scale_by_param_block_norm(min_scale=0.001)[source]#

Scale updates for each param block by the norm of that block’s parameters.

A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

Parameters

min_scale (float) – Minimum scaling factor.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_param_block_rms(min_scale=0.001)[source]#

Scale updates by rms of the gradient for each param vector or matrix.

A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

Parameters

min_scale (float) – Minimum scaling factor.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_radam(b1=0.9, b2=0.999, eps=1e-08, eps_root=0.0, threshold=5.0)[source]#

Rescale updates according to the Rectified Adam algorithm.

References

[Liu et al, 2020](https://arxiv.org/abs/1908.03265)

Parameters
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • threshold (float) – Threshold for variance tractability.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_rms(decay=0.9, eps=1e-08, initial_scale=0.0)[source]#

Rescale updates by the root of the exp. moving avg of the square.

References

[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)

Parameters
  • decay (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • initial_scale (float) – Initial value for second moment.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_rss(initial_accumulator_value=0.1, eps=1e-07)[source]#

Rescale updates by the root of the sum of all squared gradients to date.

References

[Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) [McMahan et al., 2010](https://arxiv.org/abs/1002.4908)

Parameters
  • initial_accumulator_value (float) – Starting value for accumulators, must be >= 0.

  • eps (float) – A small floating point value to avoid zero denominator.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_schedule(step_size_fn)[source]#

Scale updates using a custom schedule for the step_size.

Parameters

step_size_fn (Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]) – A function that takes an update count as input and proposes the step_size to multiply the updates by.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_sm3(b1=0.9, b2=1.0, eps=1e-08)[source]#

Scale updates by sm3.

References

[Anil et. al 2019](https://arxiv.org/abs/1901.11150)

Parameters
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_stddev(decay=0.9, eps=1e-08, initial_scale=0.0)[source]#

Rescale updates by the root of the centered exp. moving average of squares.

References

[Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)

Parameters
  • decay (float) – Decay rate for the exponentially weighted average of squared grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • initial_scale (float) – Initial value for second moment.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_trust_ratio(min_norm=0.0, trust_coefficient=1.0, eps=0.0)[source]#

Scale updates by trust ratio.

References

[You et. al 2020](https://arxiv.org/abs/1904.00962)

Parameters
  • min_norm (float) – Minimum norm for params and gradient norms; by default is zero.

  • trust_coefficient (float) – A multiplier for the trust ratio.

  • eps (float) – Additive constant added to the denominator for numerical stability.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.scale_by_yogi(b1=0.9, b2=0.999, eps=0.001, eps_root=0.0, initial_accumulator_value=1e-06)[source]#

Rescale updates according to the Yogi algorithm.

Supports complex numbers, see https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29

References

[Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) #pylint:disable=line-too-long

Parameters
  • b1 (float) – Decay rate for the exponentially weighted average of grads.

  • b2 (float) – Decay rate for the exponentially weighted average of variance of grads.

  • eps (float) – Term added to the denominator to improve numerical stability.

  • eps_root (float) – Term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling.

  • initial_accumulator_value (float) – The starting value for accumulators. Only positive values are allowed.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

class optax.ScaleByAdamState(count: chex.Array, mu: base.Updates, nu: base.Updates)[source]#

State for the Adam algorithm.

count: chex.Array#

Alias for field number 0

mu: base.Updates#

Alias for field number 1

nu: base.Updates#

Alias for field number 2

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.ScaleByAmsgradState(count: chex.Array, mu: base.Updates, nu: base.Updates, nu_max: base.Updates)[source]#

State for the AMSGrad algorithm.

count: chex.Array#

Alias for field number 0

mu: base.Updates#

Alias for field number 1

nu: base.Updates#

Alias for field number 2

nu_max: base.Updates#

Alias for field number 3

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.ScaleByLionState(count: chex.Array, mu: base.Updates)[source]#

State for the Lion algorithm.

count: chex.Array#

Alias for field number 0

mu: base.Updates#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.ScaleByNovogradState(count: chex.Array, mu: base.Updates, nu: base.Updates)[source]#

State for Novograd.

count: chex.Array#

Alias for field number 0

mu: base.Updates#

Alias for field number 1

nu: base.Updates#

Alias for field number 2

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.ScaleByRmsState(nu: base.Updates)[source]#

State for exponential root mean-squared (RMS)-normalized updates.

nu: base.Updates#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.ScaleByRssState(sum_of_squares: base.Updates)[source]#

State holding the sum of gradient squares to date.

sum_of_squares: base.Updates#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.ScaleByRStdDevState(mu: base.Updates, nu: base.Updates)[source]#

State for centered exponential moving average of squares of updates.

mu: base.Updates#

Alias for field number 0

nu: base.Updates#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.ScaleByScheduleState(count: chex.Array)[source]#

Maintains count for scale scheduling.

count: chex.Array#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.ScaleBySM3State(mu: base.Updates, nu: base.Updates)[source]#

State for the SM3 algorithm.

mu: base.Updates#

Alias for field number 0

nu: base.Updates#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.ScaleByTrustRatioState[source]#

The scale and decay trust ratio transformation is stateless.

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

static __new__(_cls)#

Create new instance of ScaleByTrustRatioState()

optax.ScaleState[source]#

alias of optax._src.base.EmptyState

optax.set_to_zero()[source]#

Stateless transformation that maps input gradients to zero.

The resulting update function, when called, will return a tree of zeros matching the shape of the input gradients. This means that when the updates returned from this transformation are applied to the model parameters, the model parameters will remain unchanged.

This can be used in combination with multi_transform or masked to freeze (i.e. keep fixed) some parts of the tree of model parameters while applying gradient updates to other parts of the tree.

When updates are set to zero inside the same jit-compiled function as the calculation of gradients, optax transformations, and application of updates to parameters, unnecessary computations will in general be dropped.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

optax.stateless(f)[source]#

Creates a stateless transformation from an update-like function.

This wrapper eliminates the boilerplate needed to create a transformation that does not require saved state between iterations.

Parameters

f (Callable[[Updates, Optional[Params]], Updates]) – Update function that takes in updates (e.g. gradients) and parameters and returns updates. The parameters may be None.

Return type

GradientTransformation

Returns

An optax.GradientTransformation.

optax.stateless_with_tree_map(f)[source]#

Creates a stateless transformation from an update-like function for arrays.

This wrapper eliminates the boilerplate needed to create a transformation that does not require saved state between iterations, just like optax.stateless. In addition, this function will apply the tree_map over update/params for you.

Parameters

f (Callable[[Union[Array, ndarray, bool_, number], Union[Array, ndarray, bool_, number, None]], Union[Array, ndarray, bool_, number]]) – Update function that takes in an update array (e.g. gradients) and parameter array and returns an update array. The parameter array may be None.

Return type

optax.GradientTransformation

Returns

An optax.GradientTransformation.

optax.trace(decay, nesterov=False, accumulator_dtype=None)[source]#

Compute a trace of past updates.

Note: trace and ema have very similar but distinct updates; trace = decay * trace + t, while ema = decay * ema + (1-decay) * t. Both are frequently found in the optimization literature.

Parameters
  • decay (float) – Decay rate for the trace of past updates.

  • nesterov (bool) – Whether to use Nesterov momentum.

  • accumulator_dtype (Optional[Any, None]) – Optional dtype to be used for the accumulator; if None then the dtype is inferred from params and updates.

Return type

optax.GradientTransformation

Returns

A GradientTransformation object.

class optax.TraceState(trace: base.Params)[source]#

Holds an aggregation of past updates.

trace: base.Params#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.zero_nans()[source]#

A transformation which replaces NaNs with 0.

Zeroing values in gradients is guaranteed to produce a direction of non-increasing loss.

The state of the transformation has the same tree structure as that of the parameters. Each leaf is a single boolean which contains True iff a NaN was detected in the corresponding parameter array at the last call to update. This state is not used by the transformation internally, but lets users be aware when NaNs have been zeroed out.

Return type

optax.GradientTransformation

Returns

A GradientTransformation.

class optax.ZeroNansState(found_nan: Any)[source]#

Contains a tree.

The entry found_nan has the same tree structure as that of the parameters. Each leaf is a single boolean which contains True iff a NaN was detected in the corresponding parameter array at the last call to update.

found_nan: Any#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Apply Updates#

apply_updates(params, updates)

Applies an update to the corresponding parameters.

incremental_update(new_tensors, old_tensors, ...)

Incrementally update parameters via polyak averaging.

periodic_update(new_tensors, old_tensors, ...)

Periodically update all parameters with new values.

apply_updates#

optax.apply_updates(params, updates)[source]#

Applies an update to the corresponding parameters.

This is a utility functions that applies an update to a set of parameters, and then returns the updated parameters to the caller. As an example, the update may be a gradient transformed by a sequence of`GradientTransformations`. This function is exposed for convenience, but it just adds updates and parameters; you may also apply updates to parameters manually, using tree_map (e.g. if you want to manipulate updates in custom ways before applying them).

Parameters
  • params (optax.Params) – a tree of parameters.

  • updates (optax.Updates) – a tree of updates, the tree structure and the shape of the leaf

  • params. (nodes must match that of) –

Return type

optax.Params

Returns

Updated parameters, with same structure, shape and type as params.

incremental_update#

optax.incremental_update(new_tensors, old_tensors, step_size)[source]#

Incrementally update parameters via polyak averaging.

Polyak averaging tracks an (exponential moving) average of the past parameters of a model, for use at test/evaluation time.

References

[Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046)

Parameters
  • new_tensors (optax.Params) – the latest value of the tensors.

  • old_tensors (optax.Params) – a moving average of the values of the tensors.

  • step_size (chex.Numeric) – the step_size used to update the polyak average on each step.

Return type

optax.Params

Returns

an updated moving average step_size*new+(1-step_size)*old of the params.

periodic_update#

optax.periodic_update(new_tensors, old_tensors, steps, update_period)[source]#

Periodically update all parameters with new values.

A slow copy of a model’s parameters, updated every K actual updates, can be used to implement forms of self-supervision (in supervised learning), or to stabilise temporal difference learning updates (in reinforcement learning).

References

[Grill et al., 2020](https://arxiv.org/abs/2006.07733) [Mnih et al., 2015](https://arxiv.org/abs/1312.5602)

Parameters
  • new_tensors (optax.Params) – the latest value of the tensors.

  • old_tensors (optax.Params) – a slow copy of the model’s parameters.

  • steps (chex.Array) – number of update steps on the “online” network.

  • update_period (int) – every how many steps to update the “target” network.

Return type

optax.Params

Returns

a slow copy of the model’s parameters, updated every update_period steps.

Combining Optimizers#

chain(*args)

Applies a list of chainable update transformations.

multi_transform(transforms, param_labels)

Partitions params and applies a different transformation to each subset.

chain#

optax.chain(*args)[source]#

Applies a list of chainable update transformations.

Given a sequence of chainable transforms, chain returns an init_fn that constructs a state by concatenating the states of the individual transforms, and returns an update_fn which chains the update transformations feeding the appropriate state to each.

Parameters

*args – a sequence of chainable (init_fn, update_fn) tuples.

Return type

optax.GradientTransformationExtraArgs

Returns

A GradientTransformationExtraArgs, created by chaining the input transformations. Note that independent of the argument types, the resulting transformation always supports extra args. Any extra arguments passed to the returned transformation will be passed only to those transformations in the chain that support extra args.

Multi Transform#

optax.multi_transform(transforms, param_labels)[source]#

Partitions params and applies a different transformation to each subset.

Below is an example where we apply Adam to the weights and SGD to the biases of a 2-layer neural network:

import optax
import jax
import jax.numpy as jnp

def map_nested_fn(fn):
  '''Recursively apply `fn` to the key-value pairs of a nested dict'''
  def map_fn(nested_dict):
    return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v))
            for k, v in nested_dict.items()}
  return map_fn

params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)},
          'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}}
gradients = jax.tree_util.tree_map(jnp.ones_like, params)  # dummy gradients

label_fn = map_nested_fn(lambda k, _: k)
tx = optax.multi_transform({'w': optax.adam(1.0), 'b': optax.sgd(1.0)},
                           label_fn)
state = tx.init(params)
updates, new_state = tx.update(gradients, state, params)
new_params = optax.apply_updates(params, updates)

Instead of providing a label_fn, you may provide a PyTree of labels directly. Also, this PyTree may be a prefix of the parameters PyTree. This is demonstrated in the GAN pseudocode below:

generator_params = ...
discriminator_params = ...
all_params = (generator_params, discriminator_params)
param_labels = ('generator', 'discriminator')

tx = optax.multi_transform(
    {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)},
    param_labels)

If you would like to not optimize some parameters, you may wrap optax.multi_transform with optax.masked().

Parameters
  • transforms (Mapping[Hashable, optax.GradientTransformation]) – A mapping from labels to transformations. Each transformation will be only be applied to parameters with the same label.

  • param_labels (Union[Any, Callable[[Any], Any]]) – A PyTree that is the same shape or a prefix of the parameters/updates (or a function that returns one given the parameters as input). The leaves of this PyTree correspond to the keys of the transforms (therefore the values at the leaves must be a subset of the keys).

Return type

optax.GradientTransformationExtraArgs

Returns

An optax.GradientTransformation.

class optax.MultiTransformState(inner_states)[source]#
inner_states: Mapping[Hashable, base.OptState]#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Optimizer Wrappers#

apply_if_finite(inner, max_consecutive_errors)

A function that wraps an optimizer to make it robust to a few NaNs or Infs.

ApplyIfFiniteState(notfinite_count, ...)

State of the GradientTransformation returned by apply_if_finite.

flatten(inner)

Flattens parameters and gradients for init and update of inner transform.

lookahead(fast_optimizer, sync_period, ...)

Lookahead optimizer.

LookaheadParams(fast, slow)

Holds a pair of slow and fast parameters for the lookahead optimizer.

LookaheadState(fast_state, steps_since_sync)

State of the GradientTransformation returned by lookahead.

masked(inner, mask)

Mask updates so only some are transformed, the rest are passed through.

MaskedState(inner_state)

Maintains inner transform state for masked transformations.

maybe_update(inner, should_update_fn)

Calls the inner update function only at certain steps.

MaybeUpdateState(inner_state, step)

Maintains inner transform state and adds a step counter.

MultiSteps(opt, every_k_schedule[, ...])

An optimizer wrapper to accumulate gradients over multiple steps.

MultiStepsState(mini_step, gradient_step, ...)

State of the GradientTransformation returned by MultiSteps.

ShouldSkipUpdateFunction(*args, **kwargs)

skip_large_updates(updates, gradient_step, ...)

Returns True if the global norm square of updates is small enough.

skip_not_finite(updates, gradient_step, params)

Returns True iff any of the updates contains an inf or a NaN.

Apply if Finite#

optax.apply_if_finite(inner, max_consecutive_errors)[source]#

A function that wraps an optimizer to make it robust to a few NaNs or Infs.

The purpose of this function is to prevent any optimization to happen if the gradients contain NaNs or Infs. That is, when a NaN of Inf is detected in the gradients, the wrapped optimizer ignores that gradient update. If the NaNs or Infs persist after a given number of updates, the wrapped optimizer gives up and accepts the update.

Parameters
  • inner (optax.GradientTransformation) – Inner transformation to be wrapped.

  • max_consecutive_errors (int) – Maximum number of consecutive gradient updates containing NaNs of Infs that the wrapped optimizer will ignore. After that many ignored updates, the optimizer will give up and accept.

Return type

optax.GradientTransformation

Returns

New GradientTransformationExtraArgs.

class optax.ApplyIfFiniteState(notfinite_count: Any, last_finite: Any, total_notfinite: Any, inner_state: Any)[source]#

State of the GradientTransformation returned by apply_if_finite.

Fields:
notfinite_count: Number of consecutive gradient updates containing an Inf or

a NaN. This number is reset to 0 whenever a gradient update without an Inf or a NaN is done.

last_finite: Whether or not the last gradient update contained an Inf of a

NaN.

total_notfinite: Total number of gradient updates containing an Inf or

a NaN since this optimizer was initialised. This number is never reset.

inner_state: The state of the inner GradientTransformation.

notfinite_count: Any#

Alias for field number 0

last_finite: Any#

Alias for field number 1

total_notfinite: Any#

Alias for field number 2

inner_state: Any#

Alias for field number 3

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

flatten#

optax.flatten(inner)[source]#

Flattens parameters and gradients for init and update of inner transform.

This can reduce the overhead of performing many calculations on lots of small variables, at the cost of slightly increased memory usage.

Parameters

inner (optax.GradientTransformation) – Inner transformation to flatten inputs for.

Return type

optax.GradientTransformationExtraArgs

Returns

New GradientTransformationExtraArgs

Lookahead#

optax.lookahead(fast_optimizer, sync_period, slow_step_size, reset_state=False)[source]#

Lookahead optimizer.

Performs steps with a fast optimizer and periodically updates a set of slow parameters. Optionally resets the fast optimizer state after synchronization by calling the init function of the fast optimizer.

Updates returned by the lookahead optimizer should not be modified before they are applied, otherwise fast and slow parameters are not synchronized correctly.

References

[Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf)

Parameters
  • fast_optimizer (optax.GradientTransformation) – The optimizer to use in the inner loop of lookahead.

  • sync_period (int) – Number of fast optimizer steps to take before synchronizing parameters. Must be >= 1.

  • slow_step_size (float) – Step size of the slow parameter updates.

  • reset_state (bool) – Whether to reset the optimizer state of the fast opimizer after each synchronization.

Return type

optax.GradientTransformation

Returns

A GradientTransformation with init and update functions. The updates passed to the update function should be calculated using the fast lookahead parameters only.

class optax.LookaheadParams(fast: base.Params, slow: base.Params)[source]#

Holds a pair of slow and fast parameters for the lookahead optimizer.

Gradients should always be calculated with the fast parameters. The slow parameters should be used for testing and inference as they generalize better. See the reference for a detailed discussion.

References

[Zhang et al, 2019](https://arxiv.org/pdf/1907.08610v1.pdf)

fast#

Fast parameters.

Type

optax.Params

slow#

Slow parameters.

Type

optax.Params

fast: base.Params#

Alias for field number 0

slow: base.Params#

Alias for field number 1

classmethod init_synced(params)[source]#

Initialize a pair of synchronized lookahead parameters.

Return type

‘LookaheadParams’

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

class optax.LookaheadState(fast_state: base.OptState, steps_since_sync: jnp.ndarray)[source]#

State of the GradientTransformation returned by lookahead.

fast_state#

Optimizer state of the fast optimizer.

Type

optax.OptState

steps_since_sync#

Number of fast optimizer steps taken since slow and fast parameters were synchronized.

Type

jnp.ndarray

fast_state: base.OptState#

Alias for field number 0

steps_since_sync: jnp.ndarray#

Alias for field number 1

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Masked Update#

optax.masked(inner, mask)[source]#

Mask updates so only some are transformed, the rest are passed through.

For example, it is common to skip weight decay for BatchNorm scale and all bias parameters. In many networks, these are the only parameters with only one dimension. So, you may create a mask function to mask these out as follows:

mask_fn = lambda p: jax.tree_util.tree_map(lambda x: x.ndim != 1, p)
weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask_fn)

You may alternatively create the mask pytree upfront:

mask = jax.tree_util.tree_map(lambda x: x.ndim != 1, params)
weight_decay = optax.masked(optax.add_decayed_weights(0.001), mask)

For the inner transform, state will only be stored for the parameters that have a mask value of True.

Note that, when using tree_map_params, it may be required to pass the argument is_leaf=lambda v: isinstance(v, optax.MaskedNode), if the tree map needs to take additional arguments with the same shape as the original input tree.

Parameters
  • inner (optax.GradientTransformation) – Inner transformation to mask.

  • mask (Union[base.PyTree, Callable[[optax.Params], base.PyTree]]) – a PyTree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip. The mask must be static for the gradient transformation to be jit-compilable.

Return type

optax.GradientTransformationExtraArgs

Returns

New GradientTransformationExtraArgs wrapping inner.

class optax.MaskedState(inner_state: Any)[source]#

Maintains inner transform state for masked transformations.

inner_state: Any#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Maybe Update#

optax.maybe_update(inner, should_update_fn)[source]#

Calls the inner update function only at certain steps.

Creates a transformation wrapper which counts the number of times the update function has been called. This counter is passed to the should_update_fn to decide when to call the inner update function.

When not calling the inner update function, the updates and the inner state are left untouched and just passed through. The step counter is increased regardless.

Parameters
  • inner (optax.GradientTransformation) – the inner transformation.

  • should_update_fn (Callable[[Array], Array]) – this function takes in a step counter (array of shape [] and dtype int32), and returns a boolean array of shape [].

Return type

optax.GradientTransformationExtraArgs

Returns

A new GradientTransformationExtraArgs.

class optax.MaybeUpdateState(inner_state: Any, step: Array)[source]#

Maintains inner transform state and adds a step counter.

inner_state: Any#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

step: Array#

Alias for field number 1

Multi-step Update#

class optax.MultiSteps(opt, every_k_schedule, use_grad_mean=True, should_skip_update_fn=None)[source]#

An optimizer wrapper to accumulate gradients over multiple steps.

This wrapper collects together the updates passed to its update function over consecutive steps until a given number of scheduled steps is reached. In each of these intermediate steps, the returned value from the optimizer is a tree of zeros of the same shape of the updates passed as input.

Once the scheduled number of intermediate ‘mini-steps’ has been reached, the gradients accumulated to the current time will be passed to the wrapped optimizer’s update function, (with the inner optimizer’s state being updated appropriately) and then returned to the caller. The wrapper’s accumulated gradients are then set back to zero and the process starts again.

The number of mini-steps per gradient update is controlled by a function, and it can vary over training. This offers a means of varying batch size over training.

__init__(opt, every_k_schedule, use_grad_mean=True, should_skip_update_fn=None)[source]#

Initialiser.

Parameters
  • opt (optax.GradientTransformation) – the wrapped optimizer.

  • every_k_schedule (Union[int, Callable[[Array], Array]]) –

    an int or f a function. * As a function, it returns how many mini-steps should be accumulated

    in a single gradient step. Its only argument is the current gradient step count. By varying the returned value, users can vary the overall training batch size.

    • If an int, this is the constant number of mini-steps per gradient update.

  • use_grad_mean (bool) – if True (the default), gradients accumulated over multiple mini-steps are averaged. Otherwise, they are summed.

  • should_skip_update_fn (Optional[ShouldSkipUpdateFunction, None]) –

    if provided, this function is used to decide when to accept or reject the updates from a mini-step. When a mini-step is rejected, the inner state of MultiSteps is not updated. In other words, it is as if this mini-step never happened. For example: * to ignore updates containing inf or NaN, do

    should_skip_update_fn=skip_not_finite;

    • to ignore updates with a norm square larger then 42, do `should_skip_update_fn=functools.partial(skip_large_updates,

      max_norm_sq=42.)`.

    Note that the optimizer’s state MultiStepsState contains a field skip_state in which debugging and monitoring information returned by should_skip_update_fn is written.

init(params)[source]#

Builds and returns initial MultiStepsState.

Return type

MultiStepsState

update(updates, state, params=None, **extra_args)[source]#

Accumulates gradients and proposes non-zero updates every k_steps.

Return type

Tuple[optax.Updates, MultiStepsState]

class optax.MultiStepsState(mini_step: Array, gradient_step: Array, inner_opt_state: Any, acc_grads: Any, skip_state: chex.ArrayTree = ())[source]#

State of the GradientTransformation returned by MultiSteps.

Fields:
mini_step: current mini-step counter. At an update, this either increases by

1 or is reset to 0.

gradient_step: gradient step counter. This only increases after enough

mini-steps have been accumulated.

inner_opt_state: the state of the wrapped otpimiser. acc_grads: accumulated gradients over multiple mini-steps. skip_state: an arbitrarily nested tree of arrays. This is only

relevant when passing a should_skip_update_fn to MultiSteps. This structure will then contain values for debugging and or monitoring. The actual structure will vary depending on the choice of ShouldSkipUpdateFunction.

mini_step: Array#

Alias for field number 0

gradient_step: Array#

Alias for field number 1

inner_opt_state: Any#

Alias for field number 2

acc_grads: Any#

Alias for field number 3

skip_state: chex.ArrayTree#

Alias for field number 4

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Common Losses#

convex_kl_divergence(log_predictions, targets)

Computes a convex version of the Kullback-Leibler divergence loss.

cosine_distance(predictions, targets[, epsilon])

Computes the cosine distance between targets and predictions.

cosine_similarity(predictions, targets[, ...])

Computes the cosine similarity between targets and predictions.

ctc_loss(logits, logit_paddings, labels, ...)

Computes CTC loss.

ctc_loss_with_forward_probs(logits, ...[, ...])

Computes CTC loss and CTC forward-probabilities.

hinge_loss(predictor_outputs, targets)

Computes the hinge loss for binary classification.

huber_loss(predictions[, targets, delta])

Huber loss, similar to L2 loss close to zero, L1 loss away from zero.

kl_divergence(log_predictions, targets)

Computes the Kullback-Leibler divergence (relative entropy) loss.

l2_loss(predictions[, targets])

Calculates the L2 loss for a set of predictions.

log_cosh(predictions[, targets])

Calculates the log-cosh loss for a set of predictions.

sigmoid_binary_cross_entropy(logits, labels)

Computes element-wise sigmoid cross entropy given logits and labels.

smooth_labels(labels, alpha)

Apply label smoothing.

softmax_cross_entropy(logits, labels)

Computes the softmax cross entropy between sets of logits and labels.

softmax_cross_entropy_with_integer_labels(...)

Computes softmax cross entropy between sets of logits and integer labels.

squared_error(predictions[, targets])

Calculates the squared error for a set of predictions.

Losses#

optax.convex_kl_divergence(log_predictions, targets)[source]#

Computes a convex version of the Kullback-Leibler divergence loss.

Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution. This version is jointly convex in p (targets) and q (log_predictions).

References

[Kullback, Leibler, 1951](https://www.jstor.org/stable/2236703)

Parameters
  • log_predictions (Union[Array, ndarray, bool_, number]) – Probabilities of predicted distribution with shape […, dim]. Expected to be in the log-space to avoid underflow.

  • targets (Union[Array, ndarray, bool_, number]) – Probabilities of target distribution with shape […, dim]. Expected to be strictly positive.

Return type

Union[Array, ndarray, bool_, number]

Returns

Kullback-Leibler divergence of predicted distribution from target distribution with shape […].

optax.cosine_distance(predictions, targets, epsilon=0.0)[source]#

Computes the cosine distance between targets and predictions.

The cosine distance, implemented here, measures the dissimilarity of two vectors as the opposite of cosine similarity: 1 - cos(theta).

References

[Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)

Parameters
  • predictions (Union[Array, ndarray, bool_, number]) – The predicted vectors, with shape […, dim].

  • targets (Union[Array, ndarray, bool_, number]) – Ground truth target vectors, with shape […, dim].

  • epsilon (float) – minimum norm for terms in the denominator of the cosine similarity.

Return type

Union[Array, ndarray, bool_, number]

Returns

cosine distances, with shape […].

optax.cosine_similarity(predictions, targets, epsilon=0.0)[source]#

Computes the cosine similarity between targets and predictions.

The cosine similarity is a measure of similarity between vectors defined as the cosine of the angle between them, which is also the inner product of those vectors normalized to have unit norm.

References

[Wikipedia, 2021](https://en.wikipedia.org/wiki/Cosine_similarity)

Parameters
  • predictions (Union[Array, ndarray, bool_, number]) – The predicted vectors, with shape […, dim].

  • targets (Union[Array, ndarray, bool_, number]) – Ground truth target vectors, with shape […, dim].

  • epsilon (float) – minimum norm for terms in the denominator of the cosine similarity.

Return type

Union[Array, ndarray, bool_, number]

Returns

cosine similarity measures, with shape […].

optax.ctc_loss(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon=- 100000.0)[source]#

Computes CTC loss.

See docstring for ctc_loss_with_forward_probs for details.

Parameters
  • logits (Union[Array, ndarray, bool_, number]) – (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in logits, and K denotes the number of classes including a class for blanks.

  • logit_paddings (Union[Array, ndarray, bool_, number]) – (B, T)-array. Padding indicators for logits. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0 denotes that logits[b, t, :] are padded values.

  • labels (Union[Array, ndarray, bool_, number]) – (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence.

  • label_paddings (Union[Array, ndarray, bool_, number]) – (B, N)-array. Padding indicators for labels. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0 denotes that labels[b, n] is a padded label. In the current implementation, labels must be right-padded, i.e. each row labelpaddings[b, :] must be repetition of zeroes, followed by repetition of ones.

  • blank_id (int) – Id for blank token. logits[b, :, blank_id] are used as probabilities of blank symbols.

  • log_epsilon (float) – Numerically-stable approximation of log(+0).

Return type

Union[Array, ndarray, bool_, number]

Returns

(B,)-array containing loss values for each sequence in the batch.

optax.ctc_loss_with_forward_probs(logits, logit_paddings, labels, label_paddings, blank_id=0, log_epsilon=- 100000.0)[source]#

Computes CTC loss and CTC forward-probabilities.

The CTC loss is a loss function based on log-likelihoods of the model that introduces a special blank symbol \(\phi\) to represent variable-length output sequences.

Forward probabilities returned by this function, as auxiliary results, are grouped into two part: blank alpha-probability and non-blank alpha probability. Those are defined as follows:

\[\alpha_{\mathrm{BLANK}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = \phi | \pi_{1:t-1}, y_{1:n-1}, \cdots), \\ \alpha_{\mathrm{LABEL}}(t, n) = \sum_{\pi_{1:t-1}} p(\pi_t = y_n | \pi_{1:t-1}, y_{1:n-1}, \cdots). \]

Here, \(\pi\) denotes the alignment sequence in the reference [Graves et al, 2006] that is blank-inserted representations of labels. The return values are the logarithms of the above probabilities.

References

[Graves et al, 2006](https://dl.acm.org/doi/abs/10.1145/1143844.1143891)

Parameters
  • logits (Union[Array, ndarray, bool_, number]) – (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in logits, and K denotes the number of classes including a class for blanks.

  • logit_paddings (Union[Array, ndarray, bool_, number]) – (B, T)-array. Padding indicators for logits. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0 denotes that logits[b, t, :] are padded values.

  • labels (Union[Array, ndarray, bool_, number]) – (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence.

  • label_paddings (Union[Array, ndarray, bool_, number]) – (B, N)-array. Padding indicators for labels. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0 denotes that labels[b, n] is a padded label. In the current implementation, labels must be right-padded, i.e. each row labelpaddings[b, :] must be repetition of zeroes, followed by repetition of ones.

  • blank_id (int) – Id for blank token. logits[b, :, blank_id] are used as probabilities of blank symbols.

  • log_epsilon (float) – Numerically-stable approximation of log(+0).

Return type

Tuple[Union[Array, ndarray, bool_, number], Union[Array, ndarray, bool_, number], Union[Array, ndarray, bool_, number]]

Returns

A tuple (loss_value, logalpha_blank, logalpha_nonblank). Here, loss_value is a (B,)-array containing the loss values for each sequence in the batch, logalpha_blank and logalpha_nonblank are (T, B, N+1)-arrays where the (t, b, n)-th element denotes log alpha_B(t, n) and log alpha_L(t, n), respectively, for b-th sequence in the batch.

optax.hinge_loss(predictor_outputs, targets)[source]#

Computes the hinge loss for binary classification.

Parameters
  • predictor_outputs (Union[Array, ndarray, bool_, number]) – Outputs of the decision function.

  • targets (Union[Array, ndarray, bool_, number]) – Target values. Target values should be strictly in the set {-1, 1}.

Return type

Union[Array, ndarray, bool_, number]

Returns

Binary Hinge Loss.

optax.huber_loss(predictions, targets=None, delta=1.0)[source]#

Huber loss, similar to L2 loss close to zero, L1 loss away from zero.

If gradient descent is applied to the huber loss, it is equivalent to clipping gradients of an l2_loss to [-delta, delta] in the backward pass.

References

[Huber, 1964](www.projecteuclid.org/download/pdf_1/euclid.aoms/1177703732)

Parameters
  • predictions (Union[Array, ndarray, bool_, number]) – a vector of arbitrary shape […].

  • targets (Union[Array, ndarray, bool_, number, None]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

  • delta (float) – the bounds for the huber loss transformation, defaults at 1.

Return type

Union[Array, ndarray, bool_, number]

Returns

elementwise huber losses, with the same shape of predictions.

optax.kl_divergence(log_predictions, targets)[source]#

Computes the Kullback-Leibler divergence (relative entropy) loss.

Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution.

References

[Kullback, Leibler, 1951](https://www.jstor.org/stable/2236703)

Parameters
  • log_predictions (Union[Array, ndarray, bool_, number]) – Probabilities of predicted distribution with shape […, dim]. Expected to be in the log-space to avoid underflow.

  • targets (Union[Array, ndarray, bool_, number]) – Probabilities of target distribution with shape […, dim]. Expected to be strictly positive.

Return type

Union[Array, ndarray, bool_, number]

Returns

Kullback-Leibler divergence of predicted distribution from target distribution with shape […].

optax.l2_loss(predictions, targets=None)[source]#

Calculates the L2 loss for a set of predictions.

Note: the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.

References

[Chris Bishop, 2006](https://bit.ly/3eeP0ga)

Parameters
  • predictions (Union[Array, ndarray, bool_, number]) – a vector of arbitrary shape […].

  • targets (Union[Array, ndarray, bool_, number, None]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Return type

Union[Array, ndarray, bool_, number]

Returns

elementwise squared differences, with same shape as predictions.

optax.log_cosh(predictions, targets=None)[source]#

Calculates the log-cosh loss for a set of predictions.

log(cosh(x)) is approximately (x**2) / 2 for small x and abs(x) - log(2) for large x. It is a twice differentiable alternative to the Huber loss.

References

[Chen et al, 2019](https://openreview.net/pdf?id=rkglvsC9Ym)

Parameters
  • predictions (Union[Array, ndarray, bool_, number]) – a vector of arbitrary shape […].

  • targets (Union[Array, ndarray, bool_, number, None]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Return type

Union[Array, ndarray, bool_, number]

Returns

the log-cosh loss, with same shape as predictions.

optax.sigmoid_binary_cross_entropy(logits, labels)[source]#

Computes element-wise sigmoid cross entropy given logits and labels.

This function can be used for binary or multiclass classification (where each class is an independent binary prediction and different classes are not mutually exclusive e.g. predicting that an image contains both a cat and a dog.)

Because this function is overloaded, please ensure your logits and labels are compatible with each other. If you’re passing in binary labels (values in {0, 1}), ensure your logits correspond to class 1 only. If you’re passing in per-class target probabilities or one-hot labels, please ensure your logits are also multiclass. Be particularly careful if you’re relying on implicit broadcasting to reshape logits or labels.

References

[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)

Parameters
  • logits – Each element is the unnormalized log probability of a binary prediction. See note about compatibility with labels above.

  • labels – Binary labels whose values are {0,1} or multi-class target probabilities. See note about compatibility with logits above.

Returns

cross entropy for each binary prediction, same shape as logits.

optax.smooth_labels(labels, alpha)[source]#

Apply label smoothing.

Label smoothing is often used in combination with a cross-entropy loss. Smoothed labels favour small logit gaps, and it has been shown that this can provide better model calibration by preventing overconfident predictions.

References

[Müller et al, 2019](https://arxiv.org/pdf/1906.02629.pdf)

Parameters
  • labels (Union[Array, ndarray, bool_, number]) – One hot labels to be smoothed.

  • alpha (float) – The smoothing factor.

Return type

Array

Returns

a smoothed version of the one hot input labels.

optax.softmax_cross_entropy(logits, labels)[source]#

Computes the softmax cross entropy between sets of logits and labels.

Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.

References

[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)

Parameters
  • logits (Union[Array, ndarray, bool_, number]) – Unnormalized log probabilities, with shape […, num_classes].

  • labels (Union[Array, ndarray, bool_, number]) – Valid probability distributions (non-negative, sum to 1), e.g a one hot encoding specifying the correct class for each input; must have a shape broadcastable to […, num_classes].

Return type

Union[Array, ndarray, bool_, number]

Returns

cross entropy between each prediction and the corresponding target distributions, with shape […].

optax.softmax_cross_entropy_with_integer_labels(logits, labels)[source]#

Computes softmax cross entropy between sets of logits and integer labels.

Measures the probability error in discrete classification tasks in which the classes are mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is labeled with one and only one label: an image can be a dog or a truck, but not both.

References

[Goodfellow et al, 2016](http://www.deeplearningbook.org/contents/prob.html)

Parameters
  • logits (Union[Array, ndarray, bool_, number]) – Unnormalized log probabilities, with shape […, num_classes].

  • labels (Union[Array, ndarray, bool_, number]) – Integers specifying the correct class for each input, with shape […].

Return type

Union[Array, ndarray, bool_, number]

Returns

Cross entropy between each prediction and the corresponding target distributions, with shape […].

optax.squared_error(predictions, targets=None)[source]#

Calculates the squared error for a set of predictions.

Mean Squared Error can be computed as squared_error(a, b).mean().

Note: l2_loss = 0.5 * squared_error, where the 0.5 term is standard in “Pattern Recognition and Machine Learning” by Bishop, but not “The Elements of Statistical Learning” by Tibshirani.

References

[Chris Bishop, 2006](https://bit.ly/3eeP0ga)

Parameters
  • predictions (Union[Array, ndarray, bool_, number]) – a vector of arbitrary shape […].

  • targets (Union[Array, ndarray, bool_, number, None]) – a vector with shape broadcastable to that of predictions; if not provided then it is assumed to be a vector of zeros.

Return type

Union[Array, ndarray, bool_, number]

Returns

elementwise squared differences, with same shape as predictions.

Linear Algebra Operators#

matrix_inverse_pth_root(matrix, p[, ...])

Computes matrix^(-1/p), where p is a positive integer.

multi_normal(loc, log_scale)

rtype

MultiNormalDiagFromLogScale

power_iteration(matrix[, num_iters, ...])

Power iteration algorithm.

multi_normal#

optax.multi_normal(loc, log_scale)[source]#
Return type

MultiNormalDiagFromLogScale

matrix_inverse_pth_root#

optax.matrix_inverse_pth_root(matrix, p, num_iters=100, ridge_epsilon=1e-06, error_tolerance=1e-06, precision=<Precision.HIGHEST: 2>)[source]#

Computes matrix^(-1/p), where p is a positive integer.

This function uses the Coupled newton iterations algorithm for the computation of a matrix’s inverse pth root.

References

[Functions of Matrices, Theory and Computation,

Nicholas J Higham, Pg 184, Eq 7.18]( https://epubs.siam.org/doi/book/10.1137/1.9780898717778)

Parameters
  • matrix (Union[Array, ndarray, bool_, number]) – the symmetric PSD matrix whose power it to be computed

  • p (int) – exponent, for p a positive integer.

  • num_iters (int) – Maximum number of iterations.

  • ridge_epsilon (float) – Ridge epsilon added to make the matrix positive definite.

  • error_tolerance (float) – Error indicator, useful for early termination.

  • precision (Precision) – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).

Returns

matrix^(-1/p)

Utilities for numerical stability#

safe_int32_increment(count)

Increments int32 counter by one.

safe_norm(x, min_norm[, ord, axis, keepdims])

Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients.

safe_root_mean_squares(x, min_rms)

Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads.

Numerics#

optax.safe_int32_increment(count)[source]#

Increments int32 counter by one.

Normally max_int + 1 would overflow to min_int. This functions ensures that when max_int is reached the counter stays at max_int.

Parameters

count (Union[Array, ndarray, bool_, number, float, int]) – a counter to be incremented.

Return type

Union[Array, ndarray, bool_, number, float, int]

Returns

A counter incremented by 1, or max_int if the maximum precision is reached.

optax.safe_norm(x, min_norm, ord=None, axis=None, keepdims=False)[source]#

Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients.

The gradients of jnp.maximum(jnp.linalg.norm(x), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.

Parameters
  • x (Union[Array, ndarray, bool_, number]) – jax array.

  • min_norm (Union[Array, ndarray, bool_, number, float, int]) – lower bound for the returned norm.

  • ord (Union[int, float, str, None]) – {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional. Order of the norm. inf means numpy’s inf object. The default is None.

  • axis (Union[None, Tuple[int, …], int]) – {None, int, 2-tuple of ints}, optional. If axis is an integer, it specifies the axis of x along which to compute the vector norms. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is None then either a vector norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The default is None.

  • keepdims (bool) – bool, optional. If this is set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original x.

Return type

Union[Array, ndarray, bool_, number]

Returns

The safe norm of the input vector, accounting for correct gradient.

optax.safe_root_mean_squares(x, min_rms)[source]#

Returns maximum(sqrt(mean(abs_sq(x))), min_norm) with correct grads.

The gradients of maximum(sqrt(mean(abs_sq(x))), min_norm) at 0.0 is NaN, because jax will evaluate both branches of the jnp.maximum. This function will instead return the correct gradient of 0.0 also in such setting.

Parameters
  • x (Union[Array, ndarray, bool_, number]) – jax array.

  • min_rms (Union[Array, ndarray, bool_, number, float, int]) – lower bound for the returned norm.

Return type

Union[Array, ndarray, bool_, number]

Returns

The safe RMS of the input vector, accounting for correct gradient.

power_iteration#

optax.power_iteration(matrix, num_iters=100, error_tolerance=1e-06, precision=<Precision.HIGHEST: 2>)[source]#

Power iteration algorithm.

The power iteration algorithm takes a symmetric PSD matrix A, and produces a scalar lambda , which is the greatest (in absolute value) eigenvalue of A, and a vector v, which is the corresponding eigenvector of A.

References

[Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)

Parameters
  • matrix (Union[Array, ndarray, bool_, number]) – the symmetric PSD matrix.

  • num_iters (int) – Number of iterations.

  • error_tolerance (float) – Iterative exit condition.

  • precision (Precision) – precision XLA related flag, the available options are: a) lax.Precision.DEFAULT (better step time, but not precise); b) lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST (best possible precision, slowest).

Returns

eigen vector, eigen value

Optimizer Schedules#

constant_schedule(value)

Constructs a constant schedule.

cosine_decay_schedule(init_value, decay_steps)

Returns a function which implements cosine learning rate decay.

cosine_onecycle_schedule(transition_steps, ...)

Returns a function which implements the onecycle learning rate schedule.

exponential_decay(init_value, ...[, ...])

Constructs a schedule with either continuous or discrete exponential decay.

join_schedules(schedules, boundaries)

Sequentially apply multiple schedules.

linear_onecycle_schedule(transition_steps, ...)

Returns a function which implements the onecycle learning rate schedule.

linear_schedule(init_value, end_value, ...)

rtype

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

piecewise_constant_schedule(init_value[, ...])

Returns a function which implements a piecewise constant schedule.

piecewise_interpolate_schedule(...[, ...])

Returns a function which implements a piecewise interpolated schedule.

polynomial_schedule(init_value, end_value, ...)

Constructs a schedule with polynomial transition from init to end value.

sgdr_schedule(cosine_kwargs)

SGD with warm restarts, from Loschilov & Hutter (arXiv:1608.03983).

warmup_cosine_decay_schedule(init_value, ...)

Linear warmup followed by cosine decay.

warmup_exponential_decay_schedule(...[, ...])

Linear warmup followed by exponential decay.

Schedule

alias of Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]]

InjectHyperparamsState(count, hyperparams, ...)

Maintains inner transform state, hyperparameters, and step count.

inject_hyperparams(inner_factory[, ...])

Wrapper that injects hyperparameters into the inner GradientTransformation.

Schedules#

optax.constant_schedule(value)[source]#

Constructs a constant schedule.

Parameters

value (Union[float, int]) – value to be held constant throughout.

Returns

A function that maps step counts to values.

Return type

schedule

optax.cosine_decay_schedule(init_value, decay_steps, alpha=0.0, exponent=1.0)[source]#

Returns a function which implements cosine learning rate decay.

The schedule does not restart when decay_steps has been reached. Instead, the learning rate remains constant afterwards. For a cosine schedule with restarts, optax.join_schedules() can be used to join several cosine decay schedules.

For more details see: https://arxiv.org/abs/1608.03983.

Parameters
  • init_value (float) – An initial value init_v.

  • decay_steps (int) – Positive integer - the number of steps for which to apply the decay for.

  • alpha (float) – Float. The minimum value of the multiplier used to adjust the learning rate.

  • exponent (float) – Float. The default decay is 0.5 * (1 + cos(pi * t/T)), where t is the current timestep and T is the decay_steps. The exponent modifies this to be (0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.

Returns

A function that maps step counts to values.

Return type

schedule

optax.cosine_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, div_factor=25.0, final_div_factor=10000.0)[source]#

Returns a function which implements the onecycle learning rate schedule.

This function uses a cosine annealing strategy. For more details see: https://arxiv.org/abs/1708.07120

Parameters
  • transition_steps (int) – Number of steps over which annealing takes place.

  • peak_value (float) – Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).

  • pct_start (float) – The percentage of the cycle (in number of steps) spent increasing the learning rate.

  • div_factor (float) – Determines the initial value via init_value = peak_value / div_factor

  • final_div_factor (float) – Determines the final value via final_value = init_value / final_div_factor

Returns

A function that maps step counts to values.

Return type

schedule

optax.exponential_decay(init_value, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None)[source]#

Constructs a schedule with either continuous or discrete exponential decay.

This function applies an exponential decay function to a provided initial value. The function returns the decayed value as follows:

` decayed_value = init_value * decay_rate ^ (count / transition_steps) `

If the argument staircase is True, then count / transition_steps is an integer division and the decayed value follows a staircase function.

Parameters
  • init_value (float) – the initial learning rate.

  • transition_steps (int) – must be positive. See the decay computation above.

  • decay_rate (float) – must not be zero. The decay rate.

  • transition_begin (int) – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).

  • staircase (bool) – if True, decay the values at discrete intervals.

  • end_value (Optional[float, None]) – the value at which the exponential decay stops. When decay_rate < 1, end_value is treated as a lower bound, otherwise as an upper bound. Has no effect when decay_rate = 0.

Returns

A function that maps step counts to values.

Return type

schedule

optax.join_schedules(schedules, boundaries)[source]#

Sequentially apply multiple schedules.

Parameters
  • schedules (Sequence[Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A list of callables (expected to be optax schedules). Each schedule will receive a step count indicating the number of steps since the previous boundary transition.

  • boundaries (Sequence[int]) – A list of integers (of length one less than schedules) that indicate when to transition between schedules.

Returns

A function that maps step counts to values.

Return type

schedule

optax.linear_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, div_factor=25.0, final_div_factor=10000.0)[source]#

Returns a function which implements the onecycle learning rate schedule.

This function uses a linear annealing strategy. For more details see: https://arxiv.org/abs/1708.07120

Parameters
  • transition_steps (int) – Number of steps over which annealing takes place.

  • peak_value (float) – Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).

  • pct_start (float) – The percentage of the cycle (in number of steps) spent increasing the learning rate.

  • pct_final (float) – The percentage of the cycle (in number of steps) spent increasing to peak_value then decreasing back to init_value.

  • div_factor (float) – Determines the initial value via init_value = peak_value / div_factor

  • final_div_factor (float) – Determines the final value via final_value = init_value / final_div_factor

Returns

A function that maps step counts to values.

Return type

schedule

optax.linear_schedule(init_value, end_value, transition_steps, transition_begin=0)[source]#
Return type

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

optax.piecewise_constant_schedule(init_value, boundaries_and_scales=None)[source]#

Returns a function which implements a piecewise constant schedule.

Parameters
  • init_value (float) – An initial value init_v.

  • boundaries_and_scales (Optional[Dict[int, float], None]) – A map from boundaries b_i to non-negative scaling factors f_i. For any step count s, the schedule returns init_v scaled by the product of all factors f_i such that b_i < s.

Returns

A function that maps step counts to values.

Return type

schedule

optax.piecewise_interpolate_schedule(interpolate_type, init_value, boundaries_and_scales=None)[source]#

Returns a function which implements a piecewise interpolated schedule.

Parameters
  • interpolate_type (str) – ‘linear’ or ‘cosine’, specifying the interpolation strategy.

  • init_value (float) – An initial value init_v.

  • boundaries_and_scales (Optional[Dict[int, float], None]) – A map from boundaries b_i to non-negative scaling factors f_i. At boundary step b_i, the schedule returns init_v scaled by the product of all factors f_j such that b_j <= b_i. The values in between each boundary will be interpolated as per type.

Returns

A function that maps step counts to values.

Return type

schedule

optax.polynomial_schedule(init_value, end_value, power, transition_steps, transition_begin=0)[source]#

Constructs a schedule with polynomial transition from init to end value.

Parameters
  • init_value (Union[float, int]) – initial value for the scalar to be annealed.

  • end_value (Union[float, int]) – end value of the scalar to be annealed.

  • power (Union[float, int]) – the power of the polynomial used to transition from init to end.

  • transition_steps (int) – number of steps over which annealing takes place, the scalar starts changing at transition_begin steps and completes the transition by transition_begin + transition_steps steps. If transition_steps <= 0, then the entire annealing process is disabled and the value is held fixed at init_value.

  • transition_begin (int) – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).

Returns

A function that maps step counts to values.

Return type

schedule

optax.sgdr_schedule(cosine_kwargs)[source]#

SGD with warm restarts, from Loschilov & Hutter (arXiv:1608.03983).

This learning rate schedule applies multiple joined cosine decay cycles. For more details see: https://arxiv.org/abs/1608.03983

Parameters

cosine_kwargs (Iterable[Dict[str, Union[Array, ndarray, bool_, number, float, int]]]) – An Iterable of dicts, where each element specifies the arguments to pass to each cosine decay cycle. The decay_steps kwarg will specify how long each cycle lasts for, and therefore when to transition to the next cycle.

Returns

A function that maps step counts to values.

Return type

schedule

optax.warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value=0.0, exponent=1.0)[source]#

Linear warmup followed by cosine decay.

Parameters
  • init_value (float) – Initial value for the scalar to be annealed.

  • peak_value (float) – Peak value for scalar to be annealed at end of warmup.

  • warmup_steps (int) – Positive integer, the length of the linear warmup.

  • decay_steps (int) – Positive integer, the total length of the schedule. Note that this includes the warmup time, so the number of steps during which cosine annealing is applied is decay_steps - warmup_steps.

  • end_value (float) – End value of the scalar to be annealed.

  • exponent (float) – Float. The default decay is 0.5 * (1 + cos(pi * t/T)), where t is the current timestep and T is the decay_steps. The exponent modifies this to be (0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.

Returns

A function that maps step counts to values.

Return type

schedule

optax.warmup_exponential_decay_schedule(init_value, peak_value, warmup_steps, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None)[source]#

Linear warmup followed by exponential decay.

Parameters
  • init_value (float) – Initial value for the scalar to be annealed.

  • peak_value (float) – Peak value for scalar to be annealed at end of warmup.

  • warmup_steps (int) – Positive integer, the length of the linear warmup.

  • transition_steps (int) – must be positive. See exponential_decay for more details.

  • decay_rate (float) – must not be zero. The decay rate.

  • transition_begin (int) – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at peak_value).

  • staircase (bool) – if True, decay the values at discrete intervals.

  • end_value (Optional[float, None]) – the value at which the exponential decay stops. When decay_rate < 1, end_value is treated as a lower bound, otherwise as an upper bound. Has no effect when decay_rate = 0.

Returns

A function that maps step counts to values.

Return type

schedule

optax.inject_hyperparams(inner_factory, static_args=(), hyperparam_dtype=None)[source]#

Wrapper that injects hyperparameters into the inner GradientTransformation.

This wrapper allows you to pass schedules (i.e. a function that returns a numeric value given a step count) instead of constants for hyperparameters. You may only schedule numeric hyperparameters (i.e. boolean flags cannot be scheduled).

For example, to use scale_by_adam with a piecewise linear schedule for beta_1 and constant for beta_2:

scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)(
    b1=optax.piecewise_linear_schedule(...),
    b2=0.99)

You may manually change numeric hyperparameters that were not scheduled through the hyperparams dict in the InjectHyperparamState:

state = scheduled_adam.init(params)
updates, state = scheduled_adam.update(grads, state)
state.hyperparams['b2'] = 0.95
updates, state = scheduled_adam.update(updates, state)  # uses b2 = 0.95

Manually overriding scheduled hyperparameters will have no effect (e.g. in the code sample above, you cannot manually adjust b1).

Parameters
  • inner_factory (Callable[…, optax.GradientTransformation]) – a function that returns the inner optax.GradientTransformation given the hyperparameters.

  • static_args (Union[str, Iterable[str]]) – a string or iterable of strings specifying which callable parameters are not schedules. inject_hyperparams treats all callables as schedules by default, so if a hyperparameter is a non-schedule callable, you must specify that using this argument.

  • hyperparam_dtype (Optional[dtype, None]) – Optional datatype override. If specified, all float hyperparameters will be cast to this type.

Return type

Callable[…, optax.GradientTransformation]

Returns

A callable that returns a optax.GradientTransformation. This callable accepts the same arguments as inner_factory, except you may provide schedules in place of the constant arguments.

optax.Schedule#

alias of Callable[[Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]], Union[jax.Array, numpy.ndarray, numpy.bool_, numpy.number, float, int]]

class optax.InjectHyperparamsState(count: jnp.ndarray, hyperparams: Dict[str, chex.Numeric], inner_state: base.OptState)[source]#

Maintains inner transform state, hyperparameters, and step count.

count: jnp.ndarray#

Alias for field number 0

hyperparams: Dict[str, chex.Numeric]#

Alias for field number 1

inner_state: base.OptState#

Alias for field number 2

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Second Order Optimization Utilities#

fisher_diag(negative_log_likelihood, params, ...)

Computes the diagonal of the (observed) Fisher information matrix.

hessian_diag(loss, params, inputs, targets)

Computes the diagonal hessian of loss at (inputs, targets).

hvp(loss, v, params, inputs, targets)

Performs an efficient vector-Hessian (of loss) product.

fisher_diag#

optax.fisher_diag(negative_log_likelihood, params, inputs, targets)[source]#

Computes the diagonal of the (observed) Fisher information matrix.

Parameters
  • negative_log_likelihood (Callable[[Any, Array, Array], Array]) – the negative log likelihood function.

  • params (Any) – model parameters.

  • inputs (Array) – inputs at which negative_log_likelihood is evaluated.

  • targets (Array) – targets at which negative_log_likelihood is evaluated.

Return type

Array

Returns

An Array corresponding to the product to the Hessian of negative_log_likelihood evaluated at (params, inputs, targets).

hessian_diag#

optax.hessian_diag(loss, params, inputs, targets)[source]#

Computes the diagonal hessian of loss at (inputs, targets).

Parameters
  • loss (Callable[[Any, Array, Array], Array]) – the loss function.

  • params (Any) – model parameters.

  • inputs (Array) – inputs at which loss is evaluated.

  • targets (Array) – targets at which loss is evaluated.

Return type

Array

Returns

A DeviceArray corresponding to the product to the Hessian of loss evaluated at (params, inputs, targets).

hvp#

optax.hvp(loss, v, params, inputs, targets)[source]#

Performs an efficient vector-Hessian (of loss) product.

Parameters
  • loss (Callable[[Any, Array, Array], Array]) – the loss function.

  • v (Array) – a vector of size ravel(params).

  • params (Any) – model parameters.

  • inputs (Array) – inputs at which loss is evaluated.

  • targets (Array) – targets at which loss is evaluated.

Return type

Array

Returns

An Array corresponding to the product of v and the Hessian of loss evaluated at (params, inputs, targets).

Control Variates#

control_delta_method(function)

The control delta covariate method.

control_variates_jacobians(function, ...[, ...])

Obtain jacobians using control variates.

moving_avg_baseline(function[, decay, ...])

A moving average baseline.

control_delta_method#

optax.control_delta_method(function)[source]#

The control delta covariate method.

Control variate obtained by performing a second order Taylor expansion

on the cost function f at the mean of the input distribution.

Only implemented for Gaussian random variables.

For details, see: https://icml.cc/2012/papers/687.pdf

Parameters

function (Callable[[chex.Array], float]) – The function for which to compute the control variate. The function takes in one argument (a sample from the distribution) and returns a floating point value.

Return type

ControlVariate

Returns

A tuple of three functions, to compute the control variate, the expected value of the control variate, and to update the control variate state.

control_variates_jacobians#

optax.control_variates_jacobians(function, control_variate_from_function, grad_estimator, params, dist_builder, rng, num_samples, control_variate_state=None, estimate_cv_coeffs=False, estimate_cv_coeffs_num_samples=20)[source]#

Obtain jacobians using control variates.

We will compute each term individually. The first term will use stochastic

gradient estimation. The second term will be computes using Monte Carlo estimation and automatic differentiation to compute nabla_{theta} h(x; theta). The the third term will be computed using automatic differentiation, as we restrict ourselves to control variates which compute this expectation in closed form.

This function updates the state of the control variate (once), before

computing the control variate coefficients.

Parameters
  • function (Callable[[chex.Array], float]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.

  • control_variate_from_function (Callable[[Callable[[chex.Array], float]], ControlVariate]) – The control variate to use to reduce variance. See control_delta_method and moving_avg_baseline examples.

  • grad_estimator (Callable[..., jnp.ndarray]) – The gradient estimator to be used to compute the gradients. Note that not all control variates will reduce variance for all estimators. For example, the moving_avg_baseline will make no difference to the measure valued or pathwise estimators.

  • params (optax.Params) – A tuple of jnp arrays. The parameters for which to construct the distribution and for which we want to compute the jacobians.

  • dist_builder (Callable[..., Any]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.

  • rng (chex.PRNGKey) – a PRNGKey key.

  • num_samples (int) – Int, the number of samples used to compute the grads.

  • control_variate_state (CvState) – The control variate state. This is used for control variates which keep states (such as the moving average baselines).

  • estimate_cv_coeffs (bool) – Boolean. Whether or not to estimate the optimal control variate coefficient via estimate_control_variate_coefficients.

  • estimate_cv_coeffs_num_samples (int) – The number of samples to use to estimate the optimal coefficient. These need to be new samples to ensure that the objective is unbiased.

Returns

  • A tuple of size params, each element is num_samples x param.shape

    jacobian vector containing the estimates of the gradients obtained for each sample.

    The mean of this vector is the gradient wrt to parameters that can be used for learning. The entire jacobian vector can be used to assess estimator variance.

  • The updated CV state.

Return type

A tuple of size two

moving_avg_baseline#

optax.moving_avg_baseline(function, decay=0.99, zero_debias=True, use_decay_early_training_heuristic=True)[source]#

A moving average baseline.

It has no effect on the pathwise or measure valued estimator.

Parameters
  • function (Callable[[chex.Array], float]) – The function for which to compute the control variate. The function takes in one argument (a sample from the distribution) and returns a floating point value.

  • decay (float) – The decay rate for the moving average.

  • zero_debias (bool) – Whether or not to use zero debiasing for the moving average.

  • use_decay_early_training_heuristic

    Whether or not to use a heuristic which overrides the decay value early in training based on

    min(decay, (1.0 + i) / (10.0 + i)). This stabilises training and was adapted from the Tensorflow codebase.

Return type

ControlVariate

Returns

A tuple of three functions, to compute the control variate, the expected value of the control variate, and to update the control variate state.

Stochastic Gradient Estimators#

measure_valued_jacobians(function, params, ...)

Measure valued gradient estimation.

pathwise_jacobians(function, params, ...)

Pathwise gradient estimation.

score_function_jacobians(function, params, ...)

Score function gradient estimation.

measure_valued_jacobians#

optax.measure_valued_jacobians(function, params, dist_builder, rng, num_samples, coupling=True)[source]#

Measure valued gradient estimation.

Approximates:

nabla_{theta} E_{p(x; theta)} f(x)

With:

1./ c (E_{p1(x; theta)} f(x) - E_{p2(x; theta)} f(x)) where p1 and p2 are measures which depend on p.

Currently only supports computing gradients of expectations of Gaussian RVs.

Parameters
  • function (Callable[[chex.Array], float]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.

  • params (optax.Params) – A tuple of jnp arrays. The parameters for which to construct the distribution.

  • dist_builder (Callable[..., Any]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.

  • rng (chex.PRNGKey) – a PRNGKey key.

  • num_samples (int) – Int, the number of samples used to compute the grads.

  • coupling (bool) – A boolean. Whether or not to use coupling for the positive and negative samples. Recommended: True, as this reduces variance.

Return type

Sequence[chex.Array]

Returns

A tuple of size params, each element is num_samples x param.shape

jacobian vector containing the estimates of the gradients obtained for each sample.

The mean of this vector is the gradient wrt to parameters that can be used

for learning. The entire jacobian vector can be used to assess estimator variance.

pathwise_jacobians#

optax.pathwise_jacobians(function, params, dist_builder, rng, num_samples)[source]#

Pathwise gradient estimation.

Approximates:

nabla_{theta} E_{p(x; theta)} f(x)

With:
E_{p(epsilon)} nabla_{theta} f(g(epsilon, theta))

where x = g(epsilon, theta). g depends on the distribution p.

Requires: p to be reparametrizable and the reparametrization to be implemented

in tensorflow_probability. Applicable to continuous random variables. f needs to be differentiable.

Parameters
  • function (Callable[[chex.Array], float]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.

  • params (optax.Params) – A tuple of jnp arrays. The parameters for which to construct the distribution.

  • dist_builder (Callable[..., Any]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.

  • rng (chex.PRNGKey) – a PRNGKey key.

  • num_samples (int) – Int, the number of samples used to compute the grads.

Return type

Sequence[chex.Array]

Returns

A tuple of size params, each element is num_samples x param.shape

jacobian vector containing the estimates of the gradients obtained for each sample.

The mean of this vector is the gradient wrt to parameters that can be used

for learning. The entire jacobian vector can be used to assess estimator variance.

score_function_jacobians#

optax.score_function_jacobians(function, params, dist_builder, rng, num_samples)[source]#

Score function gradient estimation.

Approximates:

nabla_{theta} E_{p(x; theta)} f(x)

With:

E_{p(x; theta)} f(x) nabla_{theta} log p(x; theta)

Requires: p to be differentiable wrt to theta. Applicable to both continuous

and discrete random variables. No requirements on f.

Parameters
  • function (Callable[[chex.Array], float]) – Function f(x) for which to estimate grads_{params} E_dist f(x). The function takes in one argument (a sample from the distribution) and returns a floating point value.

  • params (optax.Params) – A tuple of jnp arrays. The parameters for which to construct the distribution.

  • dist_builder (Callable[..., Any]) – a constructor which builds a distribution given the input parameters specified by params. dist_builder(params) should return a valid distribution.

  • rng (chex.PRNGKey) – a PRNGKey key.

  • num_samples (int) – Int, the number of samples used to compute the grads.

Return type

Sequence[chex.Array]

Returns

A tuple of size params, each element is num_samples x param.shape

jacobian vector containing the estimates of the gradients obtained for each sample.

The mean of this vector is the gradient wrt to parameters that can be used

for learning. The entire jacobian vector can be used to assess estimator variance.

Privacy-Sensitive Optax Methods#

DifferentiallyPrivateAggregateState(rng_key)

State containing PRNGKey for differentially_private_aggregate.

differentially_private_aggregate(...)

Aggregates gradients based on the DPSGD algorithm.

differentially_private_aggregate#

optax.differentially_private_aggregate(l2_norm_clip, noise_multiplier, seed)[source]#

Aggregates gradients based on the DPSGD algorithm.

WARNING: Unlike other transforms, differentially_private_aggregate expects the input updates to have a batch dimension in the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap). It can still be composed with other transformations as long as it is the first in the chain.

References

[Abadi et al, 2016](https://arxiv.org/abs/1607.00133)

Parameters
  • l2_norm_clip (float) – maximum L2 norm of the per-example gradients.

  • noise_multiplier (float) – ratio of standard deviation to the clipping norm.

  • seed (int) – initial seed used for the jax.random.PRNGKey

Return type

optax.GradientTransformation

Returns

A GradientTransformation.

class optax.DifferentiallyPrivateAggregateState(rng_key: Any)[source]#

State containing PRNGKey for differentially_private_aggregate.

rng_key: Any#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

General Utilities#

multi_normal(loc, log_scale)

rtype

MultiNormalDiagFromLogScale

scale_gradient(inputs, scale)

Scales gradients for the backwards pass.

multi_normal#

optax.multi_normal(loc, log_scale)[source]#
Return type

MultiNormalDiagFromLogScale

scale_gradient#

optax.scale_gradient(inputs, scale)[source]#

Scales gradients for the backwards pass.

Parameters
  • inputs (chex.ArrayTree) – A nested array.

  • scale (float) – The scale factor for the gradient on the backwards pass.

Return type

chex.ArrayTree

Returns

An array of the same structure as inputs, with scaled backward gradient.

🔧 Contrib#

mechanize(base_optimizer[, weight_decay, ...])

Mechanic - a black box learning rate tuner/optimizer.

MechanicState(base_optimizer_state, count, ...)

State of the GradientTransformation returned by mechanize.

🚧 Experimental#

split_real_and_imaginary(inner)

Splits the real and imaginary components of complex updates into two.

SplitRealAndImaginaryState(inner_state)

Maintains the inner transformation state for split_real_and_imaginary.

Complex-Valued Optimization#

optax.experimental.split_real_and_imaginary(inner)[source]#

Splits the real and imaginary components of complex updates into two.

The inner transformation processes real parameters and updates, and the pairs of transformed real updates are merged into complex updates.

Parameters and updates that are real before splitting are passed through unmodified.

Parameters

inner (optax.GradientTransformation) – The inner transformation.

Return type

optax.GradientTransformation

Returns

An optax.GradientTransformation.

class optax.experimental.SplitRealAndImaginaryState(inner_state: base.OptState)[source]#

Maintains the inner transformation state for split_real_and_imaginary.

inner_state: base.OptState#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.