🔧 Contrib#

Algorithms or wrappers that don’t meet (yet) the Inclusion Criteria or are not supported by the main library.

acprop(learning_rate[, b1, b2, eps, ...])

The ACProp optimizer.

ademamix(learning_rate[, b1, b2, b3, alpha, ...])

AdEMAMix.

adopt(learning_rate, b1, b2, eps, mu_dtype, ...)

ADOPT (Adaptive Optimization with Provable Theoretical guarantees).

simplified_ademamix(learning_rate[, b1, b2, ...])

Simplified AdEMAMix.

cocob([learning_rate, alpha, eps, ...])

Rescale updates according to the COntinuous COin Betting algorithm.

COCOBState(init_particles, ...)

State for COntinuous COin Betting.

dadapt_adamw([learning_rate, betas, eps, ...])

Learning rate free AdamW by D-Adaptation.

DAdaptAdamWState(exp_avg, exp_avg_sq, ...)

State of the GradientTransformation returned by dadapt_adamw.

differentially_private_aggregate(...[, key, ...])

Aggregates gradients based on the DPSGD algorithm.

DifferentiallyPrivateAggregateState(rng_key)

State containing PRNGKey for differentially_private_aggregate.

dog([learning_rate, init_step, eps, ...])

Distance over Gradients (DoG) optimizer.

DoGState(is_init_step, init_params, ...)

State for DoG optimizer.

dowg([learning_rate, init_estim_sq_dist, ...])

Distance over weighted Gradients optimizer.

DoWGState(init_params, ...)

State for DoWG optimizer.

dpsgd(learning_rate, l2_norm_clip, ...[, ...])

The DPSGD optimizer.

galore(learning_rate[, rank, ...])

GaLore: Memory-efficient training via gradient lowrank projection.

madgrad(learning_rate[, momentum, ...])

The MADGRAD optimizer.

MadgradState(count, grad_sum_sq, s, x0)

State for the MADGRAD optimizer.

mechanize(base_optimizer[, weight_decay, ...])

Mechanic - a black box learning rate tuner/optimizer.

MechanicState(base_optimizer_state, count, ...)

State of the GradientTransformation returned by mechanize.

momo([learning_rate, beta, lower_bound, ...])

Adaptive Learning Rates for SGD with momentum.

MomoState(exp_avg, barf, gamma, lb, count)

State of the GradientTransformation returned by momo.

momo_adam([learning_rate, b1, b2, eps, ...])

Adaptive Learning Rates for Adam(W).

MomoAdamState(exp_avg, exp_avg_sq, barf, ...)

State of the GradientTransformation returned by momo_adam.

muon(learning_rate[, ns_coeffs, ns_steps, ...])

Muon: Momentum Orthogonalized by Newton-schulz.

MuonState(count, mu, ns_coeffs)

State for the Muon algorithm.

prodigy([learning_rate, betas, beta3, eps, ...])

Learning rate free AdamW with Prodigy.

ProdigyState(exp_avg, exp_avg_sq, grad_sum, ...)

State of the GradientTransformation returned by prodigy.

sam(optimizer, adv_optimizer[, sync_period, ...])

Implementation of SAM (Sharpness Aware Minimization).

SAMState(steps_since_sync, opt_state, ...)

State of GradientTransformation returned by sam.

schedule_free(base_optimizer, learning_rate)

Turn base_optimizer schedule_free.

schedule_free_adamw([learning_rate, ...])

Schedule-Free wrapper for AdamW.

schedule_free_eval_params(state, params)

Params for evaluation of optax.contrib.schedule_free().

schedule_free_sgd([learning_rate, ...])

Schedule-Free wrapper for SGD.

ScheduleFreeState(b1, weight_sum, ...)

State for schedule_free.

sophia(learning_rate, b1, b2, eps, ...)

Sophia optimizer.

SophiaState(count, mu, nu, hessian_fn_state)

State for Sophia Optimizer.

split_real_and_imaginary(inner)

Splits the real and imaginary components of complex updates into two.

SplitRealAndImaginaryState(inner_state)

Maintains the inner transformation state for split_real_and_imaginary.

AdEMAMix#

optax.contrib.ademamix(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, b3: base.ScalarOrSchedule = 0.9999, alpha: base.ScalarOrSchedule = 5.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: Any | None = None, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#

AdEMAMix.

AdEMAMix (Adaptive EMA Mixture) is AdamW with a mixture of two momentum terms to better take advantage of historical gradients.

Both SGD with momentum (SGD+M) and Adam incorporate momentum using Exponential Moving Averages (EMAs) of past gradients

Let \(\eta\) represent the learning rate and \(\beta_1, \beta_2\), \(\beta_3, \alpha, \varepsilon, \bar{\varepsilon}\), represent the arguments b1, b2, b3, alpha, eps and eps_root respectively. Let \(\lambda\) be the weight decay and \(\theta_t\) the parameter vector at time \(t\).

The init function of this optimizer initializes an internal state \(S_0 := (m^{(1)}_0, m^{(2)}_0, \nu_0) = (0, 0, 0)\), representing initial estimates for the fast and slow EMAs of the first moment along with the second moment estimate. In practice, these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g^t\), the optimizer state \(S^t\) and the parameters \(\theta^{(t)}\). It then computes updates \(\theta^{(t+1)}\) and the new state \(S^{(t+1)}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_1^{(t)} &\leftarrow \beta_1 \cdot m_1^{(t-1)} + (1-\beta_1) \cdot g^{(t)} \\ m_2^{(t)} &\leftarrow \beta_3 \cdot m_2^{(t-1)} + (1-\beta_3) \cdot g^{(t)} \\ \nu^{(t)} &\leftarrow \beta_2 \cdot \nu^{(t-1)} + (1-\beta_2) \cdot {g^{(t)}}^2 \\ \hat{m_1}^{(t)} &\leftarrow m_1^{(t)} / {(1-\beta_1^{(t)})} \\ \hat{\nu}^{(t)} &\leftarrow \nu^{(t)} / {(1-\beta_2^{(t)})} \\ \theta^{(t)} &\leftarrow \theta^{(t-1)} - \eta \cdot \left( \frac{(\hat{m_1}^{(t)} + \alpha m_2^{(t)})}{\left(\sqrt{\hat{\nu}^{(t)} + \bar{\varepsilon}} + \varepsilon\right)} + \lambda \theta^{(t-1)} \right).\\ S^{(t)} &\leftarrow (m_1^{(t)}, m_2^{(t)}, v^{(t)}). \end{align*}\]

Note

AdEMAMix consists in leveraging very old gradients. Therefore, the method is best suited to settings where the number of iterations is important. The paper reports on this effect in Appendix C.1.5, showing how smaller values of b3 (e.g. b3 = 0.999) can be better for low iterations scenarios. Moreover, retaining gradient information over many thousands of steps can pose a problem in domains requiring fast adaptation to a sudden distribution shift, or general cases in which the distribution is non-stationary.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(jnp.square(x))  # simple quadratic function
>>> solver = optax.contrib.ademamix(learning_rate=0.01)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.38E+01
Objective function: 1.36E+01
Objective function: 1.35E+01
Objective function: 1.34E+01

References

Pagliardini et al, The AdEMAMix Optimizer: Better, Faster, Older, 2024

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 – Exponential decay rate to track the fast EMA.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • b3 – Exponential decay rate to track the slow EMA.

  • alpha – Mixing coefficient in the linear combination fo the fast and slow EMAs.

  • eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

Returns:

The corresponding GradientTransformation.

See also

See the related functions optax.adam(), optax.nadamw(), as well as the example Recreate AdeMAMix Rosenbrock Plot from Paper for a use case.

optax.contrib.scale_by_ademamix(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, b3: base.ScalarOrSchedule = 0.9999, alpha: base.ScalarOrSchedule = 6.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformation[source]#

Scale updates according to the Ademamix algorithm.

See optax.contrib.ademamix.() for a full description of the algorithm.

References

Pagliardini et al, The AdEMAMix Optimizer: Better, Faster, Older, 2024

Parameters:
  • b1 – Exponential decay rate to track the fast EMA.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • b3 – Exponential decay rate to track the slow EMA.

  • alpha – Mixing coefficient in the linear combination for the fast and slow EMAs.

  • eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

Returns:

The corresponding GradientTransformation.

class optax.contrib.ScaleByAdemamixState(count: jax.typing.ArrayLike, count_m2: jax.typing.ArrayLike, m1: optax.Updates, m2: optax.Updates, nu: optax.Updates)[source]#

State for the Ademamix algorithm.

count#

iteration of the algorithm used to update the fast EMA and second moment.

Type:

jax.typing.ArrayLike

count_m2#

iteration of the algorithm used to update the slow EMA and alpha.

Type:

jax.typing.ArrayLike

m1#

fast EMA of the first moment

Type:

base.Updates

m2#

slow EMA of the first moment

Type:

base.Updates

nu#

estimate of the second moment

Type:

base.Updates

Simplified AdEMAMix#

optax.contrib.simplified_ademamix(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.99, b2: jax.typing.ArrayLike = 0.95, alpha: base.ScalarOrSchedule = 0.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#

Simplified AdEMAMix.

Simplified AdEMAMix (Adaptive EMA Mixture) is a simplified version of AdEMAMix that eliminates the need for maintaining two separate momentum buffers and removes the requirement for scheduling the mixing parameter \(\alpha\). Setting \(\alpha = 0\) recovers the standard Adam optimizer, subject to appropriate transformations of \(\eta\) and \(\beta_1\).

Let \(\eta\) represent the learning rate and \(\beta_1, \beta_2\), \(\alpha, \varepsilon, \bar{\varepsilon}\), represent the arguments b1, b2, alpha, eps and eps_root respectively. Let \(\lambda\) be the weight decay and \(\theta_t\) the parameter vector at time \(t\).

The init function of this optimizer initializes an internal state \(S_0 := (m^{(1)}_0, \nu_0) = (0, 0)\), representing initial estimates for the EMAs of the first moment along with the second moment estimate. In practice, these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g^t\), the optimizer state \(S^t\) and the parameters \(\theta^{(t)}\). It then computes updates \(\theta^{(t+1)}\) and the new state \(S^{(t+1)}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_1^{(t)} &\leftarrow \beta_1 \cdot m_1^{(t-1)} + g^{(t)} \\ g^{(t)} \\ \nu^{(t)} &\leftarrow \beta_2 \cdot \nu^{(t-1)} + (1-\beta_2) \cdot {g^{(t)}}^2 \\ \hat{\nu}^{(t)} &\leftarrow \nu^{(t)} / {(1-\beta_2^{(t)})} \\ \theta^{(t)} &\leftarrow \theta^{(t-1)} - \eta \cdot \left( \frac{(m_1^{(t)} + \alpha g^{(t)})}{\left(\sqrt{\hat{\nu}^{(t)} + \bar{\varepsilon}} + \varepsilon\right)} + \lambda \theta^{(t-1)} \right).\\ S^{(t)} &\leftarrow (m_1^{(t)}, m_2^{(t)}, v^{(t)}). \end{align*}\]

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(jnp.square(x))  # simple quadratic function
>>> solver = optax.contrib.simplified_ademamix(learning_rate=0.01)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.39E+01
Objective function: 1.36E+01
Objective function: 1.33E+01
Objective function: 1.28E+01
Objective function: 1.23E+01

References

Morwani et al, Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants, 2025

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 – Exponential decay rate to track the EMA.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • alpha – Mixing coefficient for the current gradient and EMA.

  • eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

  • weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

Returns:

The corresponding GradientTransformation.

See also

See the related functions optax.adam(), optax.nadamw(), as well as the example Recreate AdeMAMix Rosenbrock Plot from Paper.

optax.contrib.scale_by_simplified_ademamix(b1: jax.typing.ArrayLike = 0.99, b2: jax.typing.ArrayLike = 0.95, alpha: base.ScalarOrSchedule = 0.0, eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0) base.GradientTransformation[source]#

Scale updates according to the Simplified AdEMAMix optimizer.

See optax.contrib.simplified_ademamix.() for a full description.

References

Morwani et al, Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants, 2025

Parameters:
  • b1 – Exponential decay rate to track the EMA.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • alpha – Mixing coefficient for the current gradient and EMA.

  • eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • eps_root – A small constant applied to denominator inside the square root (as in RMSProp), to avoid dividing by zero when rescaling. This is needed for instance when computing (meta-)gradients through Adam.

Returns:

The corresponding GradientTransformation.

class optax.contrib.ScaleBySimplifiedAdEMAMixState(t: jax.typing.ArrayLike, m: optax.Updates, n: optax.Updates)[source]#

State for the Simplified AdEMAMix optimizer.

t#

iteration count

Type:

jax.typing.ArrayLike

m#

EMA

Type:

base.Updates

n#

second moment estimate

Type:

base.Updates

ADOPT#

optax.contrib.adopt(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.9999, eps: jax.typing.ArrayLike = 1e-06, mu_dtype: Optional[Any] = None, *, nesterov: bool = False, use_clipping: bool = True, clip_value_fn: Callable[[jnp.ndarray], jnp.ndarray] = <function <lambda>>) base.GradientTransformationExtraArgs[source]#

ADOPT (Adaptive Optimization with Provable Theoretical guarantees).

ADOPT is a modified version of Adam that may improve the robustness of Adam with respect to the choice of beta2. This implementation includes an optional clipping operation to improve stability, especially in early training stages.

The key difference from Adam is that ADOPT modifies the update rule to avoid potential instability issues, particularly when some gradient elements are nearzero at initialization. With clipping enabled (default), ADOPT applies a clipping operation to improve stability, particularly in early training stages.

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler.

  • b1 – Exponential decay rate to track the first moment of past gradients.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • eps – A small constant applied to denominator outside of the square root to avoid dividing by zero when rescaling.

  • mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • nesterov – Whether to use Nesterov momentum.

  • use_clipping – Whether to apply clipping to improve stability. Recommended to keep as True, especially for training from scratch.

  • clip_value_fn – A function that takes a step index and returns a clipping value. Default is \(x^{0.25}\).

Returns:

The corresponding optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.contrib.adopt(learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01
Objective function: 1.40E+01

References

Taniguchi et al, ADOPT: Modified Adam Can Converge with Any beta2 with the Optimal Rate, NeurIPS 2024

optax.contrib.scale_by_adopt(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.9999, eps: jax.typing.ArrayLike = 1e-06, mu_dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | None = None, *, nesterov: bool = False, use_clipping: bool = True, clip_value_fn: ~typing.Callable[[~jax.jaxlib._jax.Array], ~jax.jaxlib._jax.Array] = <function <lambda>>) optax.GradientTransformation[source]#

Rescale updates according to the ADOPT algorithm.

Parameters:
  • b1 – Decay rate for the exponentially weighted average of grads.

  • b2 – Decay rate for the exponentially weighted average of squared grads.

  • eps – Term added to the denominator to improve numerical stability.

  • mu_dtype – Optional dtype to be used for the first order accumulator; if None then the dtype is inferred from params and updates.

  • nesterov – Whether to use Nesterov momentum.

  • use_clipping – Whether to use gradient clipping to improve stability. When enabled, the clipping value is determined by the clip_value_fn.

  • clip_value_fn – A function that takes a step index and returns a clipping value. Default is \(x^{0.25}\).

Returns:

A optax.GradientTransformation object.

Asynchronous-centering-Prop#

optax.contrib.acprop(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-16, eps_root: jax.typing.ArrayLike = 1e-16, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#

The ACProp optimizer.

Follows the implementation from the original repo in PyTorch: juntang-zhuang/ACProp-Optimizer.

ACProp is an adaptive optimizer that combines centering of second momentum and asynchronous update. For the update at step t, the denominator uses information up to step t-1, while the numerator uses the gradient at step t.

Let \(\alpha_t\) represent the learning rate and \(\beta_1, \beta_2\), \(\varepsilon\), \(\bar{\varepsilon}\) represent the arguments b1, b2, eps and eps_root respectively. The learning rate is indexed by \(t\) since the learning rate may also be provided by a schedule function.

The init function of this optimizer initializes an internal state \(S_0 := (m_0, s_0) = (0, 0)\), representing initial estimates for the first and second moments. In practice these values are stored as pytrees containing all zeros, with the same shape as the model updates. At step \(t\), the update function of this optimizer takes as arguments the incoming gradients \(g_t\) and optimizer state \(S_t\) and computes updates \(u_t\) and new state \(S_{t+1}\). Thus, for \(t > 0\), we have,

\[\begin{align*} m_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\ s_t &\leftarrow \beta_2 \cdot s_{t-1} + (1-\beta_2) \cdot (g_t - m_t)^2 + \bar{\varepsilon} \\ \hat{s}_t &\leftarrow s_t / {(1-\beta_2^t)} \\ u_t &\leftarrow -\alpha_t \cdot g_t / \left(\sqrt{\hat{s}_{t-1}} + \varepsilon \right) \\ S_t &\leftarrow (m_t, s_t). \end{align*}\]
Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 – Exponential decay rate to track the first moment of past gradients.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • eps – Term added to the denominator to improve numerical stability.

  • eps_root – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.

  • weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the Adam gradient transformations are applied to all parameters.

Returns:

The corresponding GradientTransformation.

References

Zuhang et al, Momentum Centering and Asynchronous Update for Adaptive Gradient Methods, 2021

optax.contrib.scale_by_acprop(b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-16, eps_root: jax.typing.ArrayLike = 1e-16) optax.GradientTransformation[source]#

Rescale updates according to ACProp (asynchronous version of AdaBelief).

See optax.contrib.acprop() for more details.

Parameters:
  • b1 – Decay rate for the exponentially weighted average of grads.

  • b2 – Decay rate for the exponentially weighted average of variance of grads.

  • eps – Term added to the denominator to improve numerical stability.

  • eps_root – Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero.

Returns:

A GradientTransformation object.

Complex-valued Optimization#

optax.contrib.split_real_and_imaginary(inner: optax.GradientTransformation) optax.GradientTransformation[source]#

Splits the real and imaginary components of complex updates into two.

The inner transformation processes real parameters and updates, and the pairs of transformed real updates are merged into complex updates.

Parameters and updates that are real before splitting are passed through unmodified.

Parameters:

inner – The inner transformation.

Returns:

An optax.GradientTransformation.

class optax.contrib.SplitRealAndImaginaryState(inner_state: optax.OptState)[source]#

Maintains the inner transformation state for split_real_and_imaginary.

Continuous coin betting#

optax.contrib.cocob(learning_rate: base.ScalarOrSchedule = 1.0, alpha: jax.typing.ArrayLike = 100.0, eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None) base.GradientTransformation[source]#

Rescale updates according to the COntinuous COin Betting algorithm.

Algorithm for stochastic subgradient descent. Uses a gambling algorithm to find the minimizer of a non-smooth objective function by accessing its subgradients. All we need is a good gambling strategy. See Algorithm 2 of:

Parameters:
  • learning_rate – optional learning rate to e.g. inject some scheduler

  • alpha – fraction to bet parameter of the COCOB optimizer

  • eps – jitter term to avoid dividing by 0

  • weight_decay – L2 penalty

  • mask – mask for weight decay

Returns:

A GradientTransformation object.

References

Orabana et al, Training Deep Networks without Learning Rates Through Coin Betting, 2017

class optax.contrib.COCOBState(init_particles: optax.Updates, cumulative_gradients: optax.Updates, scale: optax.Updates, subgradients: optax.Updates, reward: optax.Updates)[source]#

State for COntinuous COin Betting.

D-adaptation#

optax.contrib.dadapt_adamw(learning_rate: base.ScalarOrSchedule = 1.0, betas: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike] = (0.9, 0.999), eps: jax.typing.ArrayLike = 1e-08, estim_lr0: jax.typing.ArrayLike = 1e-06, weight_decay: jax.typing.ArrayLike = 0.0) base.GradientTransformationExtraArgs[source]#

Learning rate free AdamW by D-Adaptation.

Adapts the baseline learning rate of AdamW automatically by estimating the initial distance to solution in the infinity norm. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.

Parameters:
  • learning_rate – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.

  • betas – Betas for the underlying AdamW Optimizer.

  • eps – eps for the underlying AdamW Optimizer.

  • estim_lr0 – Initial (under-)estimate of the learning rate.

  • weight_decay – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.

Returns:

The corresponding optax.GradientTransformation.

References

Defazio et al, Learning-Rate-Free Learning by D-Adaptation, 2023

class optax.contrib.DAdaptAdamWState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, grad_sum: optax.Updates, estim_lr: jax.typing.ArrayLike, numerator_weighted: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#

State of the GradientTransformation returned by dadapt_adamw.

Differentially Private Aggregate#

optax.contrib.differentially_private_aggregate(l2_norm_clip: jax.typing.ArrayLike, noise_multiplier: jax.typing.ArrayLike, key: Array | int | None = None, *, seed: int | None = None) optax.GradientTransformation[source]#

Aggregates gradients based on the DPSGD algorithm.

Parameters:
  • l2_norm_clip – maximum L2 norm of the per-example gradients.

  • noise_multiplier – ratio of standard deviation to the clipping norm.

  • key – random generator key for noise generation.

  • seed – deprecated, use key instead.

Returns:

A optax.GradientTransformation.

References

Abadi et al, 2016 Deep Learning with Differential Privacy, 2016

Warning

Unlike other transforms, differentially_private_aggregate expects the input updates to have a batch dimension in the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap). It can still be composed with other transformations as long as it is the first in the chain.

Warning

Generic gradient aggregation tools like optax.MultiSteps or optax.apply_every() won’t work correctly with this transformation since the whole point of this transformation is to aggregate gradients in a specific way.

class optax.contrib.DifferentiallyPrivateAggregateState(rng_key: Array)[source]#

State containing PRNGKey for differentially_private_aggregate.

optax.contrib.dpsgd(learning_rate: base.ScalarOrSchedule, l2_norm_clip: jax.typing.ArrayLike, noise_multiplier: jax.typing.ArrayLike, seed: int, momentum: jax.typing.ArrayLike | None = None, nesterov: bool = False) base.GradientTransformation[source]#

The DPSGD optimizer.

Differential privacy is a standard for privacy guarantees of algorithms learning from aggregate databases including potentially sensitive information. DPSGD offers protection against a strong adversary with full knowledge of the training mechanism and access to the model’s parameters.

Parameters:
  • learning_rate – A fixed global scaling factor.

  • l2_norm_clip – Maximum L2 norm of the per-example gradients.

  • noise_multiplier – Ratio of standard deviation to the clipping norm.

  • seed – Initial seed used for the jax.random.PRNGKey

  • momentum – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov – Whether Nesterov momentum is used.

Returns:

A optax.GradientTransformation.

References

Abadi et al, 2016 Deep Learning with Differential Privacy, 2016

Warning

This optax.GradientTransformation expects input updates to have a batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap).

Warning

Generic gradient aggregation tools like optax.MultiSteps or optax.apply_every() won’t work correctly with this transformation since the whole point of this transformation is to aggregate gradients in a specific way.

Distance over Gradients#

optax.contrib.dog(learning_rate: base.ScalarOrSchedule = 1.0, init_step: tuple[Literal['distance', 'learning_rate', 'heuristic'], jax.typing.ArrayLike] = ('heuristic', 1e-06), eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike | None = None, mask: Any | Callable[[base.Params], Any] | None = None)[source]#

Distance over Gradients (DoG) optimizer.

DoG updates parameters \(x_t\) with stochastic gradients \(g_t\) according to the update rule:

\[\begin{align*} r_t &= \| x_t - x_0 \| \\ \bar{r}_t &= \max_{k \leq t} r_k \\ G_t &= \sum_{k \leq t} \|g_k\|^2 \\ \eta_t &= \frac{\bar{r}_t}{\sqrt{G_t + \epsilon}} \\ x_{t+1} & = x_{t} - \eta_t\, g_t \end{align*}\]
Parameters:
  • learning_rate – optional learning rate (potentially varying according to some predetermined scheduler).

  • init_step – Initial step specification. Consists of a pair (tag, value), where value is a float and tag is a string, which must be one of distance, learning_rate, or heuristic. distance sets the initial distance \(r_0\) (\(r_\epsilon\) in the paper) to the given value. learning_rate sets the initial learning rate \(\eta_0\) to the given value. heuristic sets \(r_0 = \alpha (1 + \|x_0\|)\), where \(\alpha\) is the given value. The suggested value of \(\alpha\) is 1e-6, unless the model uses batch normalization, in which case the suggested value is 1e-4. As discussed in the paper, the value should be small enough to ensure that the initial update step will be small enough to not cause the model to diverge.

  • eps – epsilon used for numerical stability - added to the sum of squared norm of gradients.

  • weight_decay – Strength of the weight decay regularization.

  • mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the gradient transformations is applied to all parameters.

Returns:

The corresponding optax.GradientTransformation.

Examples

>>> import optax
>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = contrib.dog()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  updates, opt_state = solver.update(
...    grad, opt_state, params, value=value)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: ', f(params))
Objective function:  13.99...
Objective function:  13.99...
Objective function:  13.99...
Objective function:  13.99...
Objective function:  13.99...

References

Ivgi et al., DoG is SGD’s Best Friend: A Parameter-Free Dynamic Step Size Schedule, 2023.

Added in version 0.2.3.

Warning

The authors recommend using model averaging with this optimizer.

This optimizer’s init function should receive the actual parameters (not just dummy parameters) when the heuristic initial step is used.

class optax.contrib.DoGState(is_init_step: Array, init_params: optax.ArrayTree, max_dist: Array, sum_sq_norm_grads: Array)[source]#

State for DoG optimizer.

optax.contrib.dowg(learning_rate: base.ScalarOrSchedule = 1.0, init_estim_sq_dist: jax.typing.ArrayLike | None = None, eps: jax.typing.ArrayLike = 0.0001, weight_decay: jax.typing.ArrayLike | None = None, mask: Any | Callable[[base.Params], Any] | None = None)[source]#

Distance over weighted Gradients optimizer.

Examples

>>> import optax
>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = contrib.dowg()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  updates, opt_state = solver.update(
...    grad, opt_state, params, value=value)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: ', f(params))
Objective function:  13.925367
Objective function:  13.872763
Objective function:  13.775433
Objective function:  13.596172
Objective function:  13.268837

References

Khaled et al., DoWG Unleashed: An Efficient Universal Parameter-Free Gradient Descent Method, 2023.

Parameters:
  • learning_rate – optional learning rate (potentially varying according to some predetermined scheduler).

  • init_estim_sq_dist – initial guess of the squared distance to solution.

  • eps – small value to prevent division by zero in the denominator defining, the learning rate, also used as initial guess for the distance to solution if init_estim_sq_dist is None.

  • weight_decay – Strength of the weight decay regularization.

  • mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the gradient transformations is applied to all parameters.

Returns:

The corresponding optax.GradientTransformation.

Added in version 0.2.3.

class optax.contrib.DoWGState(init_params: optax.ArrayTree, weighted_sq_norm_grads: Array, estim_sq_dist: Array)[source]#

State for DoWG optimizer.

Galore#

optax.contrib.galore(learning_rate: base.ScalarOrSchedule, rank: int = 128, update_proj_gap: int = 200, scale: float = 1.0, base_optimizer: base.GradientTransformation | None = None, weight_decay: jax.typing.ArrayLike = 0.0, mask: Any | Callable[[base.Params], Any] | None = None, weight_dimension_numbers: GaLoreDimNumsOrFn | None = None) base.GradientTransformation[source]#

GaLore: Memory-efficient training via gradient lowrank projection.

GaLore (Gradient Low-Rank Projection) is a memory-efficient training strategy that enables full-parameter learning while reducing optimizer state memory by projecting gradients into a low-rank subspace.

The key insight is that gradients of weight matrices in neural networks often exhibit low-rank structure. GaLore exploits this by:

  1. Computing a low-rank projection matrix P using SVD of the gradient

  2. Projecting gradients to a low-rank subspace: R = P^T @ G (or G @ P)

  3. Maintaining optimizer states in the reduced subspace

  4. Projecting updates back to full space: update = P @ normalized_R

For a weight matrix of shape (m, n) with rank r projection:

  • Standard Adam stores m + v states: 2 * m * n parameters

  • GaLore stores: 2 * min(r*n, m*r) + projection matrix

This can achieve up to 65% memory reduction for large linear layers.

Note

GaLore only projects 2D weight matrices by default. Use weight_dimension_numbers to project higher-dimensional tensors (like attention projections stored as 3D arrays).

Warning

The base_optimizer must be a gradient scaling transformation that does NOT require parameter values. See scale_by_galore for details on compatible vs incompatible optimizers.

Do NOT use: adamw, lamb, lars as base_optimizer.

Use instead: scale_by_adam, scale_by_lion, etc., and configure weight decay via the weight_decay parameter of this function.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(jnp.square(x['w']))
>>> solver = optax.contrib.galore(learning_rate=0.01, rank=16)
>>> params = {'w': jnp.ones((100, 100)), 'b': jnp.ones((100,))}
>>> print('Objective function: ', f(params))
Objective function:  10000.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  print('Objective function: {:.2E}'.format(f(params)))
Objective function: 9.98E+03
Objective function: 9.96E+03
Objective function: 9.94E+03
Objective function: 9.92E+03
Objective function: 9.90E+03
Using weight decay (equivalent to AdamW behavior):
>>> solver = optax.contrib.galore(
...     learning_rate=0.01,
...     rank=16,
...     weight_decay=0.01,  # Use this, NOT adamw as base_optimizer
... )
Using a custom base optimizer:
>>> solver = optax.contrib.galore(
...     learning_rate=0.01,
...     rank=16,
...     base_optimizer=optax.scale_by_adam(b1=0.9, b2=0.99),
... )
Projecting 3D attention weights as 2D matrices:
>>> from optax.contrib import GaLoreDimensionNumbers
>>> # For attention weights shaped (embed_dim, num_heads, head_dim)
>>> dim_nums = {'attn': GaLoreDimensionNumbers(
...     reduction_axis=0,      # embed_dim
...     output_axis=(1, 2),    # heads*head_dim
... )}
>>> solver = optax.contrib.galore(
...     learning_rate=0.01, rank=16, weight_dimension_numbers=dim_nums
... )
Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler.

  • rank – Target rank for the low-rank projection. Lower values save more memory but may slow convergence. Default 128 is a good starting point.

  • update_proj_gap – Number of steps between projection matrix updates. The projectors are recomputed from the gradient SVD every this many steps to adapt to the changing gradient landscape.

  • scale – Additional scaling factor for updates.

  • base_optimizer – The base gradient transformation to apply in the low-rank subspace. Must be a gradient-only transformation like scale_by_adam, NOT an optimizer requiring params like adamw. If None, defaults to optax.scale_by_adam(). If the base optimizer includes a learning rate, set learning_rate=1.0 here to avoid double-scaling.

  • weight_decay – Strength of decoupled weight decay regularization (as in AdamW). This is applied correctly in full parameter space, unlike weight decay in the base optimizer which would fail.

  • mask – A tree with same structure as params PyTree, or a Callable that returns such a pytree. Leaves should be booleans indicating whether to apply weight decay to each parameter.

  • weight_dimension_numbers – Specifies how to treat non-2D tensors as 2D matrices for projection. See scale_by_galore for details.

Returns:

A GradientTransformation implementing the GaLore optimizer.

References

Zhao et al., GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection, 2024

Mechanize#

optax.contrib.mechanize(base_optimizer: base.GradientTransformation | base.GradientTransformationExtraArgs, weight_decay: jax.typing.ArrayLike = 0.01, eps: jax.typing.ArrayLike = 1e-08, s_init: jax.typing.ArrayLike = 1e-06, num_betas: int = 6) base.GradientTransformationExtraArgs[source]#

Mechanic - a black box learning rate tuner/optimizer.

Accumulates updates returned by the base_optimizer and learns the scale of the updates (also know as learning rate or step size) to apply on a per iteration basis.

Note that Mechanic does NOT eschew the need for a learning rate schedule, you are free to apply a learning rate schedule with base learning rate set to 1.0 (or any other constant) and Mechanic will learn the right scale factor automatically.

For example, change this:

learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr)
optimizer = optax.adam(learning_rate_fn)

To:

learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=1.0)
optimizer = optax.adam(learning_rate_fn)
optimizer = optax.contrib.mechanize(optimizer)

As of June, 2023, Mechanic is tested with SGD, Momentum, Adam and Lion as inner optimizers but we expect it to work with almost any first-order optimizer (except for normalized gradient optimizer like LARS or LAMB).

Parameters:
  • base_optimizer – Base optimizer to compute updates from.

  • weight_decay – A scalar weight decay rate. Note that this weight decay is not the same as the weight decay one would use for the base_optimizer. In addition to sometimes helping converge faster, this helps Mechanic reduce the variance between training runs using different seeds. You likely would not need to tune this, the default should work in most cases.

  • eps – epsilon for mechanic.

  • s_init – initial scale factor. Default should work almost all the time.

  • num_betas – unlike traditional exp accumulators (like 1st or 2nd moment of adam), where one has to choose an explicit beta, mechanic has a clever way to automatically learn the right beta for all accumulators. We only provide the range of possible betas, and not the tuned value. For instance, if you set num_betas to 3, it will use betas = [0.9, 0.99, 0.999].

Returns:

A optax.GradientTransformation.

References

Cutkosky et al, Mechanic: A Learning Rate Tuner 2023

class optax.contrib.MechanicState(base_optimizer_state: optax.OptState, count: jax.typing.ArrayLike, r: jax.typing.ArrayLike, m: jax.typing.ArrayLike, v: jax.typing.ArrayLike, s: jax.typing.ArrayLike, x0: optax.Updates)[source]#

State of the GradientTransformation returned by mechanize.

Madgrad#

optax.contrib.madgrad(learning_rate: base.ScalarOrSchedule, momentum: float = 0.9, weight_decay: float = 0.0, eps: float = 1e-06) base.GradientTransformation[source]#

The MADGRAD optimizer.

MADGRAD is a general purpose optimizer that matches the performance of SGD+Momentum on vision tasks and Adam on NLP tasks.

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler.

  • momentum – Momentum parameter (default: 0.9).

  • weight_decay – Strength of the weight decay regularization (L2).

  • eps – Term added to the denominator to improve numerical stability.

Returns:

The corresponding optax.GradientTransformation.

References

Defazio et al, Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization, 2021.

optax.contrib.scale_by_madgrad(learning_rate: base.ScalarOrSchedule, momentum: float = 0.9, eps: float = 1e-06) base.GradientTransformation[source]#

Rescale updates according to the MADGRAD algorithm.

MADGRAD is a Dual Averaging method that maintains a weighted sum of gradients and squared gradients to compute adaptive updates. It effectively bridges the gap between the generalization performance of SGD and the convergence speed of adaptive methods like Adam.

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler.

  • momentum – Momentum parameter (default: 0.9).

  • eps – Term added to the denominator to improve numerical stability.

Returns:

A optax.GradientTransformation object.

References

Defazio et al, Adaptivity without Compromise: A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic Optimization, 2021.

class optax.contrib.MadgradState(count: Array, grad_sum_sq: optax.Updates, s: optax.Updates, x0: optax.Params)[source]#

State for the MADGRAD optimizer.

Momo#

optax.contrib.momo(learning_rate: base.ScalarOrSchedule = 1.0, beta: jax.typing.ArrayLike = 0.9, lower_bound: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, adapt_lower_bound: bool = False) base.GradientTransformationExtraArgs[source]#

Adaptive Learning Rates for SGD with momentum.

MoMo typically needs less tuning for value of learning_rate, by exploiting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.

MoMo performs SGD with momentum with a Polyak-type learning rate. The effective step size is min(learning_rate, <adaptive term>), where the adaptive term is computed on the fly.

Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument value.

Parameters:
  • learning_rate – User-specified learning rate. Recommended to be chosen rather large, by default 1.0.

  • beta – Momentum coefficient (for EMA).

  • lower_bound – Lower bound of the loss. Zero should be a good choice for many tasks.

  • weight_decay – Weight-decay parameter.

  • adapt_lower_bound – If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).

Returns:

A optax.GradientTransformation object.

Examples

>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = contrib.momo()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  params, opt_state = solver.update(grad, opt_state, params, value=value)
...  print('Objective function: ', f(params))
Objective function:  3.5
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0

References

Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023

Added in version 0.2.3.

class optax.contrib.MomoState(exp_avg: optax.Updates, barf: jax.typing.ArrayLike, gamma: jax.typing.ArrayLike, lb: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#

State of the GradientTransformation returned by momo.

optax.contrib.momo_adam(learning_rate: base.ScalarOrSchedule = 0.01, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, lower_bound: jax.typing.ArrayLike = 0.0, weight_decay: jax.typing.ArrayLike = 0.0, adapt_lower_bound: bool = False) base.GradientTransformationExtraArgs[source]#

Adaptive Learning Rates for Adam(W).

MoMo-Adam typically needs less tuning for value of learning_rate, by exploiting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.

MoMo performs Adam(W) with a Polyak-type learning rate. The effective step size is min(learning_rate, <adaptive term>), where the adaptive term is computed on the fly.

Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument value.

Parameters:
  • learning_rate – User-specified learning rate. Recommended to be chosen rather large, by default 1.0.

  • b1 – Exponential decay rate to track the first moment of past gradients.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • eps – eps for the underlying Adam Optimizer.

  • lower_bound – Lower bound of the loss. Zero should be a good choice for many tasks.

  • weight_decay – Weight-decay parameter. Momo-Adam performs weight decay in similar fashion to AdamW.

  • adapt_lower_bound – If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).

Returns:

A GradientTransformation object.

Examples

>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = contrib.momo_adam()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  params, opt_state = solver.update(grad, opt_state, params, value=value)
...  print('Objective function: ', f(params))
Objective function:  0.00029999594
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0

References

Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023

Added in version 0.2.3.

class optax.contrib.MomoAdamState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, barf: jax.typing.ArrayLike, gamma: jax.typing.ArrayLike, lb: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#

State of the GradientTransformation returned by momo_adam.

Muon#

optax.contrib.muon(learning_rate: base.ScalarOrSchedule, ns_coeffs: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike] | tuple[tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike], ...] | str = (3.4445, -4.775, 2.0315), ns_steps: jax.typing.ArrayLike = 5, beta: jax.typing.ArrayLike = 0.95, eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike = 0.0, weight_decay_mask: Any | Callable[[base.Params], Any] | None = None, mu_dtype: jax.typing.DTypeLike | None = None, *, nesterov: bool = True, adaptive: bool = False, preconditioning: Literal['frobenius', 'spectral', 'aol', 'schatten'] = 'frobenius', adam_b1: jax.typing.ArrayLike = 0.9, adam_b2: jax.typing.ArrayLike = 0.999, adam_eps_root: jax.typing.ArrayLike = 0.0, adam_weight_decay: jax.typing.ArrayLike = 0.0, adam_learning_rate: base.ScalarOrSchedule | None = None, muon_weight_dimension_numbers: WeightDimNumOrFn | None = None, consistent_rms: jax.typing.ArrayLike | None = None) base.GradientTransformation[source]#

Muon: Momentum Orthogonalized by Newton-schulz.

Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize the momentum accumulated by the optimizer. Mathematically, it does steepest descent under the Schatten-p norm, for some large p. With p=infty, it is equivalent to Shampoo without accumulation, or steepest descent under the Spectral norm.

Note that Muon is currently only defined for 2D parameters, i.e. matrices. This is because the Newton-Schulz iterator expects a matrix as input. The non-2D parameters are instead passed through an AdamW optimizer (using a weight decay of 0 as default).

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • ns_coeffs – Coefficients for the Newton-schulz method (can be a string indicator for a preset). Existing presets: muon, dion.

  • ns_steps – Number of Newton-schulz iterations. Ignored if ns_coeffs is a tuple of tuples.

  • beta – Decay rate for the exponentially weighted average of grads.

  • eps – Term added to the denominator to improve numerical stability.

  • weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • weight_decay_mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip.

  • mu_dtype – Data type of the momentum accumulator.

  • nesterov – Whether to use Nesterov momentum.

  • adaptive – Whether to scale the updates by the dual norm of the original updates. See <https://arxiv.org/abs/2409.20325>

  • preconditioning

    What type of preconditioning to use before NS iterations. Available options are: - ‘frobenius’ (default): Use Frobenius rescaling before NS:

    safe, standard, but degrades orthogonalization quality when using less than 5 NS steps.

    • ’spectral’ : Use Spectral norm rescaling before NS: much more computationally intensive, but better orthogonalization quality.

    • ’aol’: Use AOL rescalings to improve orthogonality with little to no overhead, usually allows the user to remove one iterative NS step. See <https://arxiv.org/abs/2512.04632>.

    • ’schatten’: Use the Schatten-4 norm for rescaling, allows for better performance with little to no extra cost. See <https://arxiv.org/abs/2506.10935>.

  • adam_b1 – Exponential decay rate for Adam’s first moment estimates.

  • adam_b2 – Exponential decay rate for Adam’s second moment estimates.

  • adam_eps_root – Epsilon to stabilize division in Adam, square root version.

  • adam_weight_decay – Weight decay factor for Adam.

  • adam_learning_rate – Auxiliary learning rate for the Adam optimizer. If None, the learning rate for Adam defaults to the same as Muon.

  • muon_weight_dimension_numbers – An optional tree of MuonDimensionNumbers`s, specifying how to reshape the parameters for orthogonalization otherwise muon parameters are assumed to be 2D matrices. A `None value indicates that the parameter is not a muon parameter and will be optimized with Adam. A callable takes as input the params and returns a possibly masked pytree of specs, similar to weight_decay_mask. If not provided, muon is applied to all 2D parameters.

  • consistent_rms – An optional float to activate consistent RMS scaling. Scales updates by sqrt(max(fan_in, fan_out)) * consistent_rms to make root mean square (RMS) shape-independent, like AdamW. 0.2 is recommended to match AdamW’s empirical RMS. See <https://arxiv.org/abs/2502.16982>. If None, uses width scaling sqrt(max(1, fan_out / fan_in)).

Returns:

The corresponding GradientTransformation.

References

Jordan, modded-nanogpt: Speedrunning the NanoGPT baseline, 2024

Bernstein et al., Old Optimizer, New Norm: An Anthology, 2024

Liu et al., Muon is Scalable for LLM Training, <https://arxiv.org/abs/2502.16982>`_, 2025

Boissin et al., Turbo-Muon: Accelerating Orthogonality-Based Optimization with Pre-Conditioning, <https://arxiv.org/abs/2512.04632>`_, 2025

Ahn et al., Dion: Distributed Orthonormalized Updates, <https://arxiv.org/abs/2504.05295>`_, 2025

Grishina et al., Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials, <https://arxiv.org/abs/2506.10935>`_, 2025

Amsel et al., The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm, <https://arxiv.org/pdf/2505.16932>`, 2025

optax.contrib.scale_by_muon(ns_coeffs: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike] | tuple[tuple[jax.typing.ArrayLike, jax.typing.ArrayLike, jax.typing.ArrayLike], ...] = (3.4445, -4.775, 2.0315), ns_steps: jax.typing.ArrayLike = 5, beta: jax.typing.ArrayLike = 0.95, eps: jax.typing.ArrayLike = 1e-08, mu_dtype: jax.typing.DTypeLike | None = None, *, nesterov: bool = True, adaptive: bool = False, preconditioning: Literal['frobenius', 'spectral', 'aol', 'schatten'] = 'frobenius', weight_dimension_numbers: WeightDimNumOrFn | None = None) base.GradientTransformation[source]#

Rescale updates according to the Muon algorithm.

Muon is a variant of Shampoo that uses the Newton-schulz method to orthogonalize the momentum accumulated by the optimizer. Mathematically, it does steepest descent under the Schatten-p norm, for some large p. With p=infty, it is equivalent to Shampoo without accumulation, or steepest descent under the Spectral norm.

Parameters:
  • ns_coeffs – Coefficients for the Newton-schulz method.

  • ns_steps – Number of Newton-schulz iterations. Ignored if ns_coeffs is a tuple of tuples.

  • beta – Decay rate for the exponentially weighted average of grads.

  • eps – Term added to denominators to improve numerical stability.

  • mu_dtype – Data type of the momentum accumulator.

  • nesterov – Whether to use Nesterov momentum.

  • adaptive – Whether to scale the updates by the dual norm of the original updates. See <https://arxiv.org/abs/2409.20325>

  • preconditioning – What type of preconditioning to use before NS iterations. Available options are: - ‘frobenius’ (default): Use Frobenius rescaling before NS. - ‘spectral’ : Use Spectral norm rescaling before NS. - ‘aol’: Use AOL rescaling to improve orthogonality. - ‘schatten’: Use the Schatten-4 norm for rescaling.

  • weight_dimension_numbers – An optional tree with the same structure as the params of `MuonDimensionNumbers`s, specifying how to reshape the parameters before and after the orthogonalization OR a callable returning such a tree. None implies that all parameters are 2D matrices.

Returns:

A GradientTransformation object.

References

Jordan, modded-nanogpt: Speedrunning the NanoGPT baseline, 2024

Bernstein et al., Old Optimizer, New Norm: An Anthology, 2024

Liu et al., Muon is Scalable for LLM Training, <https://arxiv.org/abs/2502.16982>`_, 2025

Boissin et al., Turbo-Muon: Accelerating Orthogonality-Based Optimization with Pre-Conditioning, <https://arxiv.org/abs/2512.04632>`_, 2025

Ahn et al., Dion: Distributed Orthonormalized Updates, <https://arxiv.org/abs/2504.05295>`_, 2025

Grishina et al., Accelerating Newton-Schulz Iteration for Orthogonalization via Chebyshev-type Polynomials, <https://arxiv.org/abs/2506.10935>`_, 2025

Amsel et al., The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm, <https://arxiv.org/pdf/2505.16932>`, 2025

class optax.contrib.MuonState(count: jax.typing.ArrayLike, mu: optax.Updates, ns_coeffs: jax.typing.ArrayLike)[source]#

State for the Muon algorithm.

Prodigy#

optax.contrib.prodigy(learning_rate: base.ScalarOrSchedule = 1.0, betas: tuple[jax.typing.ArrayLike, jax.typing.ArrayLike] = (0.9, 0.999), beta3: jax.typing.ArrayLike | None = None, eps: jax.typing.ArrayLike = 1e-08, estim_lr0: jax.typing.ArrayLike = 1e-06, estim_lr_coef: jax.typing.ArrayLike = 1.0, weight_decay: jax.typing.ArrayLike = 0.0, safeguard_warmup: bool = False) base.GradientTransformationExtraArgs[source]#

Learning rate free AdamW with Prodigy.

Implementation of the Prodigy method from “Prodigy: An Expeditiously Adaptive Parameter-Free Learner”, a version of D-Adapt AdamW that adapts the baseline learning rate faster by using a weighting of the gradients that places higher weights on more recent gradients. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.

Parameters:
  • learning_rate – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.

  • betas – Betas for the underlying AdamW Optimizer.

  • beta3 – Optional momentum parameter for estimation of D.

  • eps – eps for the underlying AdamW Optimizer.

  • estim_lr0 – Initial (under-)estimate of the learning rate.

  • estim_lr_coef – LR estimates are multiplied by this parameter.

  • weight_decay – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.

  • safeguard_warmup – Remove lr from the denominator of D estimate to avoid issues during warm-up stage. Off by default.

Returns:

A optax.GradientTransformation object.

References

Mishchenko et al, Prodigy: An Expeditiously Adaptive Parameter-Free Learner, 2023

class optax.contrib.ProdigyState(exp_avg: optax.Updates, exp_avg_sq: optax.Updates, grad_sum: optax.Updates, params0: optax.Updates, estim_lr: jax.typing.ArrayLike, numerator_weighted: jax.typing.ArrayLike, count: jax.typing.ArrayLike)[source]#

State of the GradientTransformation returned by prodigy.

Schedule-Free#

optax.contrib.schedule_free(base_optimizer: base.GradientTransformation, learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.9, weight_lr_power: jax.typing.ArrayLike = 2.0, state_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformationExtraArgs[source]#

Turn base_optimizer schedule_free.

Accumulates updates returned by the base_optimizer w/o Momentum and replaces the momentum of an underlying optimizer with a combination of interpolation and averaging. In the case of gradient descent the update is

\[\begin{align*} y_{t} & = (1-\beta_1)z_{t} + \beta_1 x_{t},\\ z_{t+1} & =z_{t}-\gamma\nabla f(y_{t}),\\ x_{t+1} & =\left(1-\frac{1}{t}\right)x_{t}+\frac{1}{t}z_{t+1}, \end{align*}\]

Here \(x\) is the sequence that evaluations of test/val loss should occur at, which differs from the primary iterates \(z\) and the gradient evaluation locations \(y\). The updates to \(z\) correspond to the underlying optimizer, in this case a simple gradient step. Note that, \(\beta_1\) corresponds to b1 in the code.

As the name suggests, Schedule-Free learning does not require a decreasing learning rate schedule, yet typically out-performs, or at worst matches, SOTA schedules such as cosine-decay and linear decay. Only two sequences need to be stored at a time (the third can be computed from the other two on the fly) so this method has the same memory requirements as the base optimizer (parameter buffer + momentum).

In practice, authors recommend tuning \(\beta_1\), warmup_steps and peak_lr for each problem separately. Default for \(\beta_1\) is 0.9 but 0.95 and 0.98 may also work well. Schedule-Free can be wrapped on top of any optax optimizer. At test time, the parameters should be evaluated using optax.contrib.schedule_free_eval_params() as presented below.

For example, change this:

learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr)
optimizer = optax.adam(learning_rate_fn, b1=b1)

To:

learning_rate_fn = optax.warmup_constant_schedule(peak_value=retuned_lr)
optimizer = optax.adam(learning_rate_fn, b1=0.)
optimizer = optax.contrib.schedule_free(optimizer, learning_rate_fn, b1=b1)
..
params_for_eval = optax.contrib.schedule_free_eval_params(state, params)

Especially note that is important to switch off Momentum of the base optimizer. As of Apr, 2024, schedule_free is tested with SGD and Adam.

Parameters:
  • base_optimizer – Base optimizer to compute updates from.

  • learning_rate – learning_rate schedule w/o decay but with warmup.

  • b1 – beta_1 parameter in the y update.

  • weight_lr_power – we downweight the weight of averaging using this. This is especially helpful in early iterations during warmup.

  • state_dtype – dtype for z sequence in the schedule free method.

Returns:

A optax.GradientTransformationExtraArgs.

References

Defazio et al, The Road Less Scheduled, 2024

Defazio et al, Schedule-Free Learning - A New Way to Train, 2024

Warning

The current implementation requires the parameter b1 to be strictly positive.

optax.contrib.schedule_free_adamw(learning_rate: jax.typing.ArrayLike = 0.0025, warmup_steps: int | None = None, b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike = 0.0, weight_lr_power: jax.typing.ArrayLike = 2.0, state_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformationExtraArgs[source]#

Schedule-Free wrapper for AdamW.

Shortcut example for using schedule_free with AdamW, which is a common use case. Note that this is just an example, and other usecases are possible, e.g. using a weight decay mask, nesterov, etc. Note also that the EMA parameter of the schedule free method (b1) must be strictly positive.

Parameters:
  • learning_rate – AdamW learning rate.

  • warmup_steps – positive integer, the length of the linear warmup.

  • b1 – beta_1 parameter in the y update.

  • b2 – Exponential decay rate to track the second moment of past gradients.

  • eps – A small constant applied to denominator outside of the square root (as in the Adam paper) to avoid dividing by zero when rescaling.

  • weight_decay – Strength of the weight decay regularization.

  • weight_lr_power – we downweight the weight of averaging using this. This is especially helpful in early iterations during warmup.

  • state_dtype – dtype for z sequence in the schedule free method.

Returns:

A optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.contrib.schedule_free_adamw(1.0)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  eval_params = optax.contrib.schedule_free_eval_params(
...      opt_state, params)
...  print('Objective function: {:.2E}'.format(f(eval_params)))
Objective function: 5.00E+00
Objective function: 3.05E+00
Objective function: 1.73E+00
Objective function: 8.94E-01
Objective function: 4.13E-01

Note

Note that optax.scale_by_adam() with b1=0 stores in its state an unused first moment always equal to zero. To avoid this waste of memory, we replace optax.scale_by_adam() with b1=0 by the equivalent optax.scale_by_rms() with eps_in_sqrt=False, bias_correction=True.

optax.contrib.schedule_free_eval_params(state: optax.OptState, params: optax.Params)[source]#

Params for evaluation of optax.contrib.schedule_free().

optax.contrib.schedule_free_sgd(learning_rate: jax.typing.ArrayLike = 1.0, warmup_steps: int | None = None, b1: jax.typing.ArrayLike = 0.9, weight_decay: jax.typing.ArrayLike | None = None, weight_lr_power: jax.typing.ArrayLike = 2.0, state_dtype: jax.typing.DTypeLike | None = None) base.GradientTransformationExtraArgs[source]#

Schedule-Free wrapper for SGD.

Shortcut example for using schedule_free with SGD, which is a common use case. Note that this is just an example, and other use cases are possible, e.g. using a weight decay mask. Note also that the EMA parameter of the schedule free method (b1) must be strictly positive.

Parameters:
  • learning_rate – SGD learning rate.

  • warmup_steps – positive integer, the length of the linear warmup.

  • b1 – beta_1 parameter in the y update.

  • weight_decay – Strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, 2019) where the weight decay is only multiplied with the “schedule multiplier”, but not the base learning rate.

  • weight_lr_power – we downweight the weight of averaging using this. This is especially helpful in early iterations during warmup.

  • state_dtype – dtype for z sequence in the schedule free method.

Returns:

A optax.GradientTransformationExtraArgs.

Examples

>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = optax.contrib.schedule_free_sgd()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  grad = jax.grad(f)(params)
...  updates, opt_state = solver.update(grad, opt_state, params)
...  params = optax.apply_updates(params, updates)
...  eval_params = optax.contrib.schedule_free_eval_params(
...      opt_state, params)
...  print('Objective function: {:.2E}'.format(f(eval_params)))
Objective function: 1.40E+01
Objective function: 1.75E-14
Objective function: 9.96E-01
Objective function: 8.06E-01
Objective function: 2.41E-01
class optax.contrib.ScheduleFreeState(b1: jax.typing.ArrayLike, weight_sum: jax.typing.ArrayLike, step_count: jax.typing.ArrayLike, max_lr: jax.typing.ArrayLike, base_optimizer_state: optax.OptState, z: optax.Params)[source]#

State for schedule_free.

Sharpness aware minimization#

optax.contrib.sam(optimizer: base.GradientTransformation, adv_optimizer: base.GradientTransformation, sync_period: int = 2, reset_state: bool = True, opaque_mode: bool = False, batch_axis_name: str | None = None) base.GradientTransformationExtraArgs[source]#

Implementation of SAM (Sharpness Aware Minimization).

Performs steps with the inner adversarial optimizer and periodically updates an outer set of true parameters. By default, resets the state of the adversarial optimizer after synchronization. For example:

>>> import optax
>>> rho = 0.1
>>> opt = optax.sgd(learning_rate=0.01)
>>> adv_opt = optax.chain(optax.contrib.normalize(), optax.sgd(rho))
>>> sam_opt = optax.contrib.sam(opt, adv_opt, sync_period=2)

Would implement the simple drop-in SAM version from the paper which uses an inner adversarial optimizer of a normalized sgd for one step.

Parameters:
  • optimizer – the outer optimizer.

  • adv_optimizer – the inner adversarial optimizer.

  • sync_period – how often to run the outer optimizer, defaults to 2, or every other step.

  • reset_state – whether to reset the state of the inner optimizer after every sync period, defaults to True.

  • opaque_mode – bool. If True, the outer optimizer and the adversarial optimizer are run in an internal loop at each call to update, so that adversarial updates are opaque to the rest of the system. If False, one optimizer is (effectively) evaluated per call to update, meaning that adversarial updates are visible to the rest of the system. Setting opaque_mode to True is necessary if the training system using SAM has side effects from each call to update besides the changes to the model’s parameters. The most common example would be if the model uses BatchNorm statistics – those statistics would be updated on both adversarial and non-adversarial update steps, causing them to get out of sync with the model’s parameters (which are effectively only updated on non-adversarial steps). See the NOTE section for more details on opaque_mode=True.

  • batch_axis_name – str or None. Only used if opaque_mode=True. When running in a pmapped setting, it is necessary to take a jax.lax.pmean of the adversarial updates internally before passing them to the outer optimizer. You only need to specify this if you have to use jax.lax.pmean in your training loop.

Returns:

The corresponding optax.GradientTransformationExtraArgs implementation of SAM.

References

Foret et al., Sharpness-Aware Minimization for Efficiently Improving Generalization, 2021

Note

When opaque_mode=True, the update function must be called with a gradient function that takes two arguments (the params and the current adversarial step) and returns the gradients of the loss. This looks like the following:

opt = sam(outer_opt, adv_opt, opaque_mode=True)
...
# In the training loop:
grad_fn = jax.grad(
  lambda params, _: loss(params, batch, and_other_args))
updates, state = opt.update(updates, state, params, grad_fn=grad_fn)
params = optax.apply_updates(params, updates)

On every call to opt.update, grad_fn will be called sync_period - 1 times, once for each adversarial update. It is usually ok to use the same minibatch in each of those updates, as in the example above, but you can use the second argument to select different batches at each adversarial step:

grad_fn = jax.grad(lambda params, i: loss(params, batches[i]))
class optax.contrib.SAMState(steps_since_sync: jax.Array, opt_state: base.OptState, adv_state: base.OptState, cache: base.Params | None)[source]#

State of GradientTransformation returned by sam.

steps_since_sync#

Number of adversarial steps taken since the last outer update.

Type:

jax.Array

opt_state#

State of the outer optimizer.

Type:

base.OptState

adv_state#

State of the inner adversarial optimizer.

Type:

base.OptState

cache#

a place to store the last outer updates.

Type:

Optional[base.Params]

Sophia#

optax.contrib.hutchinson_estimator_diag_hessian(random_seed: Array | None = None)[source]#

Returns a GradientTransformationExtraArgs computing the Hessian diagonal.

The Hessian diagonal is estimated using Hutchinson’s estimator, which is unbiased but has high variance.

Parameters:

random_seed – key used to generate random vectors.

Returns:

GradientTransformationExtraArgs

class optax.contrib.HutchinsonState(key)[source]#
optax.contrib.sophia(learning_rate: base.ScalarOrSchedule, b1: jax.typing.ArrayLike = 0.965, b2: jax.typing.ArrayLike = 0.99, eps: jax.typing.ArrayLike = 1e-08, weight_decay: jax.typing.ArrayLike = 0.0001, weight_decay_mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, gamma: jax.typing.ArrayLike = 0.01, clip_threshold: Optional[jax.typing.ArrayLike] = 1.0, update_interval: jax.typing.ArrayLike = 10, hessian_diagonal_fn: Union[base.GradientTransformation, base.GradientTransformationExtraArgs] = (<function hutchinson_estimator_diag_hessian.<locals>.init_fn>, <function hutchinson_estimator_diag_hessian.<locals>.update_fn>), mu_dtype: Optional[Any] = None, verbose: bool = False, print_win_rate_every_n_steps: jax.typing.ArrayLike = 0) base.GradientTransformationExtraArgs[source]#

Sophia optimizer.

A separate GradientTransformation is required through the argument hessian_diagonal_fn to compute the diagonal of the Hessian. Any extra arguments required by the hessian_diagonal_fn’s update function can be passed through sophia’s update function as trailing keyword arguments (**kwargs). The default hessian_diagonal_fn is Hutchinson’s estimator and needs the objective function as an extra argument, obj_fn. obj_fn must accept params as its only argument and return only a scalar (the loss).

For example, assuming your experiment’s loss function is loss_fn(params, batch) -> loss, aux that takes multiple arguments and returns multiple outputs, we must modify it to loss_fn(params) -> loss:

obj_fn = lambda params: loss_fn(params, batch)[0]

where batch is the current step’s batch.

Then it can be passed to sophia’s update function (which will pass it to the hessian_diagonal_fn’s update function):

updates, state = sophia.update(updates, state, params, obj_fn=sophia_obj_fn)

Optionally, you can write your own GradientTransformation to compute the hessian diagonal. Use this file’s hutchinson_estimator_diag_hessian function as an example. If you are using more than one device, be sure the hessian diagonal function properly averages the hessian diagonal across devices. The default hessian_diagonal_fn does not do this, and would cause params to diverge from each other across devices if using pmap for example.

Parameters:
  • learning_rate – A global scaling factor, either fixed or evolving along iterations with a scheduler, see optax.scale_by_learning_rate().

  • b1 – Exponential decay rate for the first moment estimates.

  • b2 – Exponential decay rate for the hessian diagonal estimates. Keep in mind effective b2 is 1 - (1 - b2) / update_interval, e.g. default b2 of 0.99 is effectively 0.999 because default update_interval is every 10.

  • eps – Small constant to avoid division by zero.

  • weight_decay – Rate at which to decay weights.

  • weight_decay_mask – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the transformation to, and False for those you want to skip.

  • gamma – Normalizing constant for the hessian diagonal.

  • clip_threshold – Threshold for clipping updates.

  • update_interval – Interval for updating the hessian diagonal.

  • hessian_diagonal_fn – GradientTransformation that computes the diagonal of the Hessian. Default is Hutchinson’s estimator (sophia-h). If using more than one device, be sure this function properly averages the hessian diagonal across devices.

  • mu_dtype – dtype of the first moment estimates.

  • verbose – If True, print win rate every n steps.

  • print_win_rate_every_n_steps – Print sophia win rate every n steps for diagnostic purposes. Authors state this value should stay between 0.1 and 0.5 during training. If win rate is too low, try increasing gamma. 0 to turn off.

Returns:

optax.GradientTransformationExtraArgs

References

Liu et al., Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training, 2023

Levanter

Note

We use a rademacher vector to estimate the diagonal of the Hessian, contrary to the original implementation which uses a normal random vector.

class optax.contrib.SophiaState(count: jax.Array, mu: base.Updates, nu: base.Updates, hessian_fn_state: Any)[source]#

State for Sophia Optimizer.