Optimizer Schedules#
|
Constructs a constant schedule. |
|
Returns a function which implements cosine learning rate decay. |
|
Returns a function which implements the onecycle learning rate schedule. |
|
Constructs a schedule with either continuous or discrete exponential decay. |
|
Sequentially apply multiple schedules. |
|
Returns a learning rate with three linear phases. |
|
Schedule with linear transition from |
|
Piecewise constant schedule with scaled jumps at specific boundaries. |
|
Piecewise interpolated schedule with linear or cosine transitions. |
|
Constructs a schedule with polynomial transition from init to end value. |
|
SGD with warm restarts. |
|
Linear warmup followed by constant schedule i.e no decay. |
|
Linear warmup followed by cosine decay. |
|
Linear warmup followed by exponential decay. |
|
Deprecated class kept for backwards compatibility. |
|
Wrapper to injects stateful hyperparameters into GradientTransformations. |
- optax.schedules.Schedule#
alias of
Callable[[Array|ndarray|bool|number|bool|int|float|complex|TypedNdArray],Array|ndarray|bool|number|bool|int|float|complex|TypedNdArray]
Constant schedule#
- optax.schedules.constant_schedule(value: jax.typing.ArrayLike) base.Schedule[source]#
Constructs a constant schedule.
- Parameters:
value β value to be held constant throughout.
- Returns:
- schedule
A function that maps step counts to values.
Examples
>>> schedule_fn = optax.constant_schedule(5) >>> schedule_fn(0) 5 >>> schedule_fn(100) 5
- optax.schedules.warmup_constant_schedule(init_value: jax.typing.ArrayLike, peak_value: jax.typing.ArrayLike, warmup_steps: int) base.Schedule[source]#
Linear warmup followed by constant schedule i.e no decay.
- Parameters:
init_value β Initial value for the scalar to be annealed.
peak_value β Peak value for scalar to be annealed at end of warmup.
warmup_steps β Positive integer, the length of the linear warmup.
- Returns:
- schedule
A function that maps step counts to values
Cosine decay schedule#
- optax.schedules.cosine_decay_schedule(init_value: jax.typing.ArrayLike, decay_steps: int, alpha: jax.typing.ArrayLike = 0.0, exponent: jax.typing.ArrayLike = 1.0) base.Schedule[source]#
Returns a function which implements cosine learning rate decay.
This schedule smoothly decreases the learning rate over a specified number of steps (
decay_steps). The decay follows a cosine function, with an optional exponent to modify the decay curve. A minimum value (alpha) ensures the learning rate does not drop entirely to zero.More precisely, the learning rate at iteration \(t\) is given by:
\[\begin{cases} \alpha I + (1 - \alpha) I \left[ \frac{1}{2} \left( 1+ \cos \left( \pi\,\frac{t}{T} \right) \right) \right] ^p & \text{, if } t \leq T \\ \alpha I & \text{, if } t > T \end{cases} \]where \(T\) is the number of decay steps (
decay_steps), \(p\) is theexponentand \(I\) is the initial value (init_value).- Parameters:
init_value β An initial value for the learning rate.
decay_steps β Positive integer - the number of steps for which to apply the decay for.
alpha β The minimum value of the multiplier used to adjust the learning rate. Defaults to 0.0.
exponent β The default decay is
0.5 * (1 + cos(pi * t/T)), wheretis the current timestep andTis thedecay_steps. The exponent modifies this to be(0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.
- Returns:
- schedule
A function that maps step counts to values.
References
Loshchilov et al., SGDR: Stochastic Gradient Descent with Warm Restarts, 2017
- optax.schedules.cosine_onecycle_schedule(transition_steps: int, peak_value: jax.typing.ArrayLike, pct_start: float = 0.3, div_factor: float = 25.0, final_div_factor: float = 10000.0) base.Schedule[source]#
Returns a function which implements the onecycle learning rate schedule.
This schedule increases the learning rate and then decreases it in a cosine-like manner. The number of steps over which the learning rate increases is determined by the
pct_startargument. The maximum value of the learning rate is determined by thepeak_valueargument, the initial value of the learning rate is determined through the formulainit_value = peak_value / div_factor, and the final value is determined by thefinal_div_factorargument.- Parameters:
transition_steps β Number of steps over which annealing takes place.
peak_value β Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).
pct_start β The percentage of the cycle (in number of steps) spent increasing the learning rate.
div_factor β Determines the initial value via
init_value = peak_value / div_factor.final_div_factor β Determines the final value via
final_value = init_value / final_div_factor.
- Returns:
- schedule
A function that maps step counts to values
References
Smith et al, Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates, 2017
- optax.schedules.warmup_cosine_decay_schedule(init_value: jax.typing.ArrayLike, peak_value: jax.typing.ArrayLike, warmup_steps: int, decay_steps: int, end_value: jax.typing.ArrayLike = 0.0, exponent: jax.typing.ArrayLike = 1.0) base.Schedule[source]#
Linear warmup followed by cosine decay.
- Parameters:
init_value β Initial value for the scalar to be annealed.
peak_value β Peak value for scalar to be annealed at end of warmup.
warmup_steps β Positive integer, the length of the linear warmup.
decay_steps β Positive integer, the total length of the schedule. Note that this includes the warmup time, so the number of steps during which cosine annealing is applied is
decay_steps - warmup_steps.end_value β End value of the scalar to be annealed.
exponent β The default decay is
0.5 * (1 + cos(pi t/T)), wheretis the current timestep andTisdecay_steps. The exponent modifies this to be(0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.
- Returns:
- schedule
A function that maps step counts to values
Exponential decay schedule#
- optax.schedules.exponential_decay(init_value: jax.typing.ArrayLike, transition_steps: int, decay_rate: float, transition_begin: int = 0, staircase: bool = False, end_value: jax.typing.ArrayLike | None = None) base.Schedule[source]#
Constructs a schedule with either continuous or discrete exponential decay.
This function applies an exponential decay function to a provided initial value. When
count >= transition_beginthe function returns the decayed value as:rate_factor = ((count - transition_begin) / transition_steps) decayed_value = init_value * (decay_rate ** rate_factor)
If the argument
staircaseisTruethencount / transition_stepsis an integer division and the decayed value follows a staircase function.- Parameters:
init_value β the initial learning rate.
transition_steps β must be positive. See the decay computation above.
decay_rate β must not be zero. The decay rate.
transition_begin β must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).
staircase β if
True, decay the values at discrete intervals.end_value β the value at which the exponential decay stops. When
decay_rate < 1,end_valueis treated as a lower bound, otherwise as an upper bound. Has no effect whendecay_rate = 0.
- Returns:
- schedule
A function that maps step counts to values.
- optax.schedules.warmup_exponential_decay_schedule(init_value: jax.typing.ArrayLike, peak_value: jax.typing.ArrayLike, warmup_steps: int, transition_steps: int, decay_rate: jax.typing.ArrayLike, transition_begin: int = 0, staircase: bool = False, end_value: jax.typing.ArrayLike | None = None) base.Schedule[source]#
Linear warmup followed by exponential decay.
- Parameters:
init_value β Initial value for the scalar to be annealed.
peak_value β Peak value for scalar to be annealed at end of warmup.
warmup_steps β Positive integer, the length of the linear warmup.
transition_steps β must be positive. See
optax.schedules.exponential_decay()for more details.decay_rate β must not be zero. The decay rate.
transition_begin β must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at
peak_value).staircase β if
True, decay the values at discrete intervals.end_value β the value at which the exponential decay stops. When
decay_rate < 1,end_valueis treated as a lower bound, otherwise as an upper bound. Has no effect whendecay_rate = 0.
- Returns:
- schedule
A function that maps step counts to values
Join schedules#
- optax.schedules.join_schedules(schedules: Sequence[base.Schedule], boundaries: Sequence[int]) base.Schedule[source]#
Sequentially apply multiple schedules.
- Parameters:
schedules β A list of callables (expected to be optax schedules). Each schedule will receive a step count indicating the number of steps since the previous boundary transition.
boundaries β A list of integers (of length one less than schedules) that indicate when to transition between schedules.
- Returns:
A function that maps step counts to values.
- Return type:
schedule
Inject hyperparameters#
- optax.schedules.inject_hyperparams(inner_factory: Callable[..., base.GradientTransformation], static_args: str | Iterable[str] = (), hyperparam_dtype: jnp.dtype | None = None) Callable[..., base.GradientTransformationExtraArgs][source]#
Wrapper to injects stateful hyperparameters into GradientTransformations.
This wrapper allows you to pass schedules (i.e. a function that returns a numeric value given a step count) instead of constants for hyperparameters. You may only schedule numeric hyperparameters (i.e. boolean flags cannot be scheduled).
This function supports both passing simple schedules that are function exclusively of the step count and also passing stateful schedules that rely on a complex internal state. The state updating can rely on additional information fed to gradient transformations via extra_args.
For example, to use
optax.scale_by_adam()with a piecewise linear schedule for beta_1 and constant for beta_2:>>> import optax >>> import jax.numpy as jnp >>> # create a learning rate that increases linearly from 0.1 to 1.0 ... # over 100 iterations >>> linear_schedule = optax.piecewise_interpolate_schedule( ... 'linear', init_value=0.1, boundaries_and_scales={100: 1.}) >>> scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)( ... b1=linear_schedule, b2=0.99)
You may manually change numeric hyperparameters that were not scheduled through the
hyperparamsdict in theInjectHyperparamsState:>>> params, grads = jnp.array(0.), jnp.array(0.) >>> state = scheduled_adam.init(params) >>> updates, state = scheduled_adam.update(grads, state) >>> state.hyperparams['b2'] = 0.95 >>> updates, state = scheduled_adam.update(updates, state) # uses b2 = 0.95
Manually overriding scheduled hyperparameters will have no effect (e.g. in the code sample above, you cannot manually adjust
b1).- Parameters:
inner_factory β a function that returns the inner
optax.GradientTransformationwith dynamic hyperparameters.static_args β a string or iterable of strings specifying which callable parameters are not schedules. inject_hyperparams treats all callables as schedules by default, so if a hyperparameter is a non-schedule callable, you must specify that using this argument.
hyperparam_dtype β Optional datatype override. If specified, all float hyperparameters will be cast to this type.
- Returns:
A callable that returns a
optax.GradientTransformationExtraArgs. This callable accepts the same arguments asinner_factory, except you may provide schedules in place of the constant arguments.
Changed in version 0.1.9: New parameter
hyperparam_dtype, the returned callable outputs aGradientTransformationExtraArgsinstead of aGradientTransformation.
Linear schedules#
- optax.schedules.linear_onecycle_schedule(transition_steps: int, peak_value: jax.typing.ArrayLike, pct_start: float = 0.3, pct_final: float = 0.85, div_factor: float = 25.0, final_div_factor: float = 10000.0) base.Schedule[source]#
Returns a learning rate with three linear phases.
Phase 1, from iteration 0 to
pct_start * transition_steps. The learning rate increases linearly frompeak_value / div_factortopeak_value.Phase 2, from iteration
pct_start * transition_stepstopct_final * transition_steps. The learning rate decreases linearly frompeak_valueback to the initialpeak_value/div_factor.Phase 3: For the remaining steps, the learning rate interpolates between
peak_value/div_factorandpeak_value / final_div_factor. Iffinal_div_factoris larger thandiv_factor, this is a decreasing phase.
- Parameters:
transition_steps β Number of steps over which annealing takes place.
peak_value β Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).
pct_start β The percentage of the cycle (in number of steps) spent increasing the learning rate.
pct_final β The percentage of the cycle (in number of steps) spent increasing to
peak_valuethen decreasing back toinit_value.div_factor β Determines the initial value via
init_value = peak_value / div_factor.final_div_factor β Determines the final value via
final_value = init_value / final_div_factor.
- Returns:
- schedule
A function that maps step counts to values
References
Smith et al, Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates, 2017
- optax.schedules.linear_schedule(init_value: jax.typing.ArrayLike, end_value: jax.typing.ArrayLike, transition_steps: int, transition_begin: int = 0) base.Schedule[source]#
Schedule with linear transition from
init_valuetoend_value.More precisely, the learning rate at iteration \(t\) is given by:
\[\begin{cases} I, & \text{if } t < B \\ I + \frac{t - B}{T} (E - I), & \text{if } B \leq t < B + T \\ E, & \text{if } t \geq B + T \end{cases} \]where \(I\) is the initial value, \(E\) is the end value, \(B\) is the transition begin, and \(T\) is the transition steps.
This schedule is equivalent to
optax.polynomial_schedule()withpower=1.- Parameters:
init_value β initial value for the scalar to be annealed.
end_value β end value of the scalar to be annealed.
transition_steps β number of steps over which annealing takes place. The scalar starts changing at
transition_beginsteps and completes the transition bytransition_begin + transition_stepssteps. Iftransition_steps <= 0, then the entire annealing process is disabled and the value is held fixed atinit_value.transition_begin β must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at
init_value).
- Returns:
- schedule
A function that maps step counts to values.
Examples
>>> schedule_fn = optax.linear_schedule( ... init_value=1.0, end_value=0.01, transition_steps=100) >>> schedule_fn(0) # learning rate on the first iteration Array(1., dtype=float32, weak_type=True) >>> schedule_fn(100) # learning rate on the last iteration Array(0.01, dtype=float32, weak_type=True)
Piecewise schedules#
- optax.schedules.piecewise_constant_schedule(init_value: jax.typing.ArrayLike, boundaries_and_scales: dict[int, float] | None = None) base.Schedule[source]#
Piecewise constant schedule with scaled jumps at specific boundaries.
At each step t, this schedule returns init_value scaled by the product of all factors f_i such that t >= b_i, where (b_i, f_i) are the entries in boundaries_and_scales.
- Parameters:
init_value β The starting value of the schedule.
boundaries_and_scales β Dictionary of {step: scale} where scale is multiplied into the schedule value at the given step. All scale values must be non-negative.
- Returns:
A function that maps step index to schedule value.
Example
>>> sched = optax.piecewise_constant_schedule( ... init_value=1.0, boundaries_and_scales={100: 0.1, 200: 0.01}) >>> print(sched(50)) # before first boundary 1.0 >>> print(sched(150)) # after first boundary 0.1 >>> print(sched(250)) # after second boundary 0.001
- optax.schedules.piecewise_interpolate_schedule(interpolate_type: str, init_value: jax.typing.ArrayLike, boundaries_and_scales: dict[int, float] | None = None) base.Schedule[source]#
Piecewise interpolated schedule with linear or cosine transitions.
This schedule interpolates smoothly between values scaled at specified step boundaries. The interpolation occurs between the accumulated scaled values, not the raw multiplicative scales themselves.
At each boundary, the scaling factor is multiplied into the current value. Between these boundaries, the schedule applies either linear or cosine interpolation.
- Parameters:
interpolate_type β Either βlinearβ or βcosineβ, specifying the interpolation method used between boundary segments.
init_value β Starting value for the schedule.
boundaries_and_scales β A dictionary {step: scale} that defines the boundaries and their corresponding multiplicative scaling factors.
- Returns:
A function that maps step counts to interpolated values.
Example
>>> sched = optax.piecewise_interpolate_schedule( ... interpolate_type='linear', ... init_value=1.0, ... boundaries_and_scales={4: 0.1, 8: 0.01} ... ) >>> for step in range(0, 11, 2): ... print(f"Step {step}: {sched(step):.6f}") Step 0: 1.000000 Step 2: 0.550000 Step 4: 0.100000 Step 6: 0.050500 Step 8: 0.001000 Step 10: 0.001000
Note
This schedule accumulates scaling factors and then interpolates between those accumulated values. It does not interpolate directly between the provided scales. This behavior can appear counterintuitive but allows for smooth and controlled transitions in learning rates.
Polynomial schedules#
- optax.schedules.polynomial_schedule(init_value: jax.typing.ArrayLike, end_value: jax.typing.ArrayLike, power: jax.typing.ArrayLike, transition_steps: int, transition_begin: int = 0) base.Schedule[source]#
Constructs a schedule with polynomial transition from init to end value.
This function transitions the learning rate from an initial value (
init_value) to a final value (end_value) over a specified number of steps (transition_steps) with a polynomial function of powerpower. The transition can optionally begin after a specified number of initial steps (transition_begin).More precisely, the learning rate at iteration \(t\) is given by:
\[\begin{cases} I, & \text{if } t < B \\ (I - E) \left( 1 - \frac{t - B}{T} \right)^{P} + E, & \text{if } B \leq t < B + T \\ E, & \text{if } t \geq B + T \end{cases} \]where \(I\) is the initial value, \(E\) is the end value, \(B\) is the transition begin, \(T\) is the transition steps, and \(P\) is the power used for the polynomial transition.
- Parameters:
init_value β initial value for the scalar to be annealed.
end_value β end value of the scalar to be annealed.
power β the power of the polynomial used to transition from init to end.
transition_steps β number of steps over which annealing takes place. The scalar starts changing at
transition_beginsteps and completes the transition bytransition_begin + transition_stepssteps. Iftransition_steps <= 0, then the entire annealing process is disabled and the value is held fixed atinit_value.transition_begin β must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at
init_value).
- Returns:
- schedule
A function that maps step counts to values.
Examples
>>> schedule_fn = optax.polynomial_schedule( ... init_value=1.0, end_value=0.01, transition_steps=100, power=2) >>> schedule_fn(0) # learning rate on the first iteration Array(1., dtype=float32, weak_type=True) >>> schedule_fn(100) # learning rate on the last iteration Array(0.01, dtype=float32, weak_type=True)
The following example uses a non-zero
transition_begin. In this case the learning rate is kept constant for the firsttransition_beginiterations:>>> schedule_fn = optax.polynomial_schedule( ... init_value=1.0, ... end_value=0.01, ... transition_steps=100, ... transition_begin=5, ... power=2, ... ) >>> counts = [0, 5, 6, 104, 105, 110] >>> print( ... *[f'count:{i} value:{schedule_fn(i):.4f}' for i in counts], ... sep='\n') count:0 value:1.0000 count:5 value:1.0000 count:6 value:0.9803 count:104 value:0.0101 count:105 value:0.0100 count:110 value:0.0100
Reduce on plateau#
- optax.contrib.reduce_on_plateau(factor: float = 0.1, patience: jax.typing.ArrayLike = 10, rtol: float = 0.0001, atol: float = 0.0, cooldown: jax.typing.ArrayLike = 0, accumulation_size: jax.typing.ArrayLike = 1, min_scale: jax.typing.ArrayLike = 0.0) base.GradientTransformationExtraArgs[source]#
Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a
patiencenumber of epochs, the learning rate is reduced.- Parameters:
factor β Factor by which to reduce the learning rate. new_scale = scale * factor.
patience β Number of iterations with no improvement after which learning rate will be reduced.
rtol β Relative tolerance for measuring new optimum.
atol β Absolute tolerance for measuring new optimum.
cooldown β Number of iterations to wait before resuming normal operation after scale has been reduced.
accumulation_size β Number of values to aggregate before applying the logic of reduce on plateau. If the value fed to the optimizer is a test value, simply take 1 (default). If the value fed to the optimizer is the loss on a the current minibatch, consider using a larger accumulation size.
min_scale β Scale at which the learning rate decay stops.
- Returns:
A
optax.GradientTransformationExtraArgsobject.
See also
Warm restarts#
- optax.schedules.sgdr_schedule(cosine_kwargs: Iterable[dict[str, jax.typing.ArrayLike]]) base.Schedule[source]#
SGD with warm restarts.
This learning rate schedule applies multiple joined cosine decay cycles.
- Parameters:
cosine_kwargs β An Iterable of dicts, where each element specifies the arguments to pass to each cosine decay cycle. The
decay_stepskwarg will specify how long each cycle lasts for, and therefore when to transition to the next cycle.- Returns:
- schedule
A function that maps step counts to values
References
Loshchilov et al., SGDR: Stochastic Gradient Descent with Warm Restarts, 2017