Optimizer Schedules#
|
Constructs a constant schedule. |
|
Returns a function which implements cosine learning rate decay. |
|
Returns a function which implements the onecycle learning rate schedule. |
|
Constructs a schedule with either continuous or discrete exponential decay. |
|
Sequentially apply multiple schedules. |
|
Returns a function which implements the onecycle learning rate schedule. |
|
Schedule with linear transition from |
|
Returns a function which implements a piecewise constant schedule. |
|
Returns a function which implements a piecewise interpolated schedule. |
|
Constructs a schedule with polynomial transition from init to end value. |
|
SGD with warm restarts. |
|
Linear warmup followed by cosine decay. |
|
Linear warmup followed by exponential decay. |
alias of |
|
|
Deprecated class kept for backwards compatibility. |
|
Wrapper to injects stateful hyperparameters into GradientTransformations. |
- optax.Schedule#
alias of
Callable
[[Array
|ndarray
|bool_
|number
|float
|int
],Array
|ndarray
|bool_
|number
|float
|int
]
Constant schedule#
Cosine decay schedule#
- optax.cosine_decay_schedule(init_value, decay_steps, alpha=0.0, exponent=1.0)[source]#
Returns a function which implements cosine learning rate decay.
This schedule smoothly decreases the learning rate over a specified number of steps (
decay_steps
). The decay follows a cosine function, with an optional exponent to modify the decay curve. A minimum value (alpha
) ensures the learning rate does not drop entirely to zero.More precisely, the learning rate at iteration \(t\) is given by:
\[\frac{I (1 - \alpha)}{2}(1+\cos(\pi\,\frac{t}{T})^p) + \alpha\,,\]where \(T\) is the number of decay steps (
decay_steps
), \(p\) is theexponent
and \(I\) is the initial value (init_value
).References
Loshchilov et al., SGDR: Stochastic Gradient Descent with Warm Restarts, 2017
- Parameters:
init_value (
float
) β An initial value for the learning rate.decay_steps (
int
) β Positive integer - the number of steps for which to apply the decay for.alpha (
float
) β The minimum value of the multiplier used to adjust the learning rate. Defaults to 0.0.exponent (
float
) β The default decay is0.5 * (1 + cos(pi * t/T))
, wheret
is the current timestep andT
is thedecay_steps
. The exponent modifies this to be(0.5 * (1 + cos(pi * t/T))) ** exponent
. Defaults to 1.0.
- Return type:
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]- Returns:
- schedule
A function that maps step counts to values.
- optax.cosine_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, div_factor=25.0, final_div_factor=10000.0)[source]#
Returns a function which implements the onecycle learning rate schedule.
This learning rate increases the learning rate and then decreases it in a cosine-like manner. The number of steps over which the learning rate increases is determined by the
pct_start
argument. The maximum value of the learning rate is determined by thepeak_value
argument, the initial value of the learning rate is determined through the formulainit_value = peak_value / div_factor
, and the final value is determined by thefinal_div_factor
argument.References
Smith et al, Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates, 2017
- Parameters:
transition_steps (
int
) β Number of steps over which annealing takes place.peak_value (
float
) β Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).pct_start (
float
) β The percentage of the cycle (in number of steps) spent increasing the learning rate.div_factor (
float
) β Determines the initial value viainit_value = peak_value / div_factor
.final_div_factor (
float
) β Determines the final value viafinal_value = init_value / final_div_factor
.
- Return type:
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]- Returns:
- schedule
A function that maps step counts to values
Exponential decay schedule#
- optax.exponential_decay(init_value, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None)[source]#
Constructs a schedule with either continuous or discrete exponential decay.
This function applies an exponential decay function to a provided initial value. When
count >= transition_begin
the function returns the decayed value as:rate_factor = ((count - transition_begin) / transition_steps) decayed_value = init_value * (decay_rate ** rate_factor)
If the argument
staircase
isTrue
thencount / transition_steps
is an integer division and the decayed value follows a staircase function.- Parameters:
init_value (
float
) β the initial learning rate.transition_steps (
int
) β must be positive. See the decay computation above.decay_rate (
float
) β must not be zero. The decay rate.transition_begin (
int
) β must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).staircase (
bool
) β ifTrue
, decay the values at discrete intervals.end_value (
Optional
[float
]) β the value at which the exponential decay stops. Whendecay_rate < 1
,end_value
is treated as a lower bound, otherwise as an upper bound. Has no effect whendecay_rate = 0
.
- Return type:
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]- Returns:
- schedule
A function that maps step counts to values.
Join schedules#
- optax.join_schedules(schedules, boundaries)[source]#
Sequentially apply multiple schedules.
- Parameters:
schedules (
Sequence
[Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) β A list of callables (expected to be optax schedules). Each schedule will receive a step count indicating the number of steps since the previous boundary transition.boundaries (
Sequence
[int
]) β A list of integers (of length one less than schedules) that indicate when to transition between schedules.
- Returns:
A function that maps step counts to values.
- Return type:
schedule
Inject hyperparameters#
- optax.inject_hyperparams(inner_factory, static_args=(), hyperparam_dtype=None)[source]#
Wrapper to injects stateful hyperparameters into GradientTransformations.
This wrapper allows you to pass schedules (i.e. a function that returns a numeric value given a step count) instead of constants for hyperparameters. You may only schedule numeric hyperparameters (i.e. boolean flags cannot be scheduled).
This function supports both passing simple schedules that are function exclusively of the step count and also passing stateful schedules that rely on a complex internal state. The state updating can rely on additional information fed to gradient transformations via extra_args.
For example, to use
optax.scale_by_adam()
with a piecewise linear schedule for beta_1 and constant for beta_2:>>> import optax >>> import jax.numpy as jnp >>> # create a learning rate that increases linearly from 0.1 to 1.0 ... # over 100 iterations >>> linear_schedule = optax.piecewise_interpolate_schedule( ... 'linear', init_value=0.1, boundaries_and_scales={100: 1.}) >>> scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)( ... b1=linear_schedule, b2=0.99)
You may manually change numeric hyperparameters that were not scheduled through the
hyperparams
dict in theInjectHyperparamState
:>>> params, grads = jnp.array(0.), jnp.array(0.) >>> state = scheduled_adam.init(params) >>> updates, state = scheduled_adam.update(grads, state) >>> state.hyperparams['b2'] = 0.95 >>> updates, state = scheduled_adam.update(updates, state) # uses b2 = 0.95
Manually overriding scheduled hyperparameters will have no effect (e.g. in the code sample above, you cannot manually adjust
b1
).- Parameters:
inner_factory (
Callable
[...
,GradientTransformation
]) β a function that returns the inneroptax.GradientTransformation
with dynamic hyperparameters.static_args (
Union
[str
,Iterable
[str
]]) β a string or iterable of strings specifying which callable parameters are not schedules. inject_hyperparams treats all callables as schedules by default, so if a hyperparameter is a non-schedule callable, you must specify that using this argument.hyperparam_dtype (
Optional
[dtype
]) β Optional datatype override. If specified, all float hyperparameters will be cast to this type.
- Return type:
Callable
[...
,GradientTransformationExtraArgs
]- Returns:
A callable that returns a
optax.GradientTransformationExtraArgs
. This callable accepts the same arguments asinner_factory
, except you may provide schedules in place of the constant arguments.
Changed in version 0.1.9: New parameter
hyperparam_dtype
, the returned callable outputs aGradientTransformationExtraArgs
instead of aGradientTransformation
.
- class optax.InjectHyperparamsState(count: jnp.ndarray, hyperparams: dict[str, chex.Numeric], inner_state: base.OptState)[source]#
Deprecated class kept for backwards compatibility.
Deprecated since version 0.1.9: Use
InjectStatefulHyperparamsState
instead.- count: jnp.ndarray#
Alias for field number 0
- hyperparams: dict[str, chex.Numeric]#
Alias for field number 1
- inner_state: base.OptState#
Alias for field number 2
Linear schedules#
- optax.linear_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, div_factor=25.0, final_div_factor=10000.0)[source]#
Returns a function which implements the onecycle learning rate schedule.
This function uses a linear annealing strategy.
References
Smith et al, Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates, 2017
- Parameters:
transition_steps (
int
) β Number of steps over which annealing takes place.peak_value (
float
) β Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).pct_start (
float
) β The percentage of the cycle (in number of steps) spent increasing the learning rate.pct_final (
float
) β The percentage of the cycle (in number of steps) spent increasing topeak_value
then decreasing back toinit_value
.div_factor (
float
) β Determines the initial value viainit_value = peak_value / div_factor
.final_div_factor (
float
) β Determines the final value viafinal_value = init_value / final_div_factor
.
- Return type:
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]- Returns:
- schedule
A function that maps step counts to values
- optax.linear_schedule(init_value, end_value, transition_steps, transition_begin=0)[source]#
Schedule with linear transition from
init_value
toend_value
.More precisely, the learning rate at iteration \(t\) is given by:
\[\begin{cases} I, & \text{if } t < B \\ I + \frac{t - B}{T} (E - I), & \text{if } B \leq t < B + T \\ E, & \text{if } t \geq B + T \end{cases} \]where \(I\) is the initial value, \(E\) is the end value, \(B\) is the transition begin, and \(T\) is the transition steps.
This schedule is equivalent to
optax.polynomial_schedule()
withpower=1
.Examples
>>> schedule_fn = optax.linear_schedule( ... init_value=1.0, end_value=0.01, transition_steps=100) >>> schedule_fn(0) # learning rate on the first iteration Array(1., dtype=float32, weak_type=True) >>> schedule_fn(100) # learning rate on the last iteration Array(0.01, dtype=float32, weak_type=True)
- Parameters:
init_value (
Union
[float
,int
]) β initial value for the scalar to be annealed.end_value (
Union
[float
,int
]) β end value of the scalar to be annealed.transition_steps (
int
) β number of steps over which annealing takes place. The scalar starts changing attransition_begin
steps and completes the transition bytransition_begin + transition_steps
steps. Iftransition_steps <= 0
, then the entire annealing process is disabled and the value is held fixed atinit_value
.transition_begin (
int
) β must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed atinit_value
).
- Return type:
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]- Returns:
- schedule
A function that maps step counts to values.
Piecewise schedules#
- optax.piecewise_constant_schedule(init_value, boundaries_and_scales=None)[source]#
Returns a function which implements a piecewise constant schedule.
- Parameters:
init_value (
float
) β An initial valueinit_v
.boundaries_and_scales (
Optional
[dict
[int
,float
]]) β A map from boundariesb_i
to non-negative scaling factorsf_i
. For any step count s, the schedule returnsinit_v
scaled by the product of all factorsf_i
such thatb_i < s
.
- Return type:
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]- Returns:
- schedule
A function that maps step counts to values.
- optax.piecewise_interpolate_schedule(interpolate_type, init_value, boundaries_and_scales=None)[source]#
Returns a function which implements a piecewise interpolated schedule.
- Parameters:
interpolate_type (
str
) β βlinearβ or βcosineβ, specifying the interpolation strategy.init_value (
float
) β An initial valueinit_v
.boundaries_and_scales (
Optional
[dict
[int
,float
]]) β A map from boundariesb_i
to non-negative scaling factorsf_i
. At boundary stepb_i
, the schedule returnsinit_v
scaled by the product of all factorsf_j
such thatb_j <= b_i
. The values in between each boundary will be interpolated as pertype
.
- Return type:
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]- Returns:
- schedule
A function that maps step counts to values.
Polynomial schedules#
- optax.polynomial_schedule(init_value, end_value, power, transition_steps, transition_begin=0)[source]#
Constructs a schedule with polynomial transition from init to end value.
- Parameters:
init_value (
Union
[float
,int
]) β initial value for the scalar to be annealed.end_value (
Union
[float
,int
]) β end value of the scalar to be annealed.power (
Union
[float
,int
]) β the power of the polynomial used to transition from init to end.transition_steps (
int
) β number of steps over which annealing takes place. The scalar starts changing attransition_begin
steps and completes the transition bytransition_begin + transition_steps
steps. Iftransition_steps <= 0
, then the entire annealing process is disabled and the value is held fixed atinit_value
.transition_begin (
int
) β must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed atinit_value
).
- Return type:
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]- Returns:
- schedule
A function that maps step counts to values.
Reduce on plateau#
- optax.contrib.reduce_on_plateau(factor=0.1, patience=10, rtol=0.0001, atol=0.0, cooldown=0, accumulation_size=1)[source]#
Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning once learning stagnates. his scheduler reads a metrics quantity and if no improvement is seen for a
patience
number of epochs, the learning rate is reduced.- Parameters:
factor (
float
) β Factor by which to reduce the learning rate. new_scale = scale * factor.patience (
int
) β Number of iterations with no improvement after which learning rate will be reduced.rtol (
float
) β Relative tolerance for measuring new optimum.atol (
float
) β Absolute tolerance for measuring new optimum.cooldown (
int
) β Number of iterations to wait before resuming normal operation after scale has been reduced.accumulation_size (
int
) β Number of valeus to aggregate before applying the logic of reduce on plateau. If the value fed to the optimizer is a test value, simply take 1 (default). If the value fed to the optimizer is the loss on a the current minibatch, consider using a larger accumulation size.
- Return type:
GradientTransformationExtraArgs
- Returns:
A GradientTransformationExtraArgs object.
See also
Schedules with warm-up#
- optax.warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value=0.0, exponent=1.0)[source]#
Linear warmup followed by cosine decay.
- Parameters:
init_value (
float
) β Initial value for the scalar to be annealed.peak_value (
float
) β Peak value for scalar to be annealed at end of warmup.warmup_steps (
int
) β Positive integer, the length of the linear warmup.decay_steps (
int
) β Positive integer, the total length of the schedule. Note that this includes the warmup time, so the number of steps during which cosine annealing is applied isdecay_steps - warmup_steps
.end_value (
float
) β End value of the scalar to be annealed.exponent (
float
) β Float. The default decay is0.5 * (1 + cos(pi t/T))
, wheret
is the current timestep andT
isdecay_steps
. The exponent modifies this to be(0.5 * (1 + cos(pi * t/T))) ** exponent
. Defaults to 1.0.
- Return type:
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]- Returns:
- schedule
A function that maps step counts to values
- optax.warmup_exponential_decay_schedule(init_value, peak_value, warmup_steps, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None)[source]#
Linear warmup followed by exponential decay.
- Parameters:
init_value (
float
) β Initial value for the scalar to be annealed.peak_value (
float
) β Peak value for scalar to be annealed at end of warmup.warmup_steps (
int
) β Positive integer, the length of the linear warmup.transition_steps (
int
) β must be positive. Seeoptax.exponential_decay()
for more details.decay_rate (
float
) β must not be zero. The decay rate.transition_begin (
int
) β must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed atpeak_value
).staircase (
bool
) β ifTrue
, decay the values at discrete intervals.end_value (
Optional
[float
]) β the value at which the exponential decay stops. Whendecay_rate < 1
,end_value
is treated as a lower bound, otherwise as an upper bound. Has no effect whendecay_rate = 0
.
- Return type:
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]- Returns:
- schedule
A function that maps step counts to values
Warm restarts#
- optax.sgdr_schedule(cosine_kwargs)[source]#
SGD with warm restarts.
This learning rate schedule applies multiple joined cosine decay cycles.
References
Loshchilov et al., SGDR: Stochastic Gradient Descent with Warm Restarts, 2017
- Parameters:
cosine_kwargs (
Iterable
[dict
[str
,Union
[Array
,ndarray
,bool_
,number
,float
,int
]]]) β An Iterable of dicts, where each element specifies the arguments to pass to each cosine decay cycle. Thedecay_steps
kwarg will specify how long each cycle lasts for, and therefore when to transition to the next cycle.- Return type:
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]- Returns:
- schedule
A function that maps step counts to values