Optimizer Schedules#

constant_schedule(value)

Constructs a constant schedule.

cosine_decay_schedule(init_value,Β decay_steps)

Returns a function which implements cosine learning rate decay.

cosine_onecycle_schedule(transition_steps,Β ...)

Returns a function which implements the onecycle learning rate schedule.

exponential_decay(init_value,Β ...[,Β ...])

Constructs a schedule with either continuous or discrete exponential decay.

join_schedules(schedules,Β boundaries)

Sequentially apply multiple schedules.

linear_onecycle_schedule(transition_steps,Β ...)

Returns a learning rate with three linear phases.

linear_schedule(init_value,Β end_value,Β ...)

Schedule with linear transition from init_value to end_value.

piecewise_constant_schedule(init_value[,Β ...])

Piecewise constant schedule with scaled jumps at specific boundaries.

piecewise_interpolate_schedule(...[,Β ...])

Piecewise interpolated schedule with linear or cosine transitions.

polynomial_schedule(init_value,Β end_value,Β ...)

Constructs a schedule with polynomial transition from init to end value.

sgdr_schedule(cosine_kwargs)

SGD with warm restarts.

warmup_constant_schedule(init_value,Β ...)

Linear warmup followed by constant schedule i.e no decay.

warmup_cosine_decay_schedule(init_value,Β ...)

Linear warmup followed by cosine decay.

warmup_exponential_decay_schedule(...[,Β ...])

Linear warmup followed by exponential decay.

Schedule

InjectHyperparamsState(count,Β hyperparams,Β ...)

Deprecated class kept for backwards compatibility.

inject_hyperparams(inner_factory[,Β ...])

Wrapper to injects stateful hyperparameters into GradientTransformations.

optax.schedules.Schedule#

alias of Callable[[Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray], Array | ndarray | bool | number | bool | int | float | complex | TypedNdArray]

Constant schedule#

optax.schedules.constant_schedule(value: jax.typing.ArrayLike) base.Schedule[source]#

Constructs a constant schedule.

Parameters:

value – value to be held constant throughout.

Returns:

schedule

A function that maps step counts to values.

Examples

>>> schedule_fn = optax.constant_schedule(5)
>>> schedule_fn(0)
5
>>> schedule_fn(100)
5
optax.schedules.warmup_constant_schedule(init_value: jax.typing.ArrayLike, peak_value: jax.typing.ArrayLike, warmup_steps: int) base.Schedule[source]#

Linear warmup followed by constant schedule i.e no decay.

Parameters:
  • init_value – Initial value for the scalar to be annealed.

  • peak_value – Peak value for scalar to be annealed at end of warmup.

  • warmup_steps – Positive integer, the length of the linear warmup.

Returns:

schedule

A function that maps step counts to values

Cosine decay schedule#

optax.schedules.cosine_decay_schedule(init_value: jax.typing.ArrayLike, decay_steps: int, alpha: jax.typing.ArrayLike = 0.0, exponent: jax.typing.ArrayLike = 1.0) base.Schedule[source]#

Returns a function which implements cosine learning rate decay.

This schedule smoothly decreases the learning rate over a specified number of steps (decay_steps). The decay follows a cosine function, with an optional exponent to modify the decay curve. A minimum value (alpha) ensures the learning rate does not drop entirely to zero.

More precisely, the learning rate at iteration \(t\) is given by:

\[\begin{cases} \alpha I + (1 - \alpha) I \left[ \frac{1}{2} \left( 1+ \cos \left( \pi\,\frac{t}{T} \right) \right) \right] ^p & \text{, if } t \leq T \\ \alpha I & \text{, if } t > T \end{cases} \]

where \(T\) is the number of decay steps (decay_steps), \(p\) is the exponent and \(I\) is the initial value (init_value).

Parameters:
  • init_value – An initial value for the learning rate.

  • decay_steps – Positive integer - the number of steps for which to apply the decay for.

  • alpha – The minimum value of the multiplier used to adjust the learning rate. Defaults to 0.0.

  • exponent – The default decay is 0.5 * (1 + cos(pi * t/T)), where t is the current timestep and T is the decay_steps. The exponent modifies this to be (0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.

Returns:

schedule

A function that maps step counts to values.

References

Loshchilov et al., SGDR: Stochastic Gradient Descent with Warm Restarts, 2017

optax.schedules.cosine_onecycle_schedule(transition_steps: int, peak_value: jax.typing.ArrayLike, pct_start: float = 0.3, div_factor: float = 25.0, final_div_factor: float = 10000.0) base.Schedule[source]#

Returns a function which implements the onecycle learning rate schedule.

This schedule increases the learning rate and then decreases it in a cosine-like manner. The number of steps over which the learning rate increases is determined by the pct_start argument. The maximum value of the learning rate is determined by the peak_value argument, the initial value of the learning rate is determined through the formula init_value = peak_value / div_factor, and the final value is determined by the final_div_factor argument.

Parameters:
  • transition_steps – Number of steps over which annealing takes place.

  • peak_value – Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).

  • pct_start – The percentage of the cycle (in number of steps) spent increasing the learning rate.

  • div_factor – Determines the initial value via init_value = peak_value / div_factor.

  • final_div_factor – Determines the final value via final_value = init_value / final_div_factor.

Returns:

schedule

A function that maps step counts to values

References

Smith et al, Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates, 2017

optax.schedules.warmup_cosine_decay_schedule(init_value: jax.typing.ArrayLike, peak_value: jax.typing.ArrayLike, warmup_steps: int, decay_steps: int, end_value: jax.typing.ArrayLike = 0.0, exponent: jax.typing.ArrayLike = 1.0) base.Schedule[source]#

Linear warmup followed by cosine decay.

Parameters:
  • init_value – Initial value for the scalar to be annealed.

  • peak_value – Peak value for scalar to be annealed at end of warmup.

  • warmup_steps – Positive integer, the length of the linear warmup.

  • decay_steps – Positive integer, the total length of the schedule. Note that this includes the warmup time, so the number of steps during which cosine annealing is applied is decay_steps - warmup_steps.

  • end_value – End value of the scalar to be annealed.

  • exponent – The default decay is 0.5 * (1 + cos(pi t/T)), where t is the current timestep and T is decay_steps. The exponent modifies this to be (0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.

Returns:

schedule

A function that maps step counts to values

Exponential decay schedule#

optax.schedules.exponential_decay(init_value: jax.typing.ArrayLike, transition_steps: int, decay_rate: float, transition_begin: int = 0, staircase: bool = False, end_value: jax.typing.ArrayLike | None = None) base.Schedule[source]#

Constructs a schedule with either continuous or discrete exponential decay.

This function applies an exponential decay function to a provided initial value. When count >= transition_begin the function returns the decayed value as:

rate_factor = ((count - transition_begin) / transition_steps)
decayed_value = init_value * (decay_rate ** rate_factor)

If the argument staircase is True then count / transition_steps is an integer division and the decayed value follows a staircase function.

Parameters:
  • init_value – the initial learning rate.

  • transition_steps – must be positive. See the decay computation above.

  • decay_rate – must not be zero. The decay rate.

  • transition_begin – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).

  • staircase – if True, decay the values at discrete intervals.

  • end_value – the value at which the exponential decay stops. When decay_rate < 1, end_value is treated as a lower bound, otherwise as an upper bound. Has no effect when decay_rate = 0.

Returns:

schedule

A function that maps step counts to values.

optax.schedules.warmup_exponential_decay_schedule(init_value: jax.typing.ArrayLike, peak_value: jax.typing.ArrayLike, warmup_steps: int, transition_steps: int, decay_rate: jax.typing.ArrayLike, transition_begin: int = 0, staircase: bool = False, end_value: jax.typing.ArrayLike | None = None) base.Schedule[source]#

Linear warmup followed by exponential decay.

Parameters:
  • init_value – Initial value for the scalar to be annealed.

  • peak_value – Peak value for scalar to be annealed at end of warmup.

  • warmup_steps – Positive integer, the length of the linear warmup.

  • transition_steps – must be positive. See optax.schedules.exponential_decay() for more details.

  • decay_rate – must not be zero. The decay rate.

  • transition_begin – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at peak_value).

  • staircase – if True, decay the values at discrete intervals.

  • end_value – the value at which the exponential decay stops. When decay_rate < 1, end_value is treated as a lower bound, otherwise as an upper bound. Has no effect when decay_rate = 0.

Returns:

schedule

A function that maps step counts to values

Join schedules#

optax.schedules.join_schedules(schedules: Sequence[base.Schedule], boundaries: Sequence[int]) base.Schedule[source]#

Sequentially apply multiple schedules.

Parameters:
  • schedules – A list of callables (expected to be optax schedules). Each schedule will receive a step count indicating the number of steps since the previous boundary transition.

  • boundaries – A list of integers (of length one less than schedules) that indicate when to transition between schedules.

Returns:

A function that maps step counts to values.

Return type:

schedule

Inject hyperparameters#

optax.schedules.inject_hyperparams(inner_factory: Callable[..., base.GradientTransformation], static_args: str | Iterable[str] = (), hyperparam_dtype: jnp.dtype | None = None) Callable[..., base.GradientTransformationExtraArgs][source]#

Wrapper to injects stateful hyperparameters into GradientTransformations.

This wrapper allows you to pass schedules (i.e. a function that returns a numeric value given a step count) instead of constants for hyperparameters. You may only schedule numeric hyperparameters (i.e. boolean flags cannot be scheduled).

This function supports both passing simple schedules that are function exclusively of the step count and also passing stateful schedules that rely on a complex internal state. The state updating can rely on additional information fed to gradient transformations via extra_args.

For example, to use optax.scale_by_adam() with a piecewise linear schedule for beta_1 and constant for beta_2:

>>> import optax
>>> import jax.numpy as jnp
>>> # create a learning rate that increases linearly from 0.1 to 1.0
... # over 100 iterations
>>> linear_schedule = optax.piecewise_interpolate_schedule(
...    'linear', init_value=0.1, boundaries_and_scales={100: 1.})
>>> scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)(
...     b1=linear_schedule, b2=0.99)

You may manually change numeric hyperparameters that were not scheduled through the hyperparams dict in the InjectHyperparamsState:

>>> params, grads = jnp.array(0.), jnp.array(0.)
>>> state = scheduled_adam.init(params)
>>> updates, state = scheduled_adam.update(grads, state)
>>> state.hyperparams['b2'] = 0.95
>>> updates, state = scheduled_adam.update(updates, state)  # uses b2 = 0.95

Manually overriding scheduled hyperparameters will have no effect (e.g. in the code sample above, you cannot manually adjust b1).

Parameters:
  • inner_factory – a function that returns the inner optax.GradientTransformation with dynamic hyperparameters.

  • static_args – a string or iterable of strings specifying which callable parameters are not schedules. inject_hyperparams treats all callables as schedules by default, so if a hyperparameter is a non-schedule callable, you must specify that using this argument.

  • hyperparam_dtype – Optional datatype override. If specified, all float hyperparameters will be cast to this type.

Returns:

A callable that returns a optax.GradientTransformationExtraArgs. This callable accepts the same arguments as inner_factory, except you may provide schedules in place of the constant arguments.

Changed in version 0.1.9: New parameter hyperparam_dtype, the returned callable outputs a GradientTransformationExtraArgs instead of a GradientTransformation.

class optax.schedules.InjectHyperparamsState(count: jnp.ndarray, hyperparams: dict[str, jax.typing.ArrayLike], inner_state: base.OptState)[source]#

Deprecated class kept for backwards compatibility.

Deprecated since version 0.1.9: Use InjectStatefulHyperparamsState instead.

Linear schedules#

optax.schedules.linear_onecycle_schedule(transition_steps: int, peak_value: jax.typing.ArrayLike, pct_start: float = 0.3, pct_final: float = 0.85, div_factor: float = 25.0, final_div_factor: float = 10000.0) base.Schedule[source]#

Returns a learning rate with three linear phases.

  • Phase 1, from iteration 0 to pct_start * transition_steps. The learning rate increases linearly from peak_value / div_factor to peak_value.

  • Phase 2, from iteration pct_start * transition_steps to pct_final * transition_steps. The learning rate decreases linearly from peak_value back to the initial peak_value/div_factor.

  • Phase 3: For the remaining steps, the learning rate interpolates between peak_value/div_factor and peak_value / final_div_factor. If final_div_factor is larger than div_factor, this is a decreasing phase.

Parameters:
  • transition_steps – Number of steps over which annealing takes place.

  • peak_value – Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).

  • pct_start – The percentage of the cycle (in number of steps) spent increasing the learning rate.

  • pct_final – The percentage of the cycle (in number of steps) spent increasing to peak_value then decreasing back to init_value.

  • div_factor – Determines the initial value via init_value = peak_value / div_factor.

  • final_div_factor – Determines the final value via final_value = init_value / final_div_factor.

Returns:

schedule

A function that maps step counts to values

References

Smith et al, Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates, 2017

optax.schedules.linear_schedule(init_value: jax.typing.ArrayLike, end_value: jax.typing.ArrayLike, transition_steps: int, transition_begin: int = 0) base.Schedule[source]#

Schedule with linear transition from init_value to end_value.

More precisely, the learning rate at iteration \(t\) is given by:

\[\begin{cases} I, & \text{if } t < B \\ I + \frac{t - B}{T} (E - I), & \text{if } B \leq t < B + T \\ E, & \text{if } t \geq B + T \end{cases} \]

where \(I\) is the initial value, \(E\) is the end value, \(B\) is the transition begin, and \(T\) is the transition steps.

This schedule is equivalent to optax.polynomial_schedule() with power=1.

Parameters:
  • init_value – initial value for the scalar to be annealed.

  • end_value – end value of the scalar to be annealed.

  • transition_steps – number of steps over which annealing takes place. The scalar starts changing at transition_begin steps and completes the transition by transition_begin + transition_steps steps. If transition_steps <= 0, then the entire annealing process is disabled and the value is held fixed at init_value.

  • transition_begin – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).

Returns:

schedule

A function that maps step counts to values.

Examples

>>> schedule_fn = optax.linear_schedule(
...    init_value=1.0, end_value=0.01, transition_steps=100)
>>> schedule_fn(0)  # learning rate on the first iteration
Array(1., dtype=float32, weak_type=True)
>>> schedule_fn(100)  # learning rate on the last iteration
Array(0.01, dtype=float32, weak_type=True)

Piecewise schedules#

optax.schedules.piecewise_constant_schedule(init_value: jax.typing.ArrayLike, boundaries_and_scales: dict[int, float] | None = None) base.Schedule[source]#

Piecewise constant schedule with scaled jumps at specific boundaries.

At each step t, this schedule returns init_value scaled by the product of all factors f_i such that t >= b_i, where (b_i, f_i) are the entries in boundaries_and_scales.

Parameters:
  • init_value – The starting value of the schedule.

  • boundaries_and_scales – Dictionary of {step: scale} where scale is multiplied into the schedule value at the given step. All scale values must be non-negative.

Returns:

A function that maps step index to schedule value.

Example

>>> sched = optax.piecewise_constant_schedule(
...     init_value=1.0, boundaries_and_scales={100: 0.1, 200: 0.01})
>>> print(sched(50))   # before first boundary
1.0
>>> print(sched(150))  # after first boundary
0.1
>>> print(sched(250))  # after second boundary
0.001
optax.schedules.piecewise_interpolate_schedule(interpolate_type: str, init_value: jax.typing.ArrayLike, boundaries_and_scales: dict[int, float] | None = None) base.Schedule[source]#

Piecewise interpolated schedule with linear or cosine transitions.

This schedule interpolates smoothly between values scaled at specified step boundaries. The interpolation occurs between the accumulated scaled values, not the raw multiplicative scales themselves.

At each boundary, the scaling factor is multiplied into the current value. Between these boundaries, the schedule applies either linear or cosine interpolation.

Parameters:
  • interpolate_type – Either β€˜linear’ or β€˜cosine’, specifying the interpolation method used between boundary segments.

  • init_value – Starting value for the schedule.

  • boundaries_and_scales – A dictionary {step: scale} that defines the boundaries and their corresponding multiplicative scaling factors.

Returns:

A function that maps step counts to interpolated values.

Example

>>> sched = optax.piecewise_interpolate_schedule(
...     interpolate_type='linear',
...     init_value=1.0,
...     boundaries_and_scales={4: 0.1, 8: 0.01}
... )
>>> for step in range(0, 11, 2):
...     print(f"Step {step}: {sched(step):.6f}")
Step 0: 1.000000
Step 2: 0.550000
Step 4: 0.100000
Step 6: 0.050500
Step 8: 0.001000
Step 10: 0.001000

Note

This schedule accumulates scaling factors and then interpolates between those accumulated values. It does not interpolate directly between the provided scales. This behavior can appear counterintuitive but allows for smooth and controlled transitions in learning rates.

Polynomial schedules#

optax.schedules.polynomial_schedule(init_value: jax.typing.ArrayLike, end_value: jax.typing.ArrayLike, power: jax.typing.ArrayLike, transition_steps: int, transition_begin: int = 0) base.Schedule[source]#

Constructs a schedule with polynomial transition from init to end value.

This function transitions the learning rate from an initial value (init_value) to a final value (end_value) over a specified number of steps (transition_steps) with a polynomial function of power power. The transition can optionally begin after a specified number of initial steps (transition_begin).

More precisely, the learning rate at iteration \(t\) is given by:

\[\begin{cases} I, & \text{if } t < B \\ (I - E) \left( 1 - \frac{t - B}{T} \right)^{P} + E, & \text{if } B \leq t < B + T \\ E, & \text{if } t \geq B + T \end{cases} \]

where \(I\) is the initial value, \(E\) is the end value, \(B\) is the transition begin, \(T\) is the transition steps, and \(P\) is the power used for the polynomial transition.

Parameters:
  • init_value – initial value for the scalar to be annealed.

  • end_value – end value of the scalar to be annealed.

  • power – the power of the polynomial used to transition from init to end.

  • transition_steps – number of steps over which annealing takes place. The scalar starts changing at transition_begin steps and completes the transition by transition_begin + transition_steps steps. If transition_steps <= 0, then the entire annealing process is disabled and the value is held fixed at init_value.

  • transition_begin – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).

Returns:

schedule

A function that maps step counts to values.

Examples

>>> schedule_fn = optax.polynomial_schedule(
...    init_value=1.0, end_value=0.01, transition_steps=100, power=2)
>>> schedule_fn(0)  # learning rate on the first iteration
Array(1., dtype=float32, weak_type=True)
>>> schedule_fn(100)  # learning rate on the last iteration
Array(0.01, dtype=float32, weak_type=True)

The following example uses a non-zero transition_begin. In this case the learning rate is kept constant for the first transition_begin iterations:

>>> schedule_fn = optax.polynomial_schedule(
...    init_value=1.0,
...    end_value=0.01,
...    transition_steps=100,
...    transition_begin=5,
...    power=2,
... )
>>> counts = [0, 5, 6, 104, 105, 110]
>>> print(
...    *[f'count:{i} value:{schedule_fn(i):.4f}' for i in counts],
...    sep='\n')
count:0 value:1.0000
count:5 value:1.0000
count:6 value:0.9803
count:104 value:0.0101
count:105 value:0.0100
count:110 value:0.0100

Reduce on plateau#

optax.contrib.reduce_on_plateau(factor: float = 0.1, patience: jax.typing.ArrayLike = 10, rtol: float = 0.0001, atol: float = 0.0, cooldown: jax.typing.ArrayLike = 0, accumulation_size: jax.typing.ArrayLike = 1, min_scale: jax.typing.ArrayLike = 0.0) base.GradientTransformationExtraArgs[source]#

Reduce learning rate when a metric has stopped improving.

Models often benefit from reducing the learning rate once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a patience number of epochs, the learning rate is reduced.

Parameters:
  • factor – Factor by which to reduce the learning rate. new_scale = scale * factor.

  • patience – Number of iterations with no improvement after which learning rate will be reduced.

  • rtol – Relative tolerance for measuring new optimum.

  • atol – Absolute tolerance for measuring new optimum.

  • cooldown – Number of iterations to wait before resuming normal operation after scale has been reduced.

  • accumulation_size – Number of values to aggregate before applying the logic of reduce on plateau. If the value fed to the optimizer is a test value, simply take 1 (default). If the value fed to the optimizer is the loss on a the current minibatch, consider using a larger accumulation size.

  • min_scale – Scale at which the learning rate decay stops.

Returns:

A optax.GradientTransformationExtraArgs object.

Warm restarts#

optax.schedules.sgdr_schedule(cosine_kwargs: Iterable[dict[str, jax.typing.ArrayLike]]) base.Schedule[source]#

SGD with warm restarts.

This learning rate schedule applies multiple joined cosine decay cycles.

Parameters:

cosine_kwargs – An Iterable of dicts, where each element specifies the arguments to pass to each cosine decay cycle. The decay_steps kwarg will specify how long each cycle lasts for, and therefore when to transition to the next cycle.

Returns:

schedule

A function that maps step counts to values

References

Loshchilov et al., SGDR: Stochastic Gradient Descent with Warm Restarts, 2017