🔧 Contrib#
Algorithms or wrappers that don’t meet (yet) the Inclusion Criteria or are not supported by the main library.
|
The ACProp optimizer. |
|
AdEMAMix. |
|
ADOPT (Adaptive Optimization with Provable Theoretical guarantees). |
|
Simplified AdEMAMix. |
|
Rescale updates according to the COntinuous COin Betting algorithm. |
|
State for COntinuous COin Betting. |
|
Learning rate free AdamW by D-Adaptation. |
|
State of the GradientTransformation returned by dadapt_adamw. |
|
Aggregates gradients based on the DPSGD algorithm. |
|
State containing PRNGKey for differentially_private_aggregate. |
|
Distance over Gradients (DoG) optimizer. |
|
State for DoG optimizer. |
|
Distance over weighted Gradients optimizer. |
|
State for DoWG optimizer. |
|
The DPSGD optimizer. |
|
GaLore: Memory-efficient training via gradient lowrank projection. |
|
The MADGRAD optimizer. |
|
State for the MADGRAD optimizer. |
|
Mechanic - a black box learning rate tuner/optimizer. |
|
State of the GradientTransformation returned by mechanize. |
|
Adaptive Learning Rates for SGD with momentum. |
|
State of the GradientTransformation returned by momo. |
|
Adaptive Learning Rates for Adam(W). |
|
State of the |
|
Muon: Momentum Orthogonalized by Newton-schulz. |
|
State for the Muon algorithm. |
|
Learning rate free AdamW with Prodigy. |
|
State of the GradientTransformation returned by prodigy. |
|
Implementation of SAM (Sharpness Aware Minimization). |
|
State of GradientTransformation returned by sam. |
|
Turn base_optimizer schedule_free. |
|
Schedule-Free wrapper for AdamW. |
|
Params for evaluation of |
|
Schedule-Free wrapper for SGD. |
|
State for schedule_free. |
|
Sophia optimizer. |
|
State for Sophia Optimizer. |
|
Splits the real and imaginary components of complex updates into two. |
|
Maintains the inner transformation state for split_real_and_imaginary. |
AdEMAMix#
- optax.contrib.ademamix(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, b3: base.ScalarOrSchedule = 0.9999, alpha: base.ScalarOrSchedule = 5.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: Any | None = None, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#
AdEMAMix.
AdEMAMix (Adaptive EMA Mixture) is AdamW with a mixture of two momentum terms to better take advantage of historical gradients.
Both SGD with momentum (SGD+M) and Adam incorporate momentum using Exponential Moving Averages (EMAs) of past gradients
Let \(\eta\) represent the learning rate and \(\beta_1, \beta_2\), \(\beta_3, \alpha, \varepsilon, \bar{\varepsilon}\), represent the arguments
b1,b2,b3,alpha,epsandeps_rootrespectively. Let \(\lambda\) be the weight decay and \(\theta_t\) the parameter vector at time \(t\).The
initfunction of this optimizer initializes an internal state \(S_0 := (m^{(1)}_0, m^{(2)}_0, \nu_0) = (0, 0, 0)\), representing initial estimates for the fast and slow EMAs of the first moment along with the second moment estimate. In practice, these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdatefunction of this optimizer takes as arguments the incoming gradients \(g^t\), the optimizer state \(S^t\) and the parameters \(\theta^{(t)}\). It then computes updates \(\theta^{(t+1)}\) and the new state \(S^{(t+1)}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_1^{(t)} &\leftarrow \beta_1 \cdot m_1^{(t-1)} + (1-\beta_1) \cdot g^{(t)} \\ m_2^{(t)} &\leftarrow \beta_3 \cdot m_2^{(t-1)} + (1-\beta_3) \cdot g^{(t)} \\ \nu^{(t)} &\leftarrow \beta_2 \cdot \nu^{(t-1)} + (1-\beta_2) \cdot {g^{(t)}}^2 \\ \hat{m_1}^{(t)} &\leftarrow m_1^{(t)} / {(1-\beta_1^{(t)})} \\ \hat{\nu}^{(t)} &\leftarrow \nu^{(t)} / {(1-\beta_2^{(t)})} \\ \theta^{(t)} &\leftarrow \theta^{(t-1)} - \eta \cdot \left( \frac{(\hat{m_1}^{(t)} + \alpha m_2^{(t)})}{\left(\sqrt{\hat{\nu}^{(t)} + \bar{\varepsilon}} + \varepsilon\right)} + \lambda \theta^{(t-1)} \right).\\ S^{(t)} &\leftarrow (m_1^{(t)}, m_2^{(t)}, v^{(t)}). \end{align*}\]Note
AdEMAMix consists in leveraging very old gradients. Therefore, the method is best suited to settings where the number of iterations is important. The paper reports on this effect in Appendix C.1.5, showing how smaller values of
b3(e.g.b3 = 0.999) can be better for low iterations scenarios. Moreover, retaining gradient information over many thousands of steps can pose a problem in domains requiring fast adaptation to a sudden distribution shift, or general cases in which the distribution is non-stationary.Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(jnp.square(x)) # simple quadratic function >>> solver = optax.contrib.ademamix(learning_rate=0.01) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.38E+01 Objective function: 1.36E+01 Objective function: 1.35E+01 Objective function: 1.34E+01
References
Pagliardini et al, The AdEMAMix Optimizer: Better, Faster, Older, 2024
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().b1 – Exponential decay rate to track the fast EMA.
b2 – Exponential decay rate to track the second moment of past gradients.
b3 – Exponential decay rate to track the slow EMA.
alpha – Mixing coefficient in the linear combination fo the fast and slow EMAs.
eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.
mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.
mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.
- Returns:
The corresponding GradientTransformation.
See also
See the related functions
optax.adam(),optax.nadamw(), as well as the example Recreate AdeMAMix Rosenbrock Plot from Paper for a use case.
- optax.contrib.scale_by_ademamix(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, b3: base.ScalarOrSchedule = 0.9999, alpha: base.ScalarOrSchedule = 6.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformation[source]#
Scale updates according to the Ademamix algorithm.
See
optax.contrib.ademamix.()for a full description of the algorithm.References
Pagliardini et al, The AdEMAMix Optimizer: Better, Faster, Older, 2024
- Parameters:
b1 – Exponential decay rate to track the fast EMA.
b2 – Exponential decay rate to track the second moment of past gradients.
b3 – Exponential decay rate to track the slow EMA.
alpha – Mixing coefficient in the linear combination for the fast and slow EMAs.
eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.
mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
- Returns:
The corresponding GradientTransformation.
- class optax.contrib.ScaleByAdemamixState(count: jax.typing.ArrayLike, count_m2: jax.typing.ArrayLike, m1: optax.Updates, m2: optax.Updates, nu: optax.Updates)[source]#
State for the Ademamix algorithm.
- count#
iteration of the algorithm used to update the fast EMA and second moment.
- Type:
jax.typing.ArrayLike
- count_m2#
iteration of the algorithm used to update the slow EMA and alpha.
- Type:
jax.typing.ArrayLike
- m1#
fast EMA of the first moment
- Type:
base.Updates
- m2#
slow EMA of the first moment
- Type:
base.Updates
- nu#
estimate of the second moment
- Type:
base.Updates
Simplified AdEMAMix#
- optax.contrib.simplified_ademamix(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.99, b2: jax.typing.ArrayLike = 0.95, alpha: base.ScalarOrSchedule = 0.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#
Simplified AdEMAMix.
Simplified AdEMAMix (Adaptive EMA Mixture) is a simplified version of AdEMAMix that eliminates the need for maintaining two separate momentum buffers and removes the requirement for scheduling the mixing parameter \(\alpha\). Setting \(\alpha = 0\) recovers the standard Adam optimizer, subject to appropriate transformations of \(\eta\) and \(\beta_1\).
Let \(\eta\) represent the learning rate and \(\beta_1, \beta_2\), \(\alpha, \varepsilon, \bar{\varepsilon}\), represent the arguments
b1,b2,alpha,epsandeps_rootrespectively. Let \(\lambda\) be the weight decay and \(\theta_t\) the parameter vector at time \(t\).The
initfunction of this optimizer initializes an internal state \(S_0 := (m^{(1)}_0, \nu_0) = (0, 0)\), representing initial estimates for the EMAs of the first moment along with the second moment estimate. In practice, these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdatefunction of this optimizer takes as arguments the incoming gradients \(g^t\), the optimizer state \(S^t\) and the parameters \(\theta^{(t)}\). It then computes updates \(\theta^{(t+1)}\) and the new state \(S^{(t+1)}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_1^{(t)} &\leftarrow \beta_1 \cdot m_1^{(t-1)} + g^{(t)} \\ g^{(t)} \\ \nu^{(t)} &\leftarrow \beta_2 \cdot \nu^{(t-1)} + (1-\beta_2) \cdot {g^{(t)}}^2 \\ \hat{\nu}^{(t)} &\leftarrow \nu^{(t)} / {(1-\beta_2^{(t)})} \\ \theta^{(t)} &\leftarrow \theta^{(t-1)} - \eta \cdot \left( \frac{(m_1^{(t)} + \alpha g^{(t)})}{\left(\sqrt{\hat{\nu}^{(t)} + \bar{\varepsilon}} + \varepsilon\right)} + \lambda \theta^{(t-1)} \right).\\ S^{(t)} &\leftarrow (m_1^{(t)}, m_2^{(t)}, v^{(t)}). \end{align*}\]Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(jnp.square(x)) # simple quadratic function >>> solver = optax.contrib.simplified_ademamix(learning_rate=0.01) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.39E+01 Objective function: 1.36E+01 Objective function: 1.33E+01 Objective function: 1.28E+01 Objective function: 1.23E+01
References
Morwani et al, Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants, 2025
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().b1 – Exponential decay rate to track the EMA.
b2 – Exponential decay rate to track the second moment of past gradients.
alpha – Mixing coefficient for the current gradient and EMA.
eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.
weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.
mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.
- Returns:
The corresponding GradientTransformation.
See also
See the related functions
optax.adam(),optax.nadamw(), as well as the example Recreate AdeMAMix Rosenbrock Plot from Paper.
- optax.contrib.scale_by_simplified_ademamix(b1: jax.typing.ArrayLike = 0.99, b2: jax.typing.ArrayLike = 0.95, alpha: base.ScalarOrSchedule = 0.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0) base.GradientTransformation[source]#
Scale updates according to the Simplified AdEMAMix optimizer.
See
optax.contrib.simplified_ademamix.()for a full description.References
Morwani et al, Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants, 2025
- Parameters:
b1 – Exponential decay rate to track the EMA.
b2 – Exponential decay rate to track the second moment of past gradients.
alpha – Mixing coefficient for the current gradient and EMA.
eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.
- Returns:
The corresponding GradientTransformation.
- class optax.contrib.ScaleBySimplifiedAdEMAMixState(t: jax.typing.ArrayLike, m: optax.Updates, n: optax.Updates)[source]#
State for the Simplified AdEMAMix optimizer.
- t#
iteration count
- Type:
jax.typing.ArrayLike
- m#
EMA
- Type:
base.Updates
- n#
second moment estimate
- Type:
base.Updates
ADOPT#
- optax.contrib.adopt(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.9999, eps: jax.typing.ArrayLike = 1e-06, mu_dtype: Optional[Any] = None, *, nesterov: bool = False, use_clipping: bool = True, clip_value_fn: Callable[[jnp.ndarray], jnp.ndarray] = <function <lambda>>) base.GradientTransformationExtraArgs[source]#
ADOPT (Adaptive Optimization with Provable Theoretical guarantees).
ADOPT is a modified version of Adam that may improve the robustness of Adam with respect to the choice of beta2. This implementation includes an optional clipping operation to improve stability, especially in early training stages.
The key difference from Adam is that ADOPT modifies the update rule to avoid potential instability issues, particularly when some gradient elements are nearzero at initialization. With clipping enabled (default), ADOPT applies a clipping operation to improve stability, particularly in early training stages.
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler.
b1 – Exponential decay rate to track the first moment of past gradients.
b2 – Exponential decay rate to track the second moment of past gradients.
eps – A small constant applied to denominator outside of the square root to avoid dividing by zero when rescaling.
mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
nesterov – Whether to use Nesterov momentum.
use_clipping – Whether to apply clipping to improve stability. Recommended to keep as True, especially for training from scratch.
clip_value_fn – A function that takes a step index and returns a clipping value. Default is \(x^{0.25}\).
- Returns:
The corresponding
optax.GradientTransformationExtraArgs.
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.contrib.adopt(learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01 Objective function: 1.40E+01
References
Taniguchi et al, ADOPT: Modified Adam Can Converge with Any beta2 with the Optimal Rate, NeurIPS 2024
- optax.contrib.scale_by_adopt(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.9999, eps: jax.typing.ArrayLike = 1e-06, mu_dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, *, nesterov: bool = False, use_clipping: bool = True, clip_value_fn: ~typing.Callable[[~jax.jaxlib._jax.Array], ~jax.jaxlib._jax.Array] = <function <lambda>>) optax.GradientTransformation[source]#
Rescale updates according to the ADOPT algorithm.
- Parameters:
b1 – Decay rate for the exponentially weighted average of grads.
b2 – Decay rate for the exponentially weighted average of squared grads.
eps – Term added to the denominator to improve numerical stability.
mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.
nesterov – Whether to use Nesterov momentum.
use_clipping – Whether to use gradient clipping to improve stability. When enabled, the clipping value is determined by the clip_value_fn.
clip_value_fn – A function that takes a step index and returns a clipping value. Default is \(x^{0.25}\).
- Returns:
A
optax.GradientTransformationobject.
See also
Asynchronous-centering-Prop#
- optax.contrib.acprop(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-16, eps_root: jax.typing.ArrayLike = 1e-16, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#
The ACProp optimizer.
Follows the implementation from the original repo in PyTorch: juntang-zhuang/ACProp-Optimizer.
ACProp is an adaptive optimizer that combines centering of second momentum and asynchronous update. For the update at step t, the denominator uses information up to step t-1, while the numerator uses the gradient at step t.
Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments
b1,b2,epsandeps_rootrespectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.The
initfunction of this optimizer initializes an internal state \(S_0 := (m_0, s_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), theupdatefunction of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ s_t &\leftarrow \beta_2 \cdot s_{t-1} + (1-\beta_2) \cdot (g_t - m_t)^2 + \bar{\varepsilon} \\ \hat{s}_t &\leftarrow s_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot g_t / \left(\sqrt{\hat{s}_{t-1}} + \varepsilon \right) \\ S_t &\leftarrow (m_t, s_t). \end{align*}\]- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().b1 – Exponential decay rate to track the first moment of past gradients.
b2 – Exponential decay rate to track the second moment of past gradients.
eps – Term added to the denominator to improve numerical stability.
eps_root – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.
weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.
mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.
- Returns:
The corresponding GradientTransformation.
References
Zuhang et al, Momentum Centering and Asynchronous Update for Adaptive Gradient Methods, 2021
- optax.contrib.scale_by_acprop(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-16, eps_root: jax.typing.ArrayLike = 1e-16) optax.GradientTransformation[source]#
Rescale updates according to ACProp (asynchronous version of AdaBelief).
See
optax.contrib.acprop()for more details.- Parameters:
b1 – Decay rate for the exponentially weighted average of grads.
b2 – Decay rate for the exponentially weighted average of variance of grads.
eps – Term added to the denominator to improve numerical stability.
eps_root – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.
- Returns:
A GradientTransformation object.
Complex-valued Optimization#
- optax.contrib.split_real_and_imaginary(inner: optax.GradientTransformation) optax.GradientTransformation[source]#
Splits the real and imaginary components of complex updates into two.
The inner transformation processes real parameters and updates, and the pairs of transformed real updates are merged into complex updates.
Parameters and updates that are real before splitting are passed through unmodified.
- Parameters:
inner – The inner transformation.
- Returns:
An optax.GradientTransformation.
- class optax.contrib.SplitRealAndImaginaryState(inner_state: optax.OptState)[source]#
Maintains the inner transformation state for split_real_and_imaginary.
Continuous coin betting#
- optax.contrib.cocob(learning_rate: base.ScalarOrSchedule = 1.0, alpha: jax.typing.ArrayLike = 100.0, eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#
Rescale updates according to the COntinuous COin Betting algorithm.
Algorithm for stochastic subgradient descent. Uses a gambling algorithm to find the minimizer of a non-smooth objective function by accessing its subgradients. All we need is a good gambling strategy. See Algorithm 2 of:
- Parameters:
learning_rate – optional learning rate to e.g. inject some scheduler
alpha – fraction to bet parameter of the COCOB optimizer
eps – jitter term to avoid dividing by 0
weight_decay – L2 penalty
mask – mask for weight decay
- Returns:
A GradientTransformation object.
References
Orabana et al, Training Deep Networks without Learning Rates Through Coin Betting, 2017
- class optax.contrib.COCOBState(init_particles: optax.Updates, cumulative_gradients: optax.Updates, scale: optax.Updates, subgradients: optax.Updates, reward: optax.Updates)[source]#
State for COntinuous COin Betting.
D-adaptation#
- optax.contrib.dadapt_adamw(learning_rate: base.ScalarOrSchedule = 1.0, betas: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike] = (0.9, 0.999), eps: jax.typing.ArrayLike = 1e-08, estim_lr0: jax.typing.ArrayLike = 1e-06, weight_decay: jax.typing.ArrayLike = 0.0) base.GradientTransformationExtraArgs[source]#
Learning rate free AdamW by D-Adaptation.
Adapts the baseline learning rate of AdamW automatically by estimating the initial distance to solution in the infinity norm. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.
- Parameters:
learning_rate – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.
betas – Betas for the underlying AdamW Optimizer.
eps – eps for the underlying AdamW Optimizer.
estim_lr0 – Initial (under-)estimate of the learning rate.
weight_decay – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.
- Returns:
The corresponding
optax.GradientTransformation.
References
Defazio et al, Learning-Rate-Free Learning by D-Adaptation, 2023
- class optax.contrib.DAdaptAdamWState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, grad_sum: optax.Updates, estim_lr: jax.typing.ArrayLike, numerator_weighted: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#
State of the GradientTransformation returned by dadapt_adamw.
Differentially Private Aggregate#
- optax.contrib.differentially_private_aggregate(l2_norm_clip: jax.typing.ArrayLike, noise_multiplier: jax.typing.ArrayLike, key: Array | int | None = None, *, seed: int | None = None) optax.GradientTransformation[source]#
Aggregates gradients based on the DPSGD algorithm.
- Parameters:
l2_norm_clip – maximum L2 norm of the per-example gradients.
noise_multiplier – ratio of standard deviation to the clipping norm.
key – random generator key for noise generation.
seed – deprecated, use key instead.
- Returns:
References
Abadi et al, 2016 Deep Learning with Differential Privacy, 2016
Warning
Unlike other transforms, differentially_private_aggregate expects the input updates to have a batch dimension in the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap). It can still be composed with other transformations as long as it is the first in the chain.
Warning
Generic gradient aggregation tools like
optax.MultiStepsoroptax.apply_every()won’t work correctly with this transformation since the whole point of this transformation is to aggregate gradients in a specific way.
- class optax.contrib.DifferentiallyPrivateAggregateState(rng_key: Array)[source]#
State containing PRNGKey for differentially_private_aggregate.
- optax.contrib.dpsgd(learning_rate: base.ScalarOrSchedule, l2_norm_clip: jax.typing.ArrayLike, noise_multiplier: jax.typing.ArrayLike, seed: int, momentum: jax.typing.ArrayLike | None = None, nesterov: bool = False) base.GradientTransformation[source]#
The DPSGD optimizer.
Differential privacy is a standard for privacy guarantees of algorithms learning from aggregate databases including potentially sensitive information. DPSGD offers protection against a strong adversary with full knowledge of the training mechanism and access to the model’s parameters.
- Parameters:
learning_rate – A fixed global scaling factor.
l2_norm_clip – Maximum L2 norm of the per-example gradients.
noise_multiplier – Ratio of standard deviation to the clipping norm.
seed – Initial seed used for the jax.random.PRNGKey
momentum – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.
nesterov – Whether Nesterov momentum is used.
- Returns:
References
Abadi et al, 2016 Deep Learning with Differential Privacy, 2016
Warning
This
optax.GradientTransformationexpects input updates to have a batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap).Warning
Generic gradient aggregation tools like
optax.MultiStepsoroptax.apply_every()won’t work correctly with this transformation since the whole point of this transformation is to aggregate gradients in a specific way.
Distance over Gradients#
- optax.contrib.dog(learning_rate: base.ScalarOrSchedule = 1.0, init_step: tuple[Literal['distance', 'learning_rate', 'heuristic'], jax.typing.ArrayLike] = ('heuristic', 1e-06), eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike | None = None, mask: Any | Callable[[base.Params], Any] | None = None)[source]#
Distance over Gradients (DoG) optimizer.
DoG updates parameters \(x_t\) with stochastic gradients \(g_t\) according to the update rule:
\[\begin{align*} r_t &= \| x_t - x_0 \| \\ \bar{r}_t &= \max_{k \leq t} r_k \\ G_t &= \sum_{k \leq t} \|g_k\|^2 \\ \eta_t &= \frac{\bar{r}_t}{\sqrt{G_t + \epsilon}} \\ x_{t+1} & = x_{t} - \eta_t\, g_t \end{align*}\]- Parameters:
learning_rate – optional learning rate (potentially varying according to some predetermined scheduler).
init_step – Initial step specification. Consists of a pair
(tag, value), wherevalueis a float andtagis a string, which must be one ofdistance,learning_rate, orheuristic.distancesets the initial distance \(r_0\) (\(r_\epsilon\) in the paper) to the given value.learning_ratesets the initial learning rate \(\eta_0\) to the given value.heuristicsets \(r_0 = \alpha (1 + \|x_0\|)\), where \(\alpha\) is the given value. The suggested value of \(\alpha\) is 1e-6, unless the model uses batch normalization, in which case the suggested value is 1e-4. As discussed in the paper, the value should be small enough to ensure that the initial update step will be small enough to not cause the model to diverge.eps – epsilon used for numerical stability - added to the sum of squared norm of gradients.
weight_decay – Strength of the weight decay regularization.
mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the gradient transformations is applied to all parameters.
- Returns:
The corresponding
optax.GradientTransformation.
Examples
>>> import optax >>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.dog() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value) ... params = optax.apply_updates(params, updates) ... print('Objective function: ', f(params)) Objective function: 13.99... Objective function: 13.99... Objective function: 13.99... Objective function: 13.99... Objective function: 13.99...
References
Ivgi et al., DoG is SGD’s Best Friend: A Parameter-Free Dynamic Step Size Schedule, 2023.
Added in version 0.2.3.
Warning
The authors recommend using model averaging with this optimizer.
This optimizer’s
initfunction should receive the actual parameters (not just dummy parameters) when theheuristicinitial step is used.
- class optax.contrib.DoGState(is_init_step: Array, init_params: optax.ArrayTree, max_dist: Array, sum_sq_norm_grads: Array)[source]#
State for DoG optimizer.
- optax.contrib.dowg(learning_rate: base.ScalarOrSchedule = 1.0, init_estim_sq_dist: jax.typing.ArrayLike | None = None, eps: jax.typing.ArrayLike = 0.0001, weight_decay: jax.typing.ArrayLike | None = None, mask: Any | Callable[[base.Params], Any] | None = None)[source]#
Distance over weighted Gradients optimizer.
Examples
>>> import optax >>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.dowg() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... updates, opt_state = solver.update( ... grad, opt_state, params, value=value) ... params = optax.apply_updates(params, updates) ... print('Objective function: ', f(params)) Objective function: 13.925367 Objective function: 13.872763 Objective function: 13.775433 Objective function: 13.596172 Objective function: 13.268837
References
Khaled et al., DoWG Unleashed: An Efficient Universal Parameter-Free Gradient Descent Method, 2023.
- Parameters:
learning_rate – optional learning rate (potentially varying according to some predetermined scheduler).
init_estim_sq_dist – initial guess of the squared distance to solution.
eps – small value to prevent division by zero in the denominator defining, the learning rate, also used as initial guess for the distance to solution if
init_estim_sq_distis None.weight_decay – Strength of the weight decay regularization.
mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the gradient transformations is applied to all parameters.
- Returns:
The corresponding
optax.GradientTransformation.
Added in version 0.2.3.
Galore#
- optax.contrib.galore(learning_rate: base.ScalarOrSchedule, rank: int = 128, update_proj_gap: int = 200, scale: float = 1.0, base_optimizer: base.GradientTransformation | None = None, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None, weight_dimension_numbers: GaLoreDimNumsOrFn | None = None) base.GradientTransformation[source]#
GaLore: Memory-efficient training via gradient lowrank projection.
GaLore (Gradient Low-Rank Projection) is a memory-efficient training strategy that enables full-parameter learning while reducing optimizer state memory by projecting gradients into a low-rank subspace.
The key insight is that gradients of weight matrices in neural networks often exhibit low-rank structure. GaLore exploits this by:
Computing a low-rank projection matrix P using SVD of the gradient
Projecting gradients to a low-rank subspace: R = P^T @ G (or G @ P)
Maintaining optimizer states in the reduced subspace
Projecting updates back to full space: update = P @ normalized_R
For a weight matrix of shape (m, n) with rank r projection:
Standard Adam stores m + v states: 2 * m * n parameters
GaLore stores: 2 * min(r*n, m*r) + projection matrix
This can achieve up to 65% memory reduction for large linear layers.
Note
GaLore only projects 2D weight matrices by default. Use
weight_dimension_numbersto project higher-dimensional tensors (like attention projections stored as 3D arrays).Warning
The
base_optimizermust be a gradient scaling transformation that does NOT require parameter values. Seescale_by_galorefor details on compatible vs incompatible optimizers.Do NOT use:
adamw,lamb,larsas base_optimizer.Use instead:
scale_by_adam,scale_by_lion, etc., and configure weight decay via theweight_decayparameter of this function.Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(jnp.square(x['w'])) >>> solver = optax.contrib.galore(learning_rate=0.01, rank=16) >>> params = {'w': jnp.ones((100, 100)), 'b': jnp.ones((100,))} >>> print('Objective function: ', f(params)) Objective function: 10000.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... print('Objective function: {:.2E}'.format(f(params))) Objective function: 9.98E+03 Objective function: 9.96E+03 Objective function: 9.94E+03 Objective function: 9.92E+03 Objective function: 9.90E+03
- Using weight decay (equivalent to AdamW behavior):
>>> solver = optax.contrib.galore( ... learning_rate=0.01, ... rank=16, ... weight_decay=0.01, # Use this, NOT adamw as base_optimizer ... )
- Using a custom base optimizer:
>>> solver = optax.contrib.galore( ... learning_rate=0.01, ... rank=16, ... base_optimizer=optax.scale_by_adam(b1=0.9, b2=0.99), ... )
- Projecting 3D attention weights as 2D matrices:
>>> from optax.contrib import GaLoreDimensionNumbers >>> # For attention weights shaped (embed_dim, num_heads, head_dim) >>> dim_nums = {'attn': GaLoreDimensionNumbers( ... reduction_axis=0, # embed_dim ... output_axis=(1, 2), # heads*head_dim ... )} >>> solver = optax.contrib.galore( ... learning_rate=0.01, rank=16, weight_dimension_numbers=dim_nums ... )
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler.
rank – Target rank for the low-rank projection. Lower values save more memory but may slow convergence. Default 128 is a good starting point.
update_proj_gap – Number of steps between projection matrix updates. The projectors are recomputed from the gradient SVD every this many steps to adapt to the changing gradient landscape.
scale – Additional scaling factor for updates.
base_optimizer – The base gradient transformation to apply in the low-rank subspace. Must be a gradient-only transformation like
scale_by_adam, NOT an optimizer requiring params likeadamw. If None, defaults tooptax.scale_by_adam(). If the base optimizer includes a learning rate, setlearning_rate=1.0here to avoid double-scaling.weight_decay – Strength of decoupled weight decay regularization (as in AdamW). This is applied correctly in full parameter space, unlike weight decay in the base optimizer which would fail.
mask – A tree with same structure as params PyTree, or a Callable that returns such a pytree. Leaves should be booleans indicating whether to apply weight decay to each parameter.
weight_dimension_numbers – Specifies how to treat non-2D tensors as 2D matrices for projection. See
scale_by_galorefor details.
- Returns:
A GradientTransformation implementing the GaLore optimizer.
References
Zhao et al., GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection, 2024
Mechanize#
- optax.contrib.mechanize(base_optimizer: base.GradientTransformation | base.GradientTransformationExtraArgs, weight_decay: jax.typing.ArrayLike = 0.01, eps: jax.typing.ArrayLike = 1e-08, s_init: jax.typing.ArrayLike = 1e-06, num_betas: int = 6) base.GradientTransformationExtraArgs[source]#
Mechanic - a black box learning rate tuner/optimizer.
Accumulates updates returned by the base_optimizer and learns the scale of the updates (also know as learning rate or step size) to apply on a per iteration basis.
Note that Mechanic does NOT eschew the need for a learning rate schedule, you are free to apply a learning rate schedule with base learning rate set to 1.0 (or any other constant) and Mechanic will learn the right scale factor automatically.
For example, change this:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr) optimizer = optax.adam(learning_rate_fn)
To:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=1.0) optimizer = optax.adam(learning_rate_fn) optimizer = optax.contrib.mechanize(optimizer)
As of June, 2023, Mechanic is tested with SGD, Momentum, Adam and Lion as inner optimizers but we expect it to work with almost any first-order optimizer (except for normalized gradient optimizer like LARS or LAMB).
- Parameters:
base_optimizer – Base optimizer to compute updates from.
weight_decay – A scalar weight decay rate. Note that this weight decay is not the same as the weight decay one would use for the base_optimizer. In addition to sometimes helping converge faster, this helps Mechanic reduce the variance between training runs using different seeds. You likely would not need to tune this, the default should work in most cases.
eps – epsilon for mechanic.
s_init – initial scale factor. Default should work almost all the time.
num_betas – unlike traditional exp accumulators (like 1st or 2nd moment of adam), where one has to choose an explicit beta, mechanic has a clever way to automatically learn the right beta for all accumulators. We only provide the range of possible betas, and not the tuned value. For instance, if you set num_betas to 3, it will use betas = [0.9, 0.99, 0.999].
- Returns:
References
Cutkosky et al, Mechanic: A Learning Rate Tuner 2023
- class optax.contrib.MechanicState(base_optimizer_state: optax.OptState, count: jax.typing.ArrayLike, r: jax.typing.ArrayLike, m: jax.typing.ArrayLike, v: jax.typing.ArrayLike, s: jax.typing.ArrayLike, x0: optax.Updates)[source]#
State of the GradientTransformation returned by mechanize.
Madgrad#
- optax.contrib.madgrad(learning_rate: base.ScalarOrSchedule, momentum: float = 0.9, weight_decay: float = 0.0, eps: float = 1e-06) base.GradientTransformation[source]#
The MADGRAD optimizer.
MADGRAD is a general purpose optimizer that matches the performance of SGD+Momentum on vision tasks and Adam on NLP tasks.
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler.
momentum – Momentum parameter (default: 0.9).
weight_decay – Strength of the weight decay regularization (L2).
eps – Term added to the denominator to improve numerical stability.
- Returns:
The corresponding
optax.GradientTransformation.
References
Defazio et al, Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization, 2021.
- optax.contrib.scale_by_madgrad(learning_rate: base.ScalarOrSchedule, momentum: float = 0.9, eps: float = 1e-06) base.GradientTransformation[source]#
Rescale updates according to the MADGRAD algorithm.
MADGRAD is a Dual Averaging method that maintains a weighted sum of gradients and squared gradients to compute adaptive updates. It effectively bridges the gap between the generalization performance of SGD and the convergence speed of adaptive methods like Adam.
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler.
momentum – Momentum parameter (default: 0.9).
eps – Term added to the denominator to improve numerical stability.
- Returns:
A
optax.GradientTransformationobject.
References
Defazio et al, Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization, 2021.
- class optax.contrib.MadgradState(count: Array, grad_sum_sq: optax.Updates, s: optax.Updates, x0: optax.Params)[source]#
State for the MADGRAD optimizer.
Momo#
- optax.contrib.momo(learning_rate: base.ScalarOrSchedule = 1.0, beta: jax.typing.ArrayLike = 0.9, lower_bound: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, adapt_lower_bound: bool = False) base.GradientTransformationExtraArgs[source]#
Adaptive Learning Rates for SGD with momentum.
MoMo typically needs less tuning for value of
learning_rate, by exploiting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.MoMo performs SGD with momentum with a Polyak-type learning rate. The effective step size is
min(learning_rate, <adaptive term>), where the adaptive term is computed on the fly.Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument
value.- Parameters:
learning_rate – User-specified learning rate. Recommended to be chosen rather large, by default 1.0.
beta – Momentum coefficient (for EMA).
lower_bound – Lower bound of the loss. Zero should be a good choice for many tasks.
weight_decay – Weight-decay parameter.
adapt_lower_bound – If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).
- Returns:
A
optax.GradientTransformationobject.
Examples
>>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.momo() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 3.5 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0
References
Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023
Added in version 0.2.3.
- class optax.contrib.MomoState(exp_avg: optax.Updates, barf: jax.typing.ArrayLike, gamma: jax.typing.ArrayLike, lb: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#
State of the GradientTransformation returned by momo.
- optax.contrib.momo_adam(learning_rate: base.ScalarOrSchedule = 0.01, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, lower_bound: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, adapt_lower_bound: bool = False) base.GradientTransformationExtraArgs[source]#
Adaptive Learning Rates for Adam(W).
MoMo-Adam typically needs less tuning for value of
learning_rate, by exploiting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.MoMo performs Adam(W) with a Polyak-type learning rate. The effective step size is
min(learning_rate, <adaptive term>), where the adaptive term is computed on the fly.Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument
value.- Parameters:
learning_rate – User-specified learning rate. Recommended to be chosen rather large, by default 1.0.
b1 – Exponential decay rate to track the first moment of past gradients.
b2 – Exponential decay rate to track the second moment of past gradients.
eps – eps for the underlying Adam Optimizer.
lower_bound – Lower bound of the loss. Zero should be a good choice for many tasks.
weight_decay – Weight-decay parameter. Momo-Adam performs weight decay in similar fashion to AdamW.
adapt_lower_bound – If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).
- Returns:
A
GradientTransformationobject.
Examples
>>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.momo_adam() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 0.00029999594 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0
References
Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023
Added in version 0.2.3.
- class optax.contrib.MomoAdamState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, barf: jax.typing.ArrayLike, gamma: jax.typing.ArrayLike, lb: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#
State of the
GradientTransformationreturned bymomo_adam.
Muon#
- optax.contrib.muon(learning_rate: base.ScalarOrSchedule, ns_coeffs: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike] | tuple[tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike], ...] | str = (3.4445, -4.775, 2.0315), ns_steps: jax.typing.ArrayLike = 5, beta: jax.typing.ArrayLike = 0.95, eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike = 0.0, weight_decay_mask: Any | Callable[[base.Params], Any] | None = None, mu_dtype: jax.typing.DTypeLike | None = None, *, nesterov: bool = True, adaptive: bool = False, preconditioning: Literal['frobenius', 'spectral', 'aol', 'schatten'] = 'frobenius', adam_b1: jax.typing.ArrayLike = 0.9, adam_b2: jax.typing.ArrayLike = 0.999, adam_eps_root: jax.typing.ArrayLike = 0.0, adam_weight_decay: jax.typing.ArrayLike = 0.0, adam_learning_rate: base.ScalarOrSchedule | None = None, muon_weight_dimension_numbers: WeightDimNumOrFn | None = None, consistent_rms: jax.typing.ArrayLike | None = None) base.GradientTransformation[source]#
Muon: Momentum Orthogonalized by Newton-schulz.
Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize the momentum accumulated by the optimizer. Mathematically, it does steepest descent under the Schatten-p norm, for some large p. With p=infty, it is equivalent to Shampoo without accumulation, or steepest descent under the Spectral norm.
Note that Muon is currently only defined for 2D parameters, i.e. matrices. This is because the Newton-Schulz iterator expects a matrix as input. The non-2D parameters are instead passed through an AdamW optimizer (using a weight decay of 0 as default).
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().ns_coeffs – Coefficients for the Newton-schulz method (can be a string indicator for a preset). Existing presets: muon, dion.
ns_steps – Number of Newton-schulz iterations. Ignored if ns_coeffs is a tuple of tuples.
beta – Decay rate for the exponentially weighted average of grads.
eps – Term added to the denominator to improve numerical stability.
weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.
weight_decay_mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip.
mu_dtype – Data type of the momentum accumulator.
nesterov – Whether to use Nesterov momentum.
adaptive – Whether to scale the updates by the dual norm of the original updates. See <https://arxiv.org/abs/2409.20325>
preconditioning –
What type of preconditioning to use before NS iterations. Available options are: - ‘frobenius’ (default): Use Frobenius rescaling before NS:
safe, standard, but degrades orthogonalization quality when using less than 5 NS steps.
’spectral’ : Use Spectral norm rescaling before NS: much more computationally intensive, but better orthogonalization quality.
’aol’: Use AOL rescalings to improve orthogonality with little to no overhead, usually allows the user to remove one iterative NS step. See <https://arxiv.org/abs/2512.04632>.
’schatten’: Use the Schatten-4 norm for rescaling, allows for better performance with little to no extra cost. See <https://arxiv.org/abs/2506.10935>.
adam_b1 – Exponential decay rate for Adam’s first moment estimates.
adam_b2 – Exponential decay rate for Adam’s second moment estimates.
adam_eps_root – Epsilon to stabilize division in Adam, square root version.
adam_weight_decay – Weight decay factor for Adam.
adam_learning_rate – Auxiliary learning rate for the Adam optimizer. If None, the learning rate for Adam defaults to the same as Muon.
muon_weight_dimension_numbers – An optional tree of MuonDimensionNumbers`s, specifying how to reshape the parameters for orthogonalization otherwise muon parameters are assumed to be 2D matrices. A `None value indicates that the parameter is not a muon parameter and will be optimized with Adam. A callable takes as input the params and returns a possibly masked pytree of specs, similar to weight_decay_mask. If not provided, muon is applied to all 2D parameters.
consistent_rms – An optional float to activate consistent RMS scaling. Scales updates by sqrt(max(fan_in, fan_out)) * consistent_rms to make root mean square (RMS) shape-independent, like AdamW. 0.2 is recommended to match AdamW’s empirical RMS. See <https://arxiv.org/abs/2502.16982>. If None, uses width scaling sqrt(max(1, fan_out / fan_in)).
- Returns:
The corresponding GradientTransformation.
References
Jordan, modded-nanogpt: Speedrunning the NanoGPT baseline, 2024
Bernstein et al., Old Optimizer, New Norm: An Anthology, 2024
Liu et al., Muon is Scalable for LLM Training, <https://arxiv.org/abs/2502.16982>`_, 2025
Boissin et al., Turbo-Muon: Accelerating Orthogonality-Based Optimization with Pre-Conditioning, <https://arxiv.org/abs/2512.04632>`_, 2025
Ahn et al., Dion: Distributed Orthonormalized Updates, <https://arxiv.org/abs/2504.05295>`_, 2025
Grishina et al., Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials, <https://arxiv.org/abs/2506.10935>`_, 2025
Amsel et al., The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm, <https://arxiv.org/pdf/2505.16932>`, 2025
- optax.contrib.scale_by_muon(ns_coeffs: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike] | tuple[tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike], ...] = (3.4445, -4.775, 2.0315), ns_steps: jax.typing.ArrayLike = 5, beta: jax.typing.ArrayLike = 0.95, eps: jax.typing.ArrayLike = 1e-08, mu_dtype: jax.typing.DTypeLike | None = None, *, nesterov: bool = True, adaptive: bool = False, preconditioning: Literal['frobenius', 'spectral', 'aol', 'schatten'] = 'frobenius', weight_dimension_numbers: WeightDimNumOrFn | None = None) base.GradientTransformation[source]#
Rescale updates according to the Muon algorithm.
Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize the momentum accumulated by the optimizer. Mathematically, it does steepest descent under the Schatten-p norm, for some large p. With p=infty, it is equivalent to Shampoo without accumulation, or steepest descent under the Spectral norm.
- Parameters:
ns_coeffs – Coefficients for the Newton-schulz method.
ns_steps – Number of Newton-schulz iterations. Ignored if ns_coeffs is a tuple of tuples.
beta – Decay rate for the exponentially weighted average of grads.
eps – Term added to denominators to improve numerical stability.
mu_dtype – Data type of the momentum accumulator.
nesterov – Whether to use Nesterov momentum.
adaptive – Whether to scale the updates by the dual norm of the original updates. See <https://arxiv.org/abs/2409.20325>
preconditioning – What type of preconditioning to use before NS iterations. Available options are: - ‘frobenius’ (default): Use Frobenius rescaling before NS. - ‘spectral’ : Use Spectral norm rescaling before NS. - ‘aol’: Use AOL rescaling to improve orthogonality. - ‘schatten’: Use the Schatten-4 norm for rescaling.
weight_dimension_numbers – An optional tree with the same structure as the params of `MuonDimensionNumbers`s, specifying how to reshape the parameters before and after the orthogonalization OR a callable returning such a tree. None implies that all parameters are 2D matrices.
- Returns:
A GradientTransformation object.
References
Jordan, modded-nanogpt: Speedrunning the NanoGPT baseline, 2024
Bernstein et al., Old Optimizer, New Norm: An Anthology, 2024
Liu et al., Muon is Scalable for LLM Training, <https://arxiv.org/abs/2502.16982>`_, 2025
Boissin et al., Turbo-Muon: Accelerating Orthogonality-Based Optimization with Pre-Conditioning, <https://arxiv.org/abs/2512.04632>`_, 2025
Ahn et al., Dion: Distributed Orthonormalized Updates, <https://arxiv.org/abs/2504.05295>`_, 2025
Grishina et al., Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials, <https://arxiv.org/abs/2506.10935>`_, 2025
Amsel et al., The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm, <https://arxiv.org/pdf/2505.16932>`, 2025
- class optax.contrib.MuonState(count: jax.typing.ArrayLike, mu: optax.Updates, ns_coeffs: jax.typing.ArrayLike)[source]#
State for the Muon algorithm.
Prodigy#
- optax.contrib.prodigy(learning_rate: base.ScalarOrSchedule = 1.0, betas: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike] = (0.9, 0.999), beta3: jax.typing.ArrayLike | None = None, eps: jax.typing.ArrayLike = 1e-08, estim_lr0: jax.typing.ArrayLike = 1e-06, estim_lr_coef: jax.typing.ArrayLike = 1.0, weight_decay: jax.typing.ArrayLike = 0.0, safeguard_warmup: bool = False) base.GradientTransformationExtraArgs[source]#
Learning rate free AdamW with Prodigy.
Implementation of the Prodigy method from “Prodigy: An Expeditiously Adaptive Parameter-Free Learner”, a version of D-Adapt AdamW that adapts the baseline learning rate faster by using a weighting of the gradients that places higher weights on more recent gradients. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.
- Parameters:
learning_rate – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.
betas – Betas for the underlying AdamW Optimizer.
beta3 – Optional momentum parameter for estimation of D.
eps – eps for the underlying AdamW Optimizer.
estim_lr0 – Initial (under-)estimate of the learning rate.
estim_lr_coef – LR estimates are multiplied by this parameter.
weight_decay – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.
safeguard_warmup – Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default.
- Returns:
A
optax.GradientTransformationobject.
References
Mishchenko et al, Prodigy: An Expeditiously Adaptive Parameter-Free Learner, 2023
- class optax.contrib.ProdigyState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, grad_sum: optax.Updates, params0: optax.Updates, estim_lr: jax.typing.ArrayLike, numerator_weighted: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#
State of the GradientTransformation returned by prodigy.
Schedule-Free#
- optax.contrib.schedule_free(base_optimizer: base.GradientTransformation, learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, weight_lr_power: jax.typing.ArrayLike = 2.0, state_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformationExtraArgs[source]#
Turn base_optimizer schedule_free.
Accumulates updates returned by the base_optimizer w/o Momentum and replaces the momentum of an underlying optimizer with a combination of interpolation and averaging. In the case of gradient descent the update is
\[\begin{align*} y_{t} & = (1-\beta_1)z_{t} + \beta_1 x_{t},\\ z_{t+1} & =z_{t}-\gamma\nabla f(y_{t}),\\ x_{t+1} & =\left(1-\frac{1}{t}\right)x_{t}+\frac{1}{t}z_{t+1}, \end{align*}\]Here \(x\) is the sequence that evaluations of test/val loss should occur at, which differs from the primary iterates \(z\) and the gradient evaluation locations \(y\). The updates to \(z\) correspond to the underlying optimizer, in this case a simple gradient step. Note that, \(\beta_1\) corresponds to b1 in the code.
As the name suggests, Schedule-Free learning does not require a decreasing learning rate schedule, yet typically out-performs, or at worst matches, SOTA schedules such as cosine-decay and linear decay. Only two sequences need to be stored at a time (the third can be computed from the other two on the fly) so this method has the same memory requirements as the base optimizer (parameter buffer + momentum).
In practice, authors recommend tuning \(\beta_1\), warmup_steps and peak_lr for each problem separately. Default for \(\beta_1\) is 0.9 but 0.95 and 0.98 may also work well. Schedule-Free can be wrapped on top of any optax optimizer. At test time, the parameters should be evaluated using
optax.contrib.schedule_free_eval_params()as presented below.For example, change this:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr) optimizer = optax.adam(learning_rate_fn, b1=b1)
To:
learning_rate_fn = optax.warmup_constant_schedule(peak_value=retuned_lr) optimizer = optax.adam(learning_rate_fn, b1=0.) optimizer = optax.contrib.schedule_free(optimizer, learning_rate_fn, b1=b1) .. params_for_eval = optax.contrib.schedule_free_eval_params(state, params)
Especially note that is important to switch off Momentum of the base optimizer. As of Apr, 2024, schedule_free is tested with SGD and Adam.
- Parameters:
base_optimizer – Base optimizer to compute updates from.
learning_rate – learning_rate schedule w/o decay but with warmup.
b1 – beta_1 parameter in the y update.
weight_lr_power – we downweight the weight of averaging using this. This is especially helpful in early iterations during warmup.
state_dtype – dtype for z sequence in the schedule free method.
- Returns:
References
Defazio et al, The Road Less Scheduled, 2024
Defazio et al, Schedule-Free Learning - A New Way to Train, 2024
Warning
The current implementation requires the parameter
b1to be strictly positive.
- optax.contrib.schedule_free_adamw(learning_rate: jax.typing.ArrayLike = 0.0025, warmup_steps: int | None = None, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike = 0.0, weight_lr_power: jax.typing.ArrayLike = 2.0, state_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformationExtraArgs[source]#
Schedule-Free wrapper for AdamW.
Shortcut example for using schedule_free with AdamW, which is a common use case. Note that this is just an example, and other usecases are possible, e.g. using a weight decay mask, nesterov, etc. Note also that the EMA parameter of the schedule free method (b1) must be strictly positive.
- Parameters:
learning_rate – AdamW learning rate.
warmup_steps – positive integer, the length of the linear warmup.
b1 – beta_1 parameter in the y update.
b2 – Exponential decay rate to track the second moment of past gradients.
eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.
weight_decay – Strength of the weight decay regularization.
weight_lr_power – we downweight the weight of averaging using this. This is especially helpful in early iterations during warmup.
state_dtype – dtype for z sequence in the schedule free method.
- Returns:
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.contrib.schedule_free_adamw(1.0) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... eval_params = optax.contrib.schedule_free_eval_params( ... opt_state, params) ... print('Objective function: {:.2E}'.format(f(eval_params))) Objective function: 5.00E+00 Objective function: 3.05E+00 Objective function: 1.73E+00 Objective function: 8.94E-01 Objective function: 4.13E-01
Note
Note that
optax.scale_by_adam()withb1=0stores in its state an unused first moment always equal to zero. To avoid this waste of memory, we replaceoptax.scale_by_adam()withb1=0by the equivalentoptax.scale_by_rms()witheps_in_sqrt=False, bias_correction=True.
- optax.contrib.schedule_free_eval_params(state: optax.OptState, params: optax.Params)[source]#
Params for evaluation of
optax.contrib.schedule_free().
- optax.contrib.schedule_free_sgd(learning_rate: jax.typing.ArrayLike = 1.0, warmup_steps: int | None = None, b1: jax.typing.ArrayLike = 0.9, weight_decay: jax.typing.ArrayLike | None = None, weight_lr_power: jax.typing.ArrayLike = 2.0, state_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformationExtraArgs[source]#
Schedule-Free wrapper for SGD.
Shortcut example for using schedule_free with SGD, which is a common use case. Note that this is just an example, and other use cases are possible, e.g. using a weight decay mask. Note also that the EMA parameter of the schedule free method (b1) must be strictly positive.
- Parameters:
learning_rate – SGD learning rate.
warmup_steps – positive integer, the length of the linear warmup.
b1 – beta_1 parameter in the y update.
weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.
weight_lr_power – we downweight the weight of averaging using this. This is especially helpful in early iterations during warmup.
state_dtype – dtype for z sequence in the schedule free method.
- Returns:
Examples
>>> import optax >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = optax.contrib.schedule_free_sgd() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... grad = jax.grad(f)(params) ... updates, opt_state = solver.update(grad, opt_state, params) ... params = optax.apply_updates(params, updates) ... eval_params = optax.contrib.schedule_free_eval_params( ... opt_state, params) ... print('Objective function: {:.2E}'.format(f(eval_params))) Objective function: 1.40E+01 Objective function: 1.75E-14 Objective function: 9.96E-01 Objective function: 8.06E-01 Objective function: 2.41E-01
- class optax.contrib.ScheduleFreeState(b1: jax.typing.ArrayLike, weight_sum: jax.typing.ArrayLike, step_count: jax.typing.ArrayLike, max_lr: jax.typing.ArrayLike, base_optimizer_state: optax.OptState, z: optax.Params)[source]#
State for schedule_free.
Sophia#
- optax.contrib.hutchinson_estimator_diag_hessian(random_seed: Array | None = None)[source]#
Returns a GradientTransformationExtraArgs computing the Hessian diagonal.
The Hessian diagonal is estimated using Hutchinson’s estimator, which is unbiased but has high variance.
- Parameters:
random_seed – key used to generate random vectors.
- Returns:
GradientTransformationExtraArgs
- optax.contrib.sophia(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.965, b2: jax.typing.ArrayLike = 0.99, eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike = 0.0001, weight_decay_mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, gamma: jax.typing.ArrayLike = 0.01, clip_threshold: Optional[jax.typing.ArrayLike] = 1.0, update_interval: jax.typing.ArrayLike = 10, hessian_diagonal_fn: Union[base.GradientTransformation, base.GradientTransformationExtraArgs] = (<function hutchinson_estimator_diag_hessian.<locals>.init_fn>, <function hutchinson_estimator_diag_hessian.<locals>.update_fn>), mu_dtype: Optional[Any] = None, verbose: bool = False, print_win_rate_every_n_steps: jax.typing.ArrayLike = 0) base.GradientTransformationExtraArgs[source]#
Sophia optimizer.
A separate GradientTransformation is required through the argument hessian_diagonal_fn to compute the diagonal of the Hessian. Any extra arguments required by the hessian_diagonal_fn’s update function can be passed through sophia’s update function as trailing keyword arguments (**kwargs). The default hessian_diagonal_fn is Hutchinson’s estimator and needs the objective function as an extra argument, obj_fn. obj_fn must accept params as its only argument and return only a scalar (the loss).
For example, assuming your experiment’s loss function is loss_fn(params, batch) -> loss, aux that takes multiple arguments and returns multiple outputs, we must modify it to loss_fn(params) -> loss:
obj_fn = lambda params: loss_fn(params, batch)[0]
where batch is the current step’s batch.
Then it can be passed to sophia’s update function (which will pass it to the hessian_diagonal_fn’s update function):
updates, state = sophia.update(updates, state, params, obj_fn=sophia_obj_fn)
Optionally, you can write your own GradientTransformation to compute the hessian diagonal. Use this file’s hutchinson_estimator_diag_hessian function as an example. If you are using more than one device, be sure the hessian diagonal function properly averages the hessian diagonal across devices. The default hessian_diagonal_fn does not do this, and would cause params to diverge from each other across devices if using pmap for example.
- Parameters:
learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see
optax.scale_by_learning_rate().b1 – Exponential decay rate for the first moment estimates.
b2 – Exponential decay rate for the hessian diagonal estimates. Keep in mind effective b2 is 1 - (1 - b2) / update_interval, e.g. default b2 of 0.99 is effectively 0.999 because default update_interval is every 10.
eps – Small constant to avoid division by zero.
weight_decay – Rate at which to decay weights.
weight_decay_mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.
gamma – Normalizing constant for the hessian diagonal.
clip_threshold – Threshold for clipping updates.
update_interval – Interval for updating the hessian diagonal.
hessian_diagonal_fn – GradientTransformation that computes the diagonal of the Hessian. Default is Hutchinson’s estimator (sophia-h). If using more than one device, be sure this function properly averages the hessian diagonal across devices.
mu_dtype – dtype of the first moment estimates.
verbose – If True, print win rate every n steps.
print_win_rate_every_n_steps – Print sophia win rate every n steps for diagnostic purposes. Authors state this value should stay between 0.1 and 0.5 during training. If win rate is too low, try increasing gamma. 0 to turn off.
- Returns:
optax.GradientTransformationExtraArgs
References
Liu et al., Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training, 2023
Note
We use a rademacher vector to estimate the diagonal of the Hessian, contrary to the original implementation which uses a normal random vector.