🔧 Contrib

Contents

🔧 Contrib#

Experimental features and algorithms that don’t meet the Inclusion Criteria.

cocob([learning_rate, alpha, eps, ...])

Rescale updates according to the COntinuous COin Betting algorithm.

COCOBState(init_particles, ...)

State for COntinuous COin Betting.

dadapt_adamw([learning_rate, betas, eps, ...])

Learning rate free AdamW by D-Adaptation.

DAdaptAdamWState(exp_avg, exp_avg_sq, ...)

State of the GradientTransformation returned by dadapt_adamw.

differentially_private_aggregate(...)

Aggregates gradients based on the DPSGD algorithm.

DifferentiallyPrivateAggregateState(rng_key)

State containing PRNGKey for differentially_private_aggregate.

dog([learning_rate, reps_rel, eps, ...])

Distance over Gradients optimizer.

DoGState(first_step, init_params, ...)

State for DoG optimizer.

dowg([learning_rate, init_estim_sq_dist, ...])

Distance over weighted Gradients optimizer.

DoWGState(init_params, ...)

State for DoWG optimizer.

dpsgd(learning_rate, l2_norm_clip, ...[, ...])

The DPSGD optimizer.

mechanize(base_optimizer[, weight_decay, ...])

Mechanic - a black box learning rate tuner/optimizer.

MechanicState(base_optimizer_state, count, ...)

State of the GradientTransformation returned by mechanize.

momo([learning_rate, beta, lower_bound, ...])

Adaptive Learning Rates for SGD with momentum.

MomoState(exp_avg, barf, gamma, lb, count)

State of the GradientTransformation returned by momo.

momo_adam([learning_rate, b1, b2, eps, ...])

Adaptive Learning Rates for Adam(W).

MomoAdamState(exp_avg, exp_avg_sq, barf, ...)

State of the GradientTransformation returned by momo_adam.

prodigy([learning_rate, betas, beta3, eps, ...])

Learning rate free AdamW with Prodigy.

ProdigyState(exp_avg, exp_avg_sq, grad_sum, ...)

State of the GradientTransformation returned by prodigy.

sam(optimizer, adv_optimizer[, sync_period, ...])

Implementation of SAM (Sharpness Aware Minimization).

SAMState(steps_since_sync, opt_state, ...)

State of GradientTransformation returned by sam.

schedule_free(base_optimizer, learning_rate)

Turn base_optimizer schedule_free.

schedule_free_eval_params(state, params)

Params for evaluation of optax.contrib.schedule_free().

ScheduleFreeState(b1, weight_sum, ...)

State for schedule_free.

split_real_and_imaginary(inner)

Splits the real and imaginary components of complex updates into two.

SplitRealAndImaginaryState(inner_state)

Maintains the inner transformation state for split_real_and_imaginary.

Complex-valued Optimization#

optax.contrib.split_real_and_imaginary(inner)[source]#

Splits the real and imaginary components of complex updates into two.

The inner transformation processes real parameters and updates, and the pairs of transformed real updates are merged into complex updates.

Parameters and updates that are real before splitting are passed through unmodified.

Parameters:

inner (GradientTransformation) – The inner transformation.

Return type:

GradientTransformation

Returns:

An optax.GradientTransformation.

class optax.contrib.SplitRealAndImaginaryState(inner_state: base.OptState)[source]#

Maintains the inner transformation state for split_real_and_imaginary.

inner_state: base.OptState#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Continuous coin betting#

optax.contrib.cocob(learning_rate=1.0, alpha=100, eps=1e-08, weight_decay=0, mask=None)[source]#

Rescale updates according to the COntinuous COin Betting algorithm.

Algorithm for stochastic subgradient descent. Uses a gambling algorithm to find the minimizer of a non-smooth objective function by accessing its subgradients. All we need is a good gambling strategy. See Algorithm 2 of:

References

[Orabona & Tommasi, 2017](https://arxiv.org/pdf/1705.07795.pdf)

Parameters:
  • learning_rate (base.ScalarOrSchedule) – optional learning rate to e.g. inject some scheduler

  • alpha (float) – fraction to bet parameter of the COCOB optimizer

  • eps (float) – jitter term to avoid dividing by 0

  • weight_decay (float) – L2 penalty

  • mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – mask for weight decay

Return type:

optax.GradientTransformation

Returns:

A GradientTransformation object.

class optax.contrib.COCOBState(init_particles: base.Updates, cumulative_gradients: base.Updates, scale: base.Updates, subgradients: base.Updates, reward: base.Updates)[source]#

State for COntinuous COin Betting.

init_particles: base.Updates#

Alias for field number 0

cumulative_gradients: base.Updates#

Alias for field number 1

scale: base.Updates#

Alias for field number 2

subgradients: base.Updates#

Alias for field number 3

reward: base.Updates#

Alias for field number 4

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

D-adaptation#

optax.contrib.dadapt_adamw(learning_rate=1.0, betas=(0.9, 0.999), eps=1e-08, estim_lr0=1e-06, weight_decay=0.0)[source]#

Learning rate free AdamW by D-Adaptation.

Adapts the baseline learning rate of AdamW automatically by estimating the initial distance to solution in the infinity norm. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.

References

[Defazio & Mishchenko, 2023](https://arxiv.org/abs/2301.07733.pdf)

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.

  • betas (tuple[float, float]) – Betas for the underlying AdamW Optimizer.

  • eps (float) – eps for the underlying AdamW Optimizer.

  • estim_lr0 (float) – Initial (under-)estimate of the learning rate.

  • weight_decay (float) – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.contrib.DAdaptAdamWState(exp_avg: base.Updates, exp_avg_sq: base.Updates, grad_sum: base.Updates, estim_lr: chex.Array, numerator_weighted: chex.Array, count: chex.Array)[source]#

State of the GradientTransformation returned by dadapt_adamw.

exp_avg: base.Updates#

Alias for field number 0

exp_avg_sq: base.Updates#

Alias for field number 1

grad_sum: base.Updates#

Alias for field number 2

estim_lr: chex.Array#

Alias for field number 3

numerator_weighted: chex.Array#

Alias for field number 4

count: chex.Array#

Alias for field number 5

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Privacy-Sensitive Optax Methods#

DifferentiallyPrivateAggregateState(rng_key)

State containing PRNGKey for differentially_private_aggregate.

differentially_private_aggregate(...)

Aggregates gradients based on the DPSGD algorithm.

Differentially Private Aggregate#

optax.contrib.differentially_private_aggregate(l2_norm_clip, noise_multiplier, seed)[source]#

Aggregates gradients based on the DPSGD algorithm.

WARNING: Unlike other transforms, differentially_private_aggregate expects the input updates to have a batch dimension in the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap). It can still be composed with other transformations as long as it is the first in the chain.

References

[Abadi et al, 2016](https://arxiv.org/abs/1607.00133.pdf)

Parameters:
  • l2_norm_clip (float) – maximum L2 norm of the per-example gradients.

  • noise_multiplier (float) – ratio of standard deviation to the clipping norm.

  • seed (int) – initial seed used for the jax.random.PRNGKey

Return type:

GradientTransformation

Returns:

A GradientTransformation.

class optax.contrib.DifferentiallyPrivateAggregateState(rng_key: Any)[source]#

State containing PRNGKey for differentially_private_aggregate.

rng_key: Any#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.contrib.dpsgd(learning_rate, l2_norm_clip, noise_multiplier, seed, momentum=None, nesterov=False)[source]#

The DPSGD optimizer.

Differential privacy is a standard for privacy guarantees of algorithms learning from aggregate databases including potentially sensitive information. DPSGD offers protection against a strong adversary with full knowledge of the training mechanism and access to the model’s parameters.

WARNING: This GradientTransformation expects input updates to have a batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap).

References

Abadi et al, 2016: https://arxiv.org/abs/1607.00133

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • l2_norm_clip (float) – Maximum L2 norm of the per-example gradients.

  • noise_multiplier (float) – Ratio of standard deviation to the clipping norm.

  • seed (int) – Initial seed used for the jax.random.PRNGKey

  • momentum (Optional[float]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov (bool) – Whether Nesterov momentum is used.

Return type:

GradientTransformation

Returns:

A GradientTransformation.

Distance over Gradients#

optax.contrib.dog(learning_rate=1.0, reps_rel=1e-06, eps=1e-08, init_learning_rate=None, weight_decay=None, mask=None)[source]#

Distance over Gradients optimizer.

DoG updates parameters \(w_t\) with stochastic gradients \(g_t\) according to the update rule:

\[\begin{align*} \eta_t &= \frac{\max_{i\le t}{\|x_i-x_0\|}}{ \sqrt{\sum_{i\le t}{\|g_i\|^2+eps}}}\\ x_{t+1} & = x_{t} - \eta_t\, g_t, \end{align*}\]

Examples

>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = contrib.dog()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  params, opt_state = solver.update(grad, opt_state, params, value=value)
...  print('Objective function: ', f(params))
Objective function:  2.2483316e-11
Objective function:  2.2483375e-11
Objective function:  2.2483435e-11
Objective function:  2.2483494e-11
Objective function:  2.2483555e-11

References

Ivgi et al., DoG is SGD’s Best Friend: A Parameter-Free Dynamic Step Size Schedule, 2023.

Parameters:
  • learning_rate (base.ScalarOrSchedule) – optional learning rate (potentially varying according to some predetermined scheduler).

  • reps_rel (float) – value to use to compute the initial distance (r_epsilon in the paper). Namely, the first step size is given by: (reps_rel * (1+|x_0|)) / (|g_0|^2 + eps)^{1/2} where x_0 are the initial weights of the model (or the parameter group), and g_0 is the gradient of the first step. As discussed in the paper, this value should be small enough to ensure that the first update step will be small enough to not cause the model to diverge. Suggested value is 1e-6, unless the model uses batch-normalization, in which case the suggested value is 1e-4.

  • eps (float) – epsilon used for numerical stability - added to the sum of squared norm of gradients.

  • init_learning_rate (Optional[float]) – if specified, this value will be used the the initial learning rate (i.e. first step size) instead of the rule described above with reps_rel.

  • weight_decay (Optional[float]) – Strength of the weight decay regularization.

  • mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the gradient transformations is applied to all parameters.

Returns:

The corresponding optax.GradientTransformation with associated init and update functions.

class optax.contrib.DoGState(first_step: jax.Array, init_params: chex.ArrayTree, estim_dist: jax.Array, sum_sq_norm_grads: jax.Array)[source]#

State for DoG optimizer.

optax.contrib.dowg(learning_rate=1.0, init_estim_sq_dist=None, eps=0.0001, weight_decay=None, mask=None)[source]#

Distance over weighted Gradients optimizer.

Examples

>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = contrib.dowg()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  params, opt_state = solver.update(grad, opt_state, params, value=value)
...  print('Objective function: ', f(params))
Objective function:  9.973327e-05
Objective function:  7.0334883
Objective function:  14.074293
Objective function:  49.897446
Objective function:  42.62062

References

Khaled et al., DoWG Unleashed: An Efficient Universal Parameter-Free Gradient Descent Method, 2023.

Parameters:
  • learning_rate (base.ScalarOrSchedule) – optional learning rate (potentially varying according to some predetermined scheduler).

  • init_estim_sq_dist (Optional[float]) – initial guess of the squared distance to solution.

  • eps (float) – small value to prevent division by zero in the denominator definining, the learning rate, also used as initial guess for the distance to solution if init_estim_sq_dist is None.

  • weight_decay (Optional[float]) – Strength of the weight decay regularization.

  • mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the gradient transformations is applied to all parameters.

Returns:

The corresponding optax.GradientTransformation with associated init and update functions.

class optax.contrib.DoWGState(init_params: chex.ArrayTree, weighted_sq_norm_grads: jax.Array, estim_sq_dist: jax.Array)[source]#

State for DoWG optimizer.

Mechanize#

optax.contrib.mechanize(base_optimizer, weight_decay=0.01, eps=1e-08, s_init=1e-06, num_betas=6)[source]#

Mechanic - a black box learning rate tuner/optimizer.

Accumulates updates returned by the base_optimizer and learns the scale of the updates (also know as learning rate or step size) to apply on a per iteration basis.

Note that Mechanic does NOT eschew the need for a learning rate schedule, you are free to apply a learning rate schedule with base learning rate set to 1.0 (or any other constant) and Mechanic will learn the right scale factor automatically.

For example, change this:

learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr)
optimizer = optax.adam(learning_rate_fn)

To:

learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=1.0)
optimizer = optax.adam(learning_rate_fn)
optimizer = optax.contrib.mechanize(optimizer)

As of June, 2023, Mechanic is tested with SGD, Momentum, Adam and Lion as inner optimizers but we expect it to work with almost any first-order optimizer (except for normalized gradient optimizer like LARS or LAMB).

References

[Cutkosky et al, 2023](https://arxiv.org/pdf/2306.00144.pdf)

Parameters:
  • base_optimizer (GradientTransformation) – Base optimizer to compute updates from.

  • weight_decay (float) – A scalar weight decay rate. Note that this weight decay is not the same as the weight decay one would use for the base_optimizer. In addition to sometimes helping converge faster, this helps Mechanic reduce the variance between training runs using different seeds. You likely would not need to tune this, the default should work in most cases.

  • eps (float) – epsilon for mechanic.

  • s_init (float) – initial scale factor. Default should work almost all the time.

  • num_betas (int) – unlike traditional exp accumulators (like 1st or 2nd moment of adam), where one has to choose an explicit beta, mechanic has a clever way to automatically learn the right beta for all accumulators. We only provide the range of possible betas, and not the tuned value. For instance, if you set num_betas to 3, it will use betas = [0.9, 0.99, 0.999].

Return type:

GradientTransformation

Returns:

A GradientTransformation with init and update functions.

class optax.contrib.MechanicState(base_optimizer_state: base.OptState, count: chex.Array, r: chex.Array, m: chex.Array, v: chex.Array, s: chex.Array, x0: base.Updates)[source]#

State of the GradientTransformation returned by mechanize.

base_optimizer_state: base.OptState#

Alias for field number 0

count: chex.Array#

Alias for field number 1

r: chex.Array#

Alias for field number 2

m: chex.Array#

Alias for field number 3

v: chex.Array#

Alias for field number 4

s: chex.Array#

Alias for field number 5

x0: base.Updates#

Alias for field number 6

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Momo#

optax.contrib.momo(learning_rate=1.0, beta=0.9, lower_bound=0.0, weight_decay=0.0, adapt_lower_bound=False)[source]#

Adaptive Learning Rates for SGD with momentum.

MoMo typically needs less tuning for value of learning_rate, by exploting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.

MoMo performs SGD with momentum with a Polyak-type learning rate. The effective step size is min(learning_rate, <adaptive term>), where the adaptive term is computed on the fly.

Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument value.

Examples

>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = contrib.momo()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  params, opt_state = solver.update(grad, opt_state, params, value=value)
...  print('Objective function: ', f(params))
Objective function:  3.5
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0

References

Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – User-specified learning rate. Recommended to be chosen rather large, by default 1.0.

  • beta (float) – Momentum coefficient (for EMA).

  • lower_bound (float) – Lower bound of the loss. Zero should be a good choice for many tasks.

  • weight_decay (float) – Weight-decay parameter.

  • adapt_lower_bound (bool) – If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).

Return type:

GradientTransformationExtraArgs

Returns:

A GradientTransformation object.

Added in version 0.2.3.

class optax.contrib.MomoState(exp_avg: base.Updates, barf: chex.Array, gamma: chex.Array, lb: chex.Array, count: chex.Array)[source]#

State of the GradientTransformation returned by momo.

optax.contrib.momo_adam(learning_rate=0.01, b1=0.9, b2=0.999, eps=1e-08, lower_bound=0.0, weight_decay=0.0, adapt_lower_bound=False)[source]#

Adaptive Learning Rates for Adam(W).

MoMo-Adam typically needs less tuning for value of learning_rate, by exploting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.

MoMo performs Adam(W) with a Polyak-type learning rate. The effective step size is min(learning_rate, <adaptive term>), where the adaptive term is computed on the fly.

Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument value.

Examples

>>> from optax import contrib
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2)  # simple quadratic function
>>> solver = contrib.momo_adam()
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function:  14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
...  value, grad = jax.value_and_grad(f)(params)
...  params, opt_state = solver.update(grad, opt_state, params, value=value)
...  print('Objective function: ', f(params))
Objective function:  0.00029999594
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0
Objective function:  0.0

References

Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – User-specified learning rate. Recommended to be chosen rather large, by default 1.0.

  • b1 (float) – Exponential decay rate to track the first moment of past gradients.

  • b2 (float) – Exponential decay rate to track the second moment of past gradients.

  • eps (float) – eps for the underlying Adam Optimizer.

  • lower_bound (float) – Lower bound of the loss. Zero should be a good choice for many tasks.

  • weight_decay (float) – Weight-decay parameter. Momo-Adam performs weight decay in similar fashion to AdamW.

  • adapt_lower_bound (bool) – If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).

Return type:

GradientTransformationExtraArgs

Returns:

A GradientTransformation object.

Added in version 0.2.3.

class optax.contrib.MomoAdamState(exp_avg: base.Updates, exp_avg_sq: base.Updates, barf: chex.Array, gamma: chex.Array, lb: chex.Array, count: chex.Array)[source]#

State of the GradientTransformation returned by momo_adam.

Prodigy#

optax.contrib.prodigy(learning_rate=1.0, betas=(0.9, 0.999), beta3=None, eps=1e-08, estim_lr0=1e-06, estim_lr_coef=1.0, weight_decay=0.0)[source]#

Learning rate free AdamW with Prodigy.

Implementation of the Prodigy method from “Prodigy: An Expeditiously Adaptive Parameter-Free Learner”, a version of D-Adapt AdamW that adapts the baseline learning rate faster by using a weighting of the gradients that places higher weights on more recent gradients. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.

References

[Mishchenko & Defazio, 2023](https://arxiv.org/abs/2306.06101.pdf)

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.

  • betas (tuple[float, float]) – Betas for the underlying AdamW Optimizer.

  • beta3 (Optional[float]) – Optional momentum parameter for estimation of D.

  • eps (float) – eps for the underlying AdamW Optimizer.

  • estim_lr0 (float) – Initial (under-)estimate of the learning rate.

  • estim_lr_coef (float) – LR estimates are multiplied by this parameter.

  • weight_decay (float) – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.contrib.ProdigyState(exp_avg: base.Updates, exp_avg_sq: base.Updates, grad_sum: base.Updates, params0: base.Updates, estim_lr: chex.Array, numerator_weighted: chex.Array, count: chex.Array)[source]#

State of the GradientTransformation returned by prodigy.

exp_avg: base.Updates#

Alias for field number 0

exp_avg_sq: base.Updates#

Alias for field number 1

grad_sum: base.Updates#

Alias for field number 2

params0: base.Updates#

Alias for field number 3

estim_lr: chex.Array#

Alias for field number 4

numerator_weighted: chex.Array#

Alias for field number 5

count: chex.Array#

Alias for field number 6

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Schedule-Free#

optax.contrib.schedule_free(base_optimizer, learning_rate, b1=0.9, weight_lr_power=2.0, state_dtype=<class 'jax.numpy.float32'>)[source]#

Turn base_optimizer schedule_free.

Accumulates updates returned by the base_optimizer w/o Momentum and replaces the momentum of an underlying optimizer with a combination of interpolation and averaging. In the case of gradient descent the update is

\[\begin{align*} y_{t} & = (1-\beta_1)z_{t} + \beta_1 x_{t},\\ z_{t+1} & =z_{t}-\gamma\nabla f(y_{t}),\\ x_{t+1} & =\left(1-\frac{1}{t}\right)x_{t}+\frac{1}{t}z_{t+1}, \end{align*}\]

Here \(x\) is the sequence that evaluations of test/val loss should occur at, which differs from the primary iterates \(z\) and the gradient evaluation locations \(y\). The updates to \(z\) correspond to the underlying optimizer, in this case a simple gradient step. Note that, \(\beta_1\) corresponds to b1 in the code.

As the name suggests, Schedule-Free learning does not require a decreasing learning rate schedule, yet typically out-performs, or at worst matches, SOTA schedules such as cosine-decay and linear decay. Only two sequences need to be stored at a time (the third can be computed from the other two on the fly) so this method has the same memory requirements as the base optimizer (parameter buffer + momentum).

In practice, authors recommend tuning \(\beta_1\), warmup_steps and peak_lr for each problem seperately. Default for \(\beta_1\) is 0.9 but 0.95 and 0.98 may also work well. Schedule-Free can be wrapped on top of any optax optimizer. At test time, the parameters should be evaluated using optax.contrib.schedule_free_eval_params() as presented below.

For example, change this:

learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr)
optimizer = optax.adam(learning_rate_fn, b1=b1)

To:

learning_rate_fn = optax.warmup_constant_schedule(peak_value=retuned_lr)
optimizer = optax.adam(learning_rate_fn, b1=0.)
optimizer = optax.contrib.schedule_free(optimizer, learning_rate_fn, b1=b1)
..
params_for_eval = optax.contrib.schedule_free_eval_params(state, params)

Especially note that is important to switch off Momentum of the base optimizer. As of Apr, 2024, schedule_free is tested with SGD and Adam.

References

Defazio et al, Schedule-Free Learning - A New Way to Train, 2024

Parameters:
  • base_optimizer (GradientTransformation) – Base optimizer to compute updates from.

  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – learning_rate schedule w/o decay but with warmup.

  • b1 (float) – beta_1 parameter in the y update.

  • weight_lr_power (float) – we downweight the weight of averaging using this. This is especially helpful in early iterations during warmup.

  • state_dtype – dtype for z sequence.

Return type:

GradientTransformationExtraArgs

Returns:

A GradientTransformationExtraArgs with init and update functions.

optax.contrib.schedule_free_eval_params(state, params)[source]#

Params for evaluation of optax.contrib.schedule_free().

class optax.contrib.ScheduleFreeState(b1: chex.Array, weight_sum: chex.Array, step_count: chex.Array, max_lr: chex.Array, base_optimizer_state: base.OptState, z: base.Params)[source]#

State for schedule_free.

b1: chex.Array#

Alias for field number 0

weight_sum: chex.Array#

Alias for field number 1

step_count: chex.Array#

Alias for field number 2

max_lr: chex.Array#

Alias for field number 3

base_optimizer_state: base.OptState#

Alias for field number 4

z: base.Params#

Alias for field number 5

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Sharpness aware minimization#

optax.contrib.sam(optimizer, adv_optimizer, sync_period=2, reset_state=True, opaque_mode=False, batch_axis_name=None)[source]#

Implementation of SAM (Sharpness Aware Minimization).

Performs steps with the inner adversarial optimizer and periodically updates an outer set of true parameters. By default, resets the state of the adversarial optimizer after syncronization. For example:

>>> import optax
>>> rho = 0.1
>>> opt = optax.sgd(learning_rate=0.01)
>>> adv_opt = optax.chain(optax.contrib.normalize(), optax.sgd(rho))
>>> sam_opt = optax.contrib.sam(opt, adv_opt, sync_period=2)

Would implement the simple drop-in SAM version from the paper which uses an inner adversarial optimizer of a normalized sgd for one step.

Note

When opaque_mode=True, the update function must be called with a gradient function that takes two arguments (the params and the current adversarial step) and returns the gradients of the loss. This looks like the following:

opt = sam(outer_opt, adv_opt, opaque_mode=True)
...
# In the training loop:
grad_fn = jax.grad(
  lambda params, _: loss(params, batch, and_other_args))
updates, state = opt.update(updates, state, params, grad_fn=grad_fn)
params = optax.apply_updates(params, updates)

On every call to opt.update, grad_fn will be called sync_period - 1 times, once for each adversarial update. It is usually ok to use the same minibatch in each of those updates, as in the example above, but you can use the second argument to select different batches at each adversarial step:

grad_fn = jax.grad(lambda params, i: loss(params, batches[i]))

References

Foret et al., Sharpness-Aware Minimization for Efficiently Improving Generalization, 2021

Parameters:
  • optimizer (GradientTransformation) – the outer optimizer.

  • adv_optimizer (GradientTransformation) – the inner adversarial optimizer.

  • sync_period (int) – how often to run the outer optimizer, defaults to 2, or every other step.

  • reset_state (bool) – whether to reset the state of the inner optimizer after every sync period, defaults to True.

  • opaque_mode (bool) – bool. If True, the outer optimizer and the adversarial optimizer are run in an internal loop at each call to update, so that adversarial updates are opaque to the rest of the system. If False, one optimizer is (effectively) evaluated per call to update, meaning that adversarial updates are visible to the rest of the system. Setting opaque_mode to True is necessary if the training system using SAM has side effects from each call to update besides the changes to the model’s parameters. The most common example would be if the model uses BatchNorm statistics – those statistics would be updated on both adversarial and non-adversarial update steps, causing them to get out of sync with the model’s parameters (which are effectively only updated on non-adversarial steps). See the NOTE section for more details on opaque_mode=True.

  • batch_axis_name (Optional[str]) – str or None. Only used if opaque_mode=True. When running in a pmapped setting, it is necessary to take a jax.lax.pmean of the adversarial updates internally before passing them to the outer optimizer. You only need to specify this if you have to use jax.lax.pmean in your training loop.

Returns:

a GradientTransformationExtraArgs implementation of SAM.

Return type:

sam_optimizer

class optax.contrib.SAMState(steps_since_sync, opt_state, adv_state, cache)[source]#

State of GradientTransformation returned by sam.

steps_since_sync#

Number of adversarial steps taken since the last outer update.

opt_state#

State of the outer optimizer.

adv_state#

State of the inner adversarial optimizer.

cache#

a place to store the last outer updates.

__eq__(other)#

Return self==value.

__getstate__()[source]#

Helper for pickle.

__hash__ = None#
__init__(steps_since_sync, opt_state, adv_state, cache)#
items() a set-like object providing a view on D's items[source]#
keys() a set-like object providing a view on D's keys[source]#
values() an object providing a view on D's values[source]#