🔧 Contrib#
Experimental features and algorithms that don’t meet the Inclusion Criteria.
|
Rescale updates according to the COntinuous COin Betting algorithm. |
|
State for COntinuous COin Betting. |
|
Learning rate free AdamW by D-Adaptation. |
|
State of the GradientTransformation returned by dadapt_adamw. |
Aggregates gradients based on the DPSGD algorithm. |
|
|
State containing PRNGKey for differentially_private_aggregate. |
|
Distance over Gradients optimizer. |
|
State for DoG optimizer. |
|
Distance over weighted Gradients optimizer. |
|
State for DoWG optimizer. |
|
The DPSGD optimizer. |
|
Mechanic - a black box learning rate tuner/optimizer. |
|
State of the GradientTransformation returned by mechanize. |
|
Adaptive Learning Rates for SGD with momentum. |
|
State of the GradientTransformation returned by momo. |
|
Adaptive Learning Rates for Adam(W). |
|
State of the |
|
Learning rate free AdamW with Prodigy. |
|
State of the GradientTransformation returned by prodigy. |
|
Implementation of SAM (Sharpness Aware Minimization). |
|
State of GradientTransformation returned by sam. |
|
Turn base_optimizer schedule_free. |
|
Params for evaluation of |
|
State for schedule_free. |
|
Splits the real and imaginary components of complex updates into two. |
|
Maintains the inner transformation state for split_real_and_imaginary. |
Complex-valued Optimization#
- optax.contrib.split_real_and_imaginary(inner)[source]#
Splits the real and imaginary components of complex updates into two.
The inner transformation processes real parameters and updates, and the pairs of transformed real updates are merged into complex updates.
Parameters and updates that are real before splitting are passed through unmodified.
- Parameters:
inner (
GradientTransformation
) – The inner transformation.- Return type:
GradientTransformation
- Returns:
An optax.GradientTransformation.
Continuous coin betting#
- optax.contrib.cocob(learning_rate=1.0, alpha=100, eps=1e-08, weight_decay=0, mask=None)[source]#
Rescale updates according to the COntinuous COin Betting algorithm.
Algorithm for stochastic subgradient descent. Uses a gambling algorithm to find the minimizer of a non-smooth objective function by accessing its subgradients. All we need is a good gambling strategy. See Algorithm 2 of:
References
[Orabona & Tommasi, 2017](https://arxiv.org/pdf/1705.07795.pdf)
- Parameters:
learning_rate (base.ScalarOrSchedule) – optional learning rate to e.g. inject some scheduler
alpha (float) – fraction to bet parameter of the COCOB optimizer
eps (float) – jitter term to avoid dividing by 0
weight_decay (float) – L2 penalty
mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – mask for weight decay
- Return type:
optax.GradientTransformation
- Returns:
A GradientTransformation object.
- class optax.contrib.COCOBState(init_particles: base.Updates, cumulative_gradients: base.Updates, scale: base.Updates, subgradients: base.Updates, reward: base.Updates)[source]#
State for COntinuous COin Betting.
- init_particles: base.Updates#
Alias for field number 0
- cumulative_gradients: base.Updates#
Alias for field number 1
- scale: base.Updates#
Alias for field number 2
- subgradients: base.Updates#
Alias for field number 3
- reward: base.Updates#
Alias for field number 4
D-adaptation#
- optax.contrib.dadapt_adamw(learning_rate=1.0, betas=(0.9, 0.999), eps=1e-08, estim_lr0=1e-06, weight_decay=0.0)[source]#
Learning rate free AdamW by D-Adaptation.
Adapts the baseline learning rate of AdamW automatically by estimating the initial distance to solution in the infinity norm. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.
References
[Defazio & Mishchenko, 2023](https://arxiv.org/abs/2301.07733.pdf)
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.betas (
tuple
[float
,float
]) – Betas for the underlying AdamW Optimizer.eps (
float
) – eps for the underlying AdamW Optimizer.estim_lr0 (
float
) – Initial (under-)estimate of the learning rate.weight_decay (
float
) – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.
- Return type:
GradientTransformation
- Returns:
A GradientTransformation object.
- class optax.contrib.DAdaptAdamWState(exp_avg: base.Updates, exp_avg_sq: base.Updates, grad_sum: base.Updates, estim_lr: chex.Array, numerator_weighted: chex.Array, count: chex.Array)[source]#
State of the GradientTransformation returned by dadapt_adamw.
- exp_avg: base.Updates#
Alias for field number 0
- exp_avg_sq: base.Updates#
Alias for field number 1
- grad_sum: base.Updates#
Alias for field number 2
- estim_lr: chex.Array#
Alias for field number 3
- numerator_weighted: chex.Array#
Alias for field number 4
- count: chex.Array#
Alias for field number 5
Privacy-Sensitive Optax Methods#
|
State containing PRNGKey for differentially_private_aggregate. |
Aggregates gradients based on the DPSGD algorithm. |
Differentially Private Aggregate#
- optax.contrib.differentially_private_aggregate(l2_norm_clip, noise_multiplier, seed)[source]#
Aggregates gradients based on the DPSGD algorithm.
WARNING: Unlike other transforms, differentially_private_aggregate expects the input updates to have a batch dimension in the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap). It can still be composed with other transformations as long as it is the first in the chain.
References
[Abadi et al, 2016](https://arxiv.org/abs/1607.00133.pdf)
- Parameters:
l2_norm_clip (
float
) – maximum L2 norm of the per-example gradients.noise_multiplier (
float
) – ratio of standard deviation to the clipping norm.seed (
int
) – initial seed used for the jax.random.PRNGKey
- Return type:
GradientTransformation
- Returns:
A GradientTransformation.
- class optax.contrib.DifferentiallyPrivateAggregateState(rng_key: Any)[source]#
State containing PRNGKey for differentially_private_aggregate.
-
rng_key:
Any
# Alias for field number 0
-
rng_key:
- optax.contrib.dpsgd(learning_rate, l2_norm_clip, noise_multiplier, seed, momentum=None, nesterov=False)[source]#
The DPSGD optimizer.
Differential privacy is a standard for privacy guarantees of algorithms learning from aggregate databases including potentially sensitive information. DPSGD offers protection against a strong adversary with full knowledge of the training mechanism and access to the model’s parameters.
WARNING: This GradientTransformation expects input updates to have a batch dimension on the 0th axis. That is, this function expects per-example gradients as input (which are easy to obtain in JAX using jax.vmap).
References
Abadi et al, 2016: https://arxiv.org/abs/1607.00133
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – A fixed global scaling factor.l2_norm_clip (
float
) – Maximum L2 norm of the per-example gradients.noise_multiplier (
float
) – Ratio of standard deviation to the clipping norm.seed (
int
) – Initial seed used for the jax.random.PRNGKeymomentum (
Optional
[float
]) – Decay rate used by the momentum term, when it is set to None, then momentum is not used at all.nesterov (
bool
) – Whether Nesterov momentum is used.
- Return type:
GradientTransformation
- Returns:
A GradientTransformation.
Distance over Gradients#
- optax.contrib.dog(learning_rate=1.0, reps_rel=1e-06, eps=1e-08, init_learning_rate=None, weight_decay=None, mask=None)[source]#
Distance over Gradients optimizer.
DoG updates parameters \(w_t\) with stochastic gradients \(g_t\) according to the update rule:
\[\begin{align*} \eta_t &= \frac{\max_{i\le t}{\|x_i-x_0\|}}{ \sqrt{\sum_{i\le t}{\|g_i\|^2+eps}}}\\ x_{t+1} & = x_{t} - \eta_t\, g_t, \end{align*}\]Examples
>>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.dog() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 2.2483316e-11 Objective function: 2.2483375e-11 Objective function: 2.2483435e-11 Objective function: 2.2483494e-11 Objective function: 2.2483555e-11
References
Ivgi et al., DoG is SGD’s Best Friend: A Parameter-Free Dynamic Step Size Schedule, 2023.
- Parameters:
learning_rate (base.ScalarOrSchedule) – optional learning rate (potentially varying according to some predetermined scheduler).
reps_rel (float) – value to use to compute the initial distance (r_epsilon in the paper). Namely, the first step size is given by: (reps_rel * (1+|x_0|)) / (|g_0|^2 + eps)^{1/2} where x_0 are the initial weights of the model (or the parameter group), and g_0 is the gradient of the first step. As discussed in the paper, this value should be small enough to ensure that the first update step will be small enough to not cause the model to diverge. Suggested value is 1e-6, unless the model uses batch-normalization, in which case the suggested value is 1e-4.
eps (float) – epsilon used for numerical stability - added to the sum of squared norm of gradients.
init_learning_rate (Optional[float]) – if specified, this value will be used the the initial learning rate (i.e. first step size) instead of the rule described above with reps_rel.
weight_decay (Optional[float]) – Strength of the weight decay regularization.
mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the gradient transformations is applied to all parameters.
- Returns:
The corresponding
optax.GradientTransformation
with associated init and update functions.
- class optax.contrib.DoGState(first_step: jax.Array, init_params: chex.ArrayTree, estim_dist: jax.Array, sum_sq_norm_grads: jax.Array)[source]#
State for DoG optimizer.
- optax.contrib.dowg(learning_rate=1.0, init_estim_sq_dist=None, eps=0.0001, weight_decay=None, mask=None)[source]#
Distance over weighted Gradients optimizer.
Examples
>>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.dowg() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 9.973327e-05 Objective function: 7.0334883 Objective function: 14.074293 Objective function: 49.897446 Objective function: 42.62062
References
Khaled et al., DoWG Unleashed: An Efficient Universal Parameter-Free Gradient Descent Method, 2023.
- Parameters:
learning_rate (base.ScalarOrSchedule) – optional learning rate (potentially varying according to some predetermined scheduler).
init_estim_sq_dist (Optional[float]) – initial guess of the squared distance to solution.
eps (float) – small value to prevent division by zero in the denominator definining, the learning rate, also used as initial guess for the distance to solution if
init_estim_sq_dist
is None.weight_decay (Optional[float]) – Strength of the weight decay regularization.
mask (Optional[Union[Any, Callable[[optax.Params], Any]]]) – A tree with same structure as (or a prefix of) the params PyTree, or a Callable that returns such a pytree given the params/updates. The leaves should be booleans, True for leaves/subtrees you want to apply the weight decay to, and False for those you want to skip. Note that the gradient transformations is applied to all parameters.
- Returns:
The corresponding
optax.GradientTransformation
with associated init and update functions.
Mechanize#
- optax.contrib.mechanize(base_optimizer, weight_decay=0.01, eps=1e-08, s_init=1e-06, num_betas=6)[source]#
Mechanic - a black box learning rate tuner/optimizer.
Accumulates updates returned by the base_optimizer and learns the scale of the updates (also know as learning rate or step size) to apply on a per iteration basis.
Note that Mechanic does NOT eschew the need for a learning rate schedule, you are free to apply a learning rate schedule with base learning rate set to 1.0 (or any other constant) and Mechanic will learn the right scale factor automatically.
For example, change this:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr) optimizer = optax.adam(learning_rate_fn)
To:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=1.0) optimizer = optax.adam(learning_rate_fn) optimizer = optax.contrib.mechanize(optimizer)
As of June, 2023, Mechanic is tested with SGD, Momentum, Adam and Lion as inner optimizers but we expect it to work with almost any first-order optimizer (except for normalized gradient optimizer like LARS or LAMB).
References
[Cutkosky et al, 2023](https://arxiv.org/pdf/2306.00144.pdf)
- Parameters:
base_optimizer (
GradientTransformation
) – Base optimizer to compute updates from.weight_decay (
float
) – A scalar weight decay rate. Note that this weight decay is not the same as the weight decay one would use for the base_optimizer. In addition to sometimes helping converge faster, this helps Mechanic reduce the variance between training runs using different seeds. You likely would not need to tune this, the default should work in most cases.eps (
float
) – epsilon for mechanic.s_init (
float
) – initial scale factor. Default should work almost all the time.num_betas (
int
) – unlike traditional exp accumulators (like 1st or 2nd moment of adam), where one has to choose an explicit beta, mechanic has a clever way to automatically learn the right beta for all accumulators. We only provide the range of possible betas, and not the tuned value. For instance, if you set num_betas to 3, it will use betas = [0.9, 0.99, 0.999].
- Return type:
GradientTransformation
- Returns:
A GradientTransformation with init and update functions.
- class optax.contrib.MechanicState(base_optimizer_state: base.OptState, count: chex.Array, r: chex.Array, m: chex.Array, v: chex.Array, s: chex.Array, x0: base.Updates)[source]#
State of the GradientTransformation returned by mechanize.
- base_optimizer_state: base.OptState#
Alias for field number 0
- count: chex.Array#
Alias for field number 1
- r: chex.Array#
Alias for field number 2
- m: chex.Array#
Alias for field number 3
- v: chex.Array#
Alias for field number 4
- s: chex.Array#
Alias for field number 5
- x0: base.Updates#
Alias for field number 6
Momo#
- optax.contrib.momo(learning_rate=1.0, beta=0.9, lower_bound=0.0, weight_decay=0.0, adapt_lower_bound=False)[source]#
Adaptive Learning Rates for SGD with momentum.
MoMo typically needs less tuning for value of
learning_rate
, by exploting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.MoMo performs SGD with momentum with a Polyak-type learning rate. The effective step size is
min(learning_rate, <adaptive term>)
, where the adaptive term is computed on the fly.Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument
value
.Examples
>>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.momo() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 3.5 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0
References
Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – User-specified learning rate. Recommended to be chosen rather large, by default 1.0.beta (
float
) – Momentum coefficient (for EMA).lower_bound (
float
) – Lower bound of the loss. Zero should be a good choice for many tasks.weight_decay (
float
) – Weight-decay parameter.adapt_lower_bound (
bool
) – If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).
- Return type:
GradientTransformationExtraArgs
- Returns:
A
GradientTransformation
object.
Added in version 0.2.3.
- class optax.contrib.MomoState(exp_avg: base.Updates, barf: chex.Array, gamma: chex.Array, lb: chex.Array, count: chex.Array)[source]#
State of the GradientTransformation returned by momo.
- optax.contrib.momo_adam(learning_rate=0.01, b1=0.9, b2=0.999, eps=1e-08, lower_bound=0.0, weight_decay=0.0, adapt_lower_bound=False)[source]#
Adaptive Learning Rates for Adam(W).
MoMo-Adam typically needs less tuning for value of
learning_rate
, by exploting the fact that a lower bound of the loss (or the optimal value) is known. For most tasks, zero is a lower bound and an accurate estimate of the final loss.MoMo performs Adam(W) with a Polyak-type learning rate. The effective step size is
min(learning_rate, <adaptive term>)
, where the adaptive term is computed on the fly.Note that one needs to pass the latest (batch) loss value to the update function using the keyword argument
value
.Examples
>>> from optax import contrib >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function >>> solver = contrib.momo_adam() >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 >>> opt_state = solver.init(params) >>> for _ in range(5): ... value, grad = jax.value_and_grad(f)(params) ... params, opt_state = solver.update(grad, opt_state, params, value=value) ... print('Objective function: ', f(params)) Objective function: 0.00029999594 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0 Objective function: 0.0
References
Schaipp et al., MoMo: Momentum Models for Adaptive Learning Rates, 2023
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – User-specified learning rate. Recommended to be chosen rather large, by default 1.0.b1 (
float
) – Exponential decay rate to track the first moment of past gradients.b2 (
float
) – Exponential decay rate to track the second moment of past gradients.eps (
float
) – eps for the underlying Adam Optimizer.lower_bound (
float
) – Lower bound of the loss. Zero should be a good choice for many tasks.weight_decay (
float
) – Weight-decay parameter. Momo-Adam performs weight decay in similar fashion to AdamW.adapt_lower_bound (
bool
) – If no good guess for the lower bound is available, set this to true, in order to estimate the lower bound on the fly (see the paper for details).
- Return type:
GradientTransformationExtraArgs
- Returns:
A
GradientTransformation
object.
Added in version 0.2.3.
Prodigy#
- optax.contrib.prodigy(learning_rate=1.0, betas=(0.9, 0.999), beta3=None, eps=1e-08, estim_lr0=1e-06, estim_lr_coef=1.0, weight_decay=0.0)[source]#
Learning rate free AdamW with Prodigy.
Implementation of the Prodigy method from “Prodigy: An Expeditiously Adaptive Parameter-Free Learner”, a version of D-Adapt AdamW that adapts the baseline learning rate faster by using a weighting of the gradients that places higher weights on more recent gradients. This method works best when combined with a learning rate schedule that treats 1.0 as the base (usually max) value.
References
[Mishchenko & Defazio, 2023](https://arxiv.org/abs/2306.06101.pdf)
- Parameters:
learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – Learning rate scheduling parameter. The recommended schedule is a linear_schedule with init_value=1.0 and end_value=0, combined with a 0-20% learning rate warmup.betas (
tuple
[float
,float
]) – Betas for the underlying AdamW Optimizer.beta3 (
Optional
[float
]) – Optional momentum parameter for estimation of D.eps (
float
) – eps for the underlying AdamW Optimizer.estim_lr0 (
float
) – Initial (under-)estimate of the learning rate.estim_lr_coef (
float
) – LR estimates are multiplied by this parameter.weight_decay (
float
) – AdamW style weight-decay. To use Regular Adam decay, chain with add_decayed_weights.
- Return type:
GradientTransformation
- Returns:
A GradientTransformation object.
- class optax.contrib.ProdigyState(exp_avg: base.Updates, exp_avg_sq: base.Updates, grad_sum: base.Updates, params0: base.Updates, estim_lr: chex.Array, numerator_weighted: chex.Array, count: chex.Array)[source]#
State of the GradientTransformation returned by prodigy.
- exp_avg: base.Updates#
Alias for field number 0
- exp_avg_sq: base.Updates#
Alias for field number 1
- grad_sum: base.Updates#
Alias for field number 2
- params0: base.Updates#
Alias for field number 3
- estim_lr: chex.Array#
Alias for field number 4
- numerator_weighted: chex.Array#
Alias for field number 5
- count: chex.Array#
Alias for field number 6
Schedule-Free#
- optax.contrib.schedule_free(base_optimizer, learning_rate, b1=0.9, weight_lr_power=2.0, state_dtype=<class 'jax.numpy.float32'>)[source]#
Turn base_optimizer schedule_free.
Accumulates updates returned by the base_optimizer w/o Momentum and replaces the momentum of an underlying optimizer with a combination of interpolation and averaging. In the case of gradient descent the update is
\[\begin{align*} y_{t} & = (1-\beta_1)z_{t} + \beta_1 x_{t},\\ z_{t+1} & =z_{t}-\gamma\nabla f(y_{t}),\\ x_{t+1} & =\left(1-\frac{1}{t}\right)x_{t}+\frac{1}{t}z_{t+1}, \end{align*}\]Here \(x\) is the sequence that evaluations of test/val loss should occur at, which differs from the primary iterates \(z\) and the gradient evaluation locations \(y\). The updates to \(z\) correspond to the underlying optimizer, in this case a simple gradient step. Note that, \(\beta_1\) corresponds to b1 in the code.
As the name suggests, Schedule-Free learning does not require a decreasing learning rate schedule, yet typically out-performs, or at worst matches, SOTA schedules such as cosine-decay and linear decay. Only two sequences need to be stored at a time (the third can be computed from the other two on the fly) so this method has the same memory requirements as the base optimizer (parameter buffer + momentum).
In practice, authors recommend tuning \(\beta_1\), warmup_steps and peak_lr for each problem seperately. Default for \(\beta_1\) is 0.9 but 0.95 and 0.98 may also work well. Schedule-Free can be wrapped on top of any optax optimizer. At test time, the parameters should be evaluated using
optax.contrib.schedule_free_eval_params()
as presented below.For example, change this:
learning_rate_fn = optax.warmup_cosine_decay_schedule(peak_value=tuned_lr) optimizer = optax.adam(learning_rate_fn, b1=b1)
To:
learning_rate_fn = optax.warmup_constant_schedule(peak_value=retuned_lr) optimizer = optax.adam(learning_rate_fn, b1=0.) optimizer = optax.contrib.schedule_free(optimizer, learning_rate_fn, b1=b1) .. params_for_eval = optax.contrib.schedule_free_eval_params(state, params)
Especially note that is important to switch off Momentum of the base optimizer. As of Apr, 2024, schedule_free is tested with SGD and Adam.
References
Defazio et al, Schedule-Free Learning - A New Way to Train, 2024
- Parameters:
base_optimizer (
GradientTransformation
) – Base optimizer to compute updates from.learning_rate (
Union
[float
,Array
,Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) – learning_rate schedule w/o decay but with warmup.b1 (
float
) – beta_1 parameter in the y update.weight_lr_power (
float
) – we downweight the weight of averaging using this. This is especially helpful in early iterations during warmup.state_dtype – dtype for z sequence.
- Return type:
GradientTransformationExtraArgs
- Returns:
A GradientTransformationExtraArgs with init and update functions.
- optax.contrib.schedule_free_eval_params(state, params)[source]#
Params for evaluation of
optax.contrib.schedule_free()
.
- class optax.contrib.ScheduleFreeState(b1: chex.Array, weight_sum: chex.Array, step_count: chex.Array, max_lr: chex.Array, base_optimizer_state: base.OptState, z: base.Params)[source]#
State for schedule_free.
- b1: chex.Array#
Alias for field number 0
- weight_sum: chex.Array#
Alias for field number 1
- step_count: chex.Array#
Alias for field number 2
- max_lr: chex.Array#
Alias for field number 3
- base_optimizer_state: base.OptState#
Alias for field number 4
- z: base.Params#
Alias for field number 5