🔧 Contrib

Contents

🔧 Contrib#

Experimental features and algorithms that don’t meet the Inclusion Criteria.

cocob([alpha, eps])

Rescale updates according to the COntinuous COin Betting algorithm.

COCOBState(init_particles, ...)

State for COntinuous COin Betting.

dadapt_adamw([learning_rate, betas, eps, ...])

Learning rate free AdamW by D-Adaptation.

DAdaptAdamWState(exp_avg, exp_avg_sq, ...)

State of the GradientTransformation returned by dadapt_adamw.

differentially_private_aggregate(...)

Aggregates gradients based on the DPSGD algorithm.

DifferentiallyPrivateAggregateState(rng_key)

State containing PRNGKey for differentially_private_aggregate.

dpsgd(learning_rate, l2_norm_clip, ...[, ...])

The DPSGD optimizer.

mechanize(base_optimizer[, weight_decay, ...])

Mechanic - a black box learning rate tuner/optimizer.

MechanicState(base_optimizer_state, count, ...)

State of the GradientTransformation returned by mechanize.

prodigy([learning_rate, betas, beta3, eps, ...])

Learning rate free AdamW with Prodigy.

ProdigyState(exp_avg, exp_avg_sq, grad_sum, ...)

State of the GradientTransformation returned by prodigy.

sam(optimizer, adv_optimizer[, sync_period, ...])

Implementation of SAM (Sharpness Aware Minimization).

SAMState(steps_since_sync, opt_state, ...)

State of GradientTransformation returned by sam.

split_real_and_imaginary(inner)

Splits the real and imaginary components of complex updates into two.

SplitRealAndImaginaryState(inner_state)

Maintains the inner transformation state for split_real_and_imaginary.

Complex-valued Optimization#

optax.contrib.split_real_and_imaginary(inner)[source]#

Splits the real and imaginary components of complex updates into two.

The inner transformation processes real parameters and updates, and the pairs of transformed real updates are merged into complex updates.

Parameters and updates that are real before splitting are passed through unmodified.

Parameters:

inner (GradientTransformation) – The inner transformation.

Return type:

GradientTransformation

Returns:

An optax.GradientTransformation.

class optax.contrib.SplitRealAndImaginaryState(inner_state: base.OptState)[source]#

Maintains the inner transformation state for split_real_and_imaginary.

inner_state: base.OptState#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Continuous coin betting#

optax.contrib.cocob(alpha=100, eps=1e-08)[source]#

Rescale updates according to the COntinuous COin Betting algorithm.

Algorithm for stochastic subgradient descent. Uses a gambling algorithm to find the minimizer of a non-smooth objective function by accessing its subgradients. All we need is a good gambling strategy. See Algorithm 2 of:

References

[Orabona & Tommasi, 2017](https://arxiv.org/pdf/1705.07795.pdf)

Parameters:
  • alpha (float) – fraction to bet parameter of the COCOB optimizer

  • eps (float) – jitter term to avoid dividing by 0

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.contrib.COCOBState(init_particles: base.Updates, cumulative_gradients: base.Updates, scale: base.Updates, subgradients: base.Updates, reward: base.Updates)[source]#

State for COntinuous COin Betting.

init_particles: base.Updates#

Alias for field number 0

cumulative_gradients: base.Updates#

Alias for field number 1

scale: base.Updates#

Alias for field number 2

subgradients: base.Updates#

Alias for field number 3

reward: base.Updates#

Alias for field number 4

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

D-adaptation#

optax.contrib.dadapt_adamw(learning_rate=1.0, betas=(0.9, 0.999), eps=1e-08, estim_lr0=1e-06, weight_decay=0.0)[source]#

Learning rate free AdamW by D-Adaptation.

Adapts the baseline learning rate of AdamW automatically by estimating the initial distance to solution in the infinity norm. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.

References

[Defazio & Mishchenko, 2023](https://arxiv.org/abs/2301.07733.pdf)

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.

  • betas (tuple[float, float]) – Betas for the underlying AdamW Optimizer.

  • eps (float) – eps for the underlying AdamW Optimizer.

  • estim_lr0 (float) – Initial (under-)estimate of the learning rate.

  • weight_decay (float) – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.contrib.DAdaptAdamWState(exp_avg: base.Updates, exp_avg_sq: base.Updates, grad_sum: base.Updates, estim_lr: chex.Array, numerator_weighted: chex.Array, count: chex.Array)[source]#

State of the GradientTransformation returned by dadapt_adamw.

exp_avg: base.Updates#

Alias for field number 0

exp_avg_sq: base.Updates#

Alias for field number 1

grad_sum: base.Updates#

Alias for field number 2

estim_lr: chex.Array#

Alias for field number 3

numerator_weighted: chex.Array#

Alias for field number 4

count: chex.Array#

Alias for field number 5

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Privacy-Sensitive Optax Methods#

DifferentiallyPrivateAggregateState(rng_key)

State containing PRNGKey for differentially_private_aggregate.

differentially_private_aggregate(...)

Aggregates gradients based on the DPSGD algorithm.

Differentially Private Aggregate#

optax.contrib.differentially_private_aggregate(l2_norm_clip, noise_multiplier, seed)[source]#

Aggregates gradients based on the DPSGD algorithm.

WARNING: Unlike other transforms, differentially_private_aggregate expects the input updates to have a batch dimension in the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap). It can still be composed with other transformations as long as it is the first in the chain.

References

[Abadi et al, 2016](https://arxiv.org/abs/1607.00133.pdf)

Parameters:
  • l2_norm_clip (float) – maximum L2 norm of the per-example gradients.

  • noise_multiplier (float) – ratio of standard deviation to the clipping norm.

  • seed (int) – initial seed used for the jax.random.PRNGKey

Return type:

GradientTransformation

Returns:

A GradientTransformation.

class optax.contrib.DifferentiallyPrivateAggregateState(rng_key: Any)[source]#

State containing PRNGKey for differentially_private_aggregate.

rng_key: Any#

Alias for field number 0

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

optax.contrib.dpsgd(learning_rate, l2_norm_clip, noise_multiplier, seed, momentum=None, nesterov=False)[source]#

The DPSGD optimizer.

Differential privacy is a standard for privacy guarantees of algorithms learning from aggregate databases including potentially sensitive information. DPSGD offers protection against a strong adversary with full knowledge of the training mechanism and access to the model’s parameters.

WARNING: This GradientTransformation expects input updates to have a batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap).

References

Abadi et al, 2016: https://arxiv.org/abs/1607.00133

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A fixed global scaling factor.

  • l2_norm_clip (float) – Maximum L2 norm of the per-example gradients.

  • noise_multiplier (float) – Ratio of standard deviation to the clipping norm.

  • seed (int) – Initial seed used for the jax.random.PRNGKey

  • momentum (Optional[float]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.

  • nesterov (bool) – Whether Nesterov momentum is used.

Return type:

GradientTransformation

Returns:

A GradientTransformation.

Mechanize#

optax.contrib.mechanize(base_optimizer, weight_decay=0.01, eps=1e-08, s_init=1e-06, num_betas=6)[source]#

Mechanic - a black box learning rate tuner/optimizer.

Accumulates updates returned by the base_optimizer and learns the scale of the updates (also know as learning rate or step size) to apply on a per iteration basis.

Note that Mechanic does NOT eschew the need for a learning rate schedule, you are free to apply a learning rate schedule with base learning rate set to 1.0 (or any other constant) and Mechanic will learn the right scale factor automatically.

For example, change this:

learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr)
optimizer = optax.adam(learning_rate_fn)

To:

learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=1.0)
optimizer = optax.adam(learning_rate_fn)
optimizer = optax.contrib.mechanize(optimizer)

As of June, 2023, Mechanic is tested with SGD, Momentum, Adam and Lion as inner optimizers but we expect it to work with almost any first-order optimizer (except for normalized gradient optimizer like LARS or LAMB).

References

[Cutkosky et al, 2023](https://arxiv.org/pdf/2306.00144.pdf)

Parameters:
  • base_optimizer (GradientTransformation) – Base optimizer to compute updates from.

  • weight_decay (float) – A scalar weight decay rate. Note that this weight decay is not the same as the weight decay one would use for the base_optimizer. In addition to sometimes helping converge faster, this helps Mechanic reduce the variance between training runs using different seeds. You likely would not need to tune this, the default should work in most cases.

  • eps (float) – epsilon for mechanic.

  • s_init (float) – initial scale factor. Default should work almost all the time.

  • num_betas (int) – unlike traditional exp accumulators (like 1st or 2nd moment of adam), where one has to choose an explicit beta, mechanic has a clever way to automatically learn the right beta for all accumulators. We only provide the range of possible betas, and not the tuned value. For instance, if you set num_betas to 3, it will use betas = [0.9, 0.99, 0.999].

Return type:

GradientTransformation

Returns:

A GradientTransformation with init and update functions.

class optax.contrib.MechanicState(base_optimizer_state: base.OptState, count: chex.Array, r: chex.Array, m: chex.Array, v: chex.Array, s: chex.Array, x0: base.Updates)[source]#

State of the GradientTransformation returned by mechanize.

base_optimizer_state: base.OptState#

Alias for field number 0

count: chex.Array#

Alias for field number 1

r: chex.Array#

Alias for field number 2

m: chex.Array#

Alias for field number 3

v: chex.Array#

Alias for field number 4

s: chex.Array#

Alias for field number 5

x0: base.Updates#

Alias for field number 6

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Prodigy#

optax.contrib.prodigy(learning_rate=1.0, betas=(0.9, 0.999), beta3=None, eps=1e-08, estim_lr0=1e-06, estim_lr_coef=1.0, weight_decay=0.0)[source]#

Learning rate free AdamW with Prodigy.

Implementation of the Prodigy method from “Prodigy: An Expeditiously Adaptive Parameter-Free Learner”, a version of D-Adapt AdamW that adapts the baseline learning rate faster by using a weighting of the gradients that places higher weights on more recent gradients. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.

References

[Mishchenko & Defazio, 2023](https://arxiv.org/abs/2306.06101.pdf)

Parameters:
  • learning_rate (Union[float, Array, Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.

  • betas (tuple[float, float]) – Betas for the underlying AdamW Optimizer.

  • beta3 (Optional[float]) – Optional momentum parameter for estimation of D.

  • eps (float) – eps for the underlying AdamW Optimizer.

  • estim_lr0 (float) – Initial (under-)estimate of the learning rate.

  • estim_lr_coef (float) – LR estimates are multiplied by this parameter.

  • weight_decay (float) – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.

Return type:

GradientTransformation

Returns:

A GradientTransformation object.

class optax.contrib.ProdigyState(exp_avg: base.Updates, exp_avg_sq: base.Updates, grad_sum: base.Updates, params0: base.Updates, estim_lr: chex.Array, numerator_weighted: chex.Array, count: chex.Array)[source]#

State of the GradientTransformation returned by prodigy.

exp_avg: base.Updates#

Alias for field number 0

exp_avg_sq: base.Updates#

Alias for field number 1

grad_sum: base.Updates#

Alias for field number 2

params0: base.Updates#

Alias for field number 3

estim_lr: chex.Array#

Alias for field number 4

numerator_weighted: chex.Array#

Alias for field number 5

count: chex.Array#

Alias for field number 6

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Sharpness aware minimization#

optax.contrib.sam(optimizer, adv_optimizer, sync_period=2, reset_state=True, opaque_mode=False, batch_axis_name=None)[source]#

Implementation of SAM (Sharpness Aware Minimization).

Performs steps with the inner adversarial optimizer and periodically updates an outer set of true parameters. By default, resets the state of the adversarial optimizer after syncronization. For example:

>>> import optax
>>> rho = 0.1
>>> opt = optax.sgd(learning_rate=0.01)
>>> adv_opt = optax.chain(optax.contrib.normalize(), optax.sgd(rho))
>>> sam_opt = optax.contrib.sam(opt, adv_opt, sync_period=2)

Would implement the simple drop-in SAM version from the paper which uses an inner adversarial optimizer of a normalized sgd for one step.

Note

When opaque_mode=True, the update function must be called with a gradient function that takes two arguments (the params and the current adversarial step) and returns the gradients of the loss. This looks like the following:

opt = sam(outer_opt, adv_opt, opaque_mode=True)
...
# In the training loop:
grad_fn = jax.grad(
  lambda params, _: loss(params, batch, and_other_args))
updates, state = opt.update(updates, state, params, grad_fn=grad_fn)
params = optax.apply_updates(params, updates)

On every call to opt.update, grad_fn will be called sync_period - 1 times, once for each adversarial update. It is usually ok to use the same minibatch in each of those updates, as in the example above, but you can use the second argument to select different batches at each adversarial step:

grad_fn = jax.grad(lambda params, i: loss(params, batches[i]))

References

Foret et al., Sharpness-Aware Minimization for Efficiently Improving Generalization, 2021

Parameters:
  • optimizer (GradientTransformation) – the outer optimizer.

  • adv_optimizer (GradientTransformation) – the inner adversarial optimizer.

  • sync_period (int) – how often to run the outer optimizer, defaults to 2, or every other step.

  • reset_state (bool) – whether to reset the state of the inner optimizer after every sync period, defaults to True.

  • opaque_mode (bool) – bool. If True, the outer optimizer and the adversarial optimizer are run in an internal loop at each call to update, so that adversarial updates are opaque to the rest of the system. If False, one optimizer is (effectively) evaluated per call to update, meaning that adversarial updates are visible to the rest of the system. Setting opaque_mode to True is necessary if the training system using SAM has side effects from each call to update besides the changes to the model’s parameters. The most common example would be if the model uses BatchNorm statistics – those statistics would be updated on both adversarial and non-adversarial update steps, causing them to get out of sync with the model’s parameters (which are effectively only updated on non-adversarial steps). See the NOTE section for more details on opaque_mode=True.

  • batch_axis_name (Optional[str]) – str or None. Only used if opaque_mode=True. When running in a pmapped setting, it is necessary to take a jax.lax.pmean of the adversarial updates internally before passing them to the outer optimizer. You only need to specify this if you have to use jax.lax.pmean in your training loop.

Returns:

a GradientTransformationExtraArgs implementation of SAM.

Return type:

sam_optimizer

class optax.contrib.SAMState(steps_since_sync, opt_state, adv_state, cache)[source]#

State of GradientTransformation returned by sam.

steps_since_sync#

Number of adversarial steps taken since the last outer update.

opt_state#

State of the outer optimizer.

adv_state#

State of the inner adversarial optimizer.

cache#

a place to store the last outer updates.

__eq__(other)#

Return self==value.

__getstate__()[source]#

Helper for pickle.

__hash__ = None#
__init__(steps_since_sync, opt_state, adv_state, cache)#
items() a set-like object providing a view on D's items[source]#
keys() a set-like object providing a view on D's keys[source]#
values() an object providing a view on D's values[source]#