🔧 Contrib#
Experimental features and algorithms that don’t meet the Inclusion Criteria.
|
Rescale updates according to the COntinuous COin Betting algorithm. |
|
State for COntinuous COin Betting. |
|
Learning rate free AdamW by D-Adaptation. |
|
State of the GradientTransformation returned by dadapt_adamw. |
Aggregates gradients based on the DPSGD algorithm. |
|
|
State containing PRNGKey for differentially_private_aggregate. |
|
The DPSGD optimizer. |
|
Mechanic - a black box learning rate tuner/optimizer. |
|
State of the GradientTransformation returned by mechanize. |
|
Learning rate free AdamW with Prodigy. |
|
State of the GradientTransformation returned by prodigy. |
|
Implementation of SAM (Sharpness Aware Minimization). |
|
State of GradientTransformation returned by sam. |
|
Splits the real and imaginary components of complex updates into two. |
|
Maintains the inner transformation state for split_real_and_imaginary. |
Complex-valued Optimization#
- optax.contrib.split_real_and_imaginary(inner)[source]#
Splits the real and imaginary components of complex updates into two.
The inner transformation processes real parameters and updates, and the pairs of transformed real updates are merged into complex updates.
Parameters and updates that are real before splitting are passed through unmodified.
- Parameters:
inner (
GradientTransformation
) – The inner transformation.- Return type:
GradientTransformation
- Returns:
An optax.GradientTransformation.
Continuous coin betting#
- optax.contrib.cocob(alpha=100, eps=1e-08)[source]#
Rescale updates according to the COntinuous COin Betting algorithm.
Algorithm for stochastic subgradient descent. Uses a gambling algorithm to find the minimizer of a non-smooth objective function by accessing its subgradients. All we need is a good gambling strategy. See Algorithm 2 of:
References
[Orabona & Tommasi, 2017](https://arxiv.org/pdf/1705.07795.pdf)
- Parameters:
alpha (
float
) – fraction to bet parameter of the COCOB optimizereps (
float
) – jitter term to avoid dividing by 0
- Return type:
GradientTransformation
- Returns:
A GradientTransformation object.
- class optax.contrib.COCOBState(init_particles: base.Updates, cumulative_gradients: base.Updates, scale: base.Updates, subgradients: base.Updates, reward: base.Updates)[source]#
State for COntinuous COin Betting.
- init_particles: base.Updates#
Alias for field number 0
- cumulative_gradients: base.Updates#
Alias for field number 1
- scale: base.Updates#
Alias for field number 2
- subgradients: base.Updates#
Alias for field number 3
- reward: base.Updates#
Alias for field number 4
D-adaptation#
- optax.contrib.dadapt_adamw(learning_rate=1.0, betas=(0.9, 0.999), eps=1e-08, estim_lr0=1e-06, weight_decay=0.0)[source]#
Learning rate free AdamW by D-Adaptation.
Adapts the baseline learning rate of AdamW automatically by estimating the initial distance to solution in the infinity norm. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.
References
[Defazio & Mishchenko, 2023](https://arxiv.org/abs/2301.07733.pdf)
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.betas (
tuple
[float
,float
]) – Betas for the underlying AdamW Optimizer.eps (
float
) – eps for the underlying AdamW Optimizer.estim_lr0 (
float
) – Initial (under-)estimate of the learning rate.weight_decay (
float
) – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.
- Return type:
GradientTransformation
- Returns:
A GradientTransformation object.
- class optax.contrib.DAdaptAdamWState(exp_avg: base.Updates, exp_avg_sq: base.Updates, grad_sum: base.Updates, estim_lr: chex.Array, numerator_weighted: chex.Array, count: chex.Array)[source]#
State of the GradientTransformation returned by dadapt_adamw.
- exp_avg: base.Updates#
Alias for field number 0
- exp_avg_sq: base.Updates#
Alias for field number 1
- grad_sum: base.Updates#
Alias for field number 2
- estim_lr: chex.Array#
Alias for field number 3
- numerator_weighted: chex.Array#
Alias for field number 4
- count: chex.Array#
Alias for field number 5
Privacy-Sensitive Optax Methods#
|
State containing PRNGKey for differentially_private_aggregate. |
Aggregates gradients based on the DPSGD algorithm. |
Differentially Private Aggregate#
- optax.contrib.differentially_private_aggregate(l2_norm_clip, noise_multiplier, seed)[source]#
Aggregates gradients based on the DPSGD algorithm.
WARNING: Unlike other transforms, differentially_private_aggregate expects the input updates to have a batch dimension in the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap). It can still be composed with other transformations as long as it is the first in the chain.
References
[Abadi et al, 2016](https://arxiv.org/abs/1607.00133.pdf)
- Parameters:
l2_norm_clip (
float
) – maximum L2 norm of the per-example gradients.noise_multiplier (
float
) – ratio of standard deviation to the clipping norm.seed (
int
) – initial seed used for the jax.random.PRNGKey
- Return type:
GradientTransformation
- Returns:
A GradientTransformation.
- class optax.contrib.DifferentiallyPrivateAggregateState(rng_key: Any)[source]#
State containing PRNGKey for differentially_private_aggregate.
-
rng_key:
Any
# Alias for field number 0
-
rng_key:
- optax.contrib.dpsgd(learning_rate, l2_norm_clip, noise_multiplier, seed, momentum=None, nesterov=False)[source]#
The DPSGD optimizer.
Differential privacy is a standard for privacy guarantees of algorithms learning from aggregate databases including potentially sensitive information. DPSGD offers protection against a strong adversary with full knowledge of the training mechanism and access to the model’s parameters.
WARNING: This GradientTransformation expects input updates to have a batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap).
References
Abadi et al, 2016: https://arxiv.org/abs/1607.00133
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A fixed global scaling factor.l2_norm_clip (
float
) – Maximum L2 norm of the per-example gradients.noise_multiplier (
float
) – Ratio of standard deviation to the clipping norm.seed (
int
) – Initial seed used for the jax.random.PRNGKeymomentum (
Optional
[float
]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.nesterov (
bool
) – Whether Nesterov momentum is used.
- Return type:
GradientTransformation
- Returns:
A GradientTransformation.
Mechanize#
- optax.contrib.mechanize(base_optimizer, weight_decay=0.01, eps=1e-08, s_init=1e-06, num_betas=6)[source]#
Mechanic - a black box learning rate tuner/optimizer.
Accumulates updates returned by the base_optimizer and learns the scale of the updates (also know as learning rate or step size) to apply on a per iteration basis.
Note that Mechanic does NOT eschew the need for a learning rate schedule, you are free to apply a learning rate schedule with base learning rate set to 1.0 (or any other constant) and Mechanic will learn the right scale factor automatically.
For example, change this:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr) optimizer = optax.adam(learning_rate_fn)
To:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=1.0) optimizer = optax.adam(learning_rate_fn) optimizer = optax.contrib.mechanize(optimizer)
As of June, 2023, Mechanic is tested with SGD, Momentum, Adam and Lion as inner optimizers but we expect it to work with almost any first-order optimizer (except for normalized gradient optimizer like LARS or LAMB).
References
[Cutkosky et al, 2023](https://arxiv.org/pdf/2306.00144.pdf)
- Parameters:
base_optimizer (
GradientTransformation
) – Base optimizer to compute updates from.weight_decay (
float
) – A scalar weight decay rate. Note that this weight decay is not the same as the weight decay one would use for the base_optimizer. In addition to sometimes helping converge faster, this helps Mechanic reduce the variance between training runs using different seeds. You likely would not need to tune this, the default should work in most cases.eps (
float
) – epsilon for mechanic.s_init (
float
) – initial scale factor. Default should work almost all the time.num_betas (
int
) – unlike traditional exp accumulators (like 1st or 2nd moment of adam), where one has to choose an explicit beta, mechanic has a clever way to automatically learn the right beta for all accumulators. We only provide the range of possible betas, and not the tuned value. For instance, if you set num_betas to 3, it will use betas = [0.9, 0.99, 0.999].
- Return type:
GradientTransformation
- Returns:
A GradientTransformation with init and update functions.
- class optax.contrib.MechanicState(base_optimizer_state: base.OptState, count: chex.Array, r: chex.Array, m: chex.Array, v: chex.Array, s: chex.Array, x0: base.Updates)[source]#
State of the GradientTransformation returned by mechanize.
- base_optimizer_state: base.OptState#
Alias for field number 0
- count: chex.Array#
Alias for field number 1
- r: chex.Array#
Alias for field number 2
- m: chex.Array#
Alias for field number 3
- v: chex.Array#
Alias for field number 4
- s: chex.Array#
Alias for field number 5
- x0: base.Updates#
Alias for field number 6
Prodigy#
- optax.contrib.prodigy(learning_rate=1.0, betas=(0.9, 0.999), beta3=None, eps=1e-08, estim_lr0=1e-06, estim_lr_coef=1.0, weight_decay=0.0)[source]#
Learning rate free AdamW with Prodigy.
Implementation of the Prodigy method from “Prodigy: An Expeditiously Adaptive Parameter-Free Learner”, a version of D-Adapt AdamW that adapts the baseline learning rate faster by using a weighting of the gradients that places higher weights on more recent gradients. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.
References
[Mishchenko & Defazio, 2023](https://arxiv.org/abs/2306.06101.pdf)
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.betas (
tuple
[float
,float
]) – Betas for the underlying AdamW Optimizer.beta3 (
Optional
[float
]) – Optional momentum parameter for estimation of D.eps (
float
) – eps for the underlying AdamW Optimizer.estim_lr0 (
float
) – Initial (under-)estimate of the learning rate.estim_lr_coef (
float
) – LR estimates are multiplied by this parameter.weight_decay (
float
) – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.
- Return type:
GradientTransformation
- Returns:
A GradientTransformation object.
- class optax.contrib.ProdigyState(exp_avg: base.Updates, exp_avg_sq: base.Updates, grad_sum: base.Updates, params0: base.Updates, estim_lr: chex.Array, numerator_weighted: chex.Array, count: chex.Array)[source]#
State of the GradientTransformation returned by prodigy.
- exp_avg: base.Updates#
Alias for field number 0
- exp_avg_sq: base.Updates#
Alias for field number 1
- grad_sum: base.Updates#
Alias for field number 2
- params0: base.Updates#
Alias for field number 3
- estim_lr: chex.Array#
Alias for field number 4
- numerator_weighted: chex.Array#
Alias for field number 5
- count: chex.Array#
Alias for field number 6