Optimizer Schedules#

constant_schedule(value)

Constructs a constant schedule.

cosine_decay_schedule(init_value,Β decay_steps)

Returns a function which implements cosine learning rate decay.

cosine_onecycle_schedule(transition_steps,Β ...)

Returns a function which implements the onecycle learning rate schedule.

exponential_decay(init_value,Β ...[,Β ...])

Constructs a schedule with either continuous or discrete exponential decay.

join_schedules(schedules,Β boundaries)

Sequentially apply multiple schedules.

linear_onecycle_schedule(transition_steps,Β ...)

Returns a function which implements the onecycle learning rate schedule.

linear_schedule(init_value,Β end_value,Β ...)

Schedule with linear transition from init_value to end_value.

piecewise_constant_schedule(init_value[,Β ...])

Returns a function which implements a piecewise constant schedule.

piecewise_interpolate_schedule(...[,Β ...])

Returns a function which implements a piecewise interpolated schedule.

polynomial_schedule(init_value,Β end_value,Β ...)

Constructs a schedule with polynomial transition from init to end value.

sgdr_schedule(cosine_kwargs)

SGD with warm restarts.

warmup_cosine_decay_schedule(init_value,Β ...)

Linear warmup followed by cosine decay.

warmup_exponential_decay_schedule(...[,Β ...])

Linear warmup followed by exponential decay.

Schedule

alias of Callable[[Array | ndarray | bool_ | number | float | int], Array | ndarray | bool_ | number | float | int]

InjectHyperparamsState(count,Β hyperparams,Β ...)

Deprecated class kept for backwards compatibility.

inject_hyperparams(inner_factory[,Β ...])

Wrapper to injects stateful hyperparameters into GradientTransformations.

optax.Schedule#

alias of Callable[[Array | ndarray | bool_ | number | float | int], Array | ndarray | bool_ | number | float | int]

Constant schedule#

optax.constant_schedule(value)[source]#

Constructs a constant schedule.

Parameters:

value (Union[float, int]) – value to be held constant throughout.

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values.

Cosine decay schedule#

optax.cosine_decay_schedule(init_value, decay_steps, alpha=0.0, exponent=1.0)[source]#

Returns a function which implements cosine learning rate decay.

This schedule smoothly decreases the learning rate over a specified number of steps (decay_steps). The decay follows a cosine function, with an optional exponent to modify the decay curve. A minimum value (alpha) ensures the learning rate does not drop entirely to zero.

More precisely, the learning rate at iteration \(t\) is given by:

\[\frac{I (1 - \alpha)}{2}(1+\cos(\pi\,\frac{t}{T})^p) + \alpha\,,\]

where \(T\) is the number of decay steps (decay_steps), \(p\) is the exponent and \(I\) is the initial value (init_value).

References

Loshchilov et al., SGDR: Stochastic Gradient Descent with Warm Restarts, 2017

Parameters:
  • init_value (float) – An initial value for the learning rate.

  • decay_steps (int) – Positive integer - the number of steps for which to apply the decay for.

  • alpha (float) – The minimum value of the multiplier used to adjust the learning rate. Defaults to 0.0.

  • exponent (float) – The default decay is 0.5 * (1 + cos(pi * t/T)), where t is the current timestep and T is the decay_steps. The exponent modifies this to be (0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values.

optax.cosine_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, div_factor=25.0, final_div_factor=10000.0)[source]#

Returns a function which implements the onecycle learning rate schedule.

This learning rate increases the learning rate and then decreases it in a cosine-like manner. The number of steps over which the learning rate increases is determined by the pct_start argument. The maximum value of the learning rate is determined by the peak_value argument, the initial value of the learning rate is determined through the formula init_value = peak_value / div_factor, and the final value is determined by the final_div_factor argument.

References

Smith et al, Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates, 2017

Parameters:
  • transition_steps (int) – Number of steps over which annealing takes place.

  • peak_value (float) – Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).

  • pct_start (float) – The percentage of the cycle (in number of steps) spent increasing the learning rate.

  • div_factor (float) – Determines the initial value via init_value = peak_value / div_factor.

  • final_div_factor (float) – Determines the final value via final_value = init_value / final_div_factor.

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values

Exponential decay schedule#

optax.exponential_decay(init_value, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None)[source]#

Constructs a schedule with either continuous or discrete exponential decay.

This function applies an exponential decay function to a provided initial value. When count >= transition_begin the function returns the decayed value as:

rate_factor = ((count - transition_begin) / transition_steps)
decayed_value = init_value * (decay_rate ** rate_factor)

If the argument staircase is True then count / transition_steps is an integer division and the decayed value follows a staircase function.

Parameters:
  • init_value (float) – the initial learning rate.

  • transition_steps (int) – must be positive. See the decay computation above.

  • decay_rate (float) – must not be zero. The decay rate.

  • transition_begin (int) – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).

  • staircase (bool) – if True, decay the values at discrete intervals.

  • end_value (Optional[float]) – the value at which the exponential decay stops. When decay_rate < 1, end_value is treated as a lower bound, otherwise as an upper bound. Has no effect when decay_rate = 0.

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values.

Join schedules#

optax.join_schedules(schedules, boundaries)[source]#

Sequentially apply multiple schedules.

Parameters:
  • schedules (Sequence[Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]]) – A list of callables (expected to be optax schedules). Each schedule will receive a step count indicating the number of steps since the previous boundary transition.

  • boundaries (Sequence[int]) – A list of integers (of length one less than schedules) that indicate when to transition between schedules.

Returns:

A function that maps step counts to values.

Return type:

schedule

Inject hyperparameters#

optax.inject_hyperparams(inner_factory, static_args=(), hyperparam_dtype=None)[source]#

Wrapper to injects stateful hyperparameters into GradientTransformations.

This wrapper allows you to pass schedules (i.e. a function that returns a numeric value given a step count) instead of constants for hyperparameters. You may only schedule numeric hyperparameters (i.e. boolean flags cannot be scheduled).

This function supports both passing simple schedules that are function exclusively of the step count and also passing stateful schedules that rely on a complex internal state. The state updating can rely on additional information fed to gradient transformations via extra_args.

For example, to use optax.scale_by_adam() with a piecewise linear schedule for beta_1 and constant for beta_2:

>>> import optax
>>> import jax.numpy as jnp
>>> # create a learning rate that increases linearly from 0.1 to 1.0
... # over 100 iterations
>>> linear_schedule = optax.piecewise_interpolate_schedule(
...    'linear', init_value=0.1, boundaries_and_scales={100: 1.})
>>> scheduled_adam = optax.inject_hyperparams(optax.scale_by_adam)(
...     b1=linear_schedule, b2=0.99)

You may manually change numeric hyperparameters that were not scheduled through the hyperparams dict in the InjectHyperparamState:

>>> params, grads = jnp.array(0.), jnp.array(0.)
>>> state = scheduled_adam.init(params)
>>> updates, state = scheduled_adam.update(grads, state)
>>> state.hyperparams['b2'] = 0.95
>>> updates, state = scheduled_adam.update(updates, state)  # uses b2 = 0.95

Manually overriding scheduled hyperparameters will have no effect (e.g. in the code sample above, you cannot manually adjust b1).

Parameters:
  • inner_factory (Callable[..., GradientTransformation]) – a function that returns the inner optax.GradientTransformation with dynamic hyperparameters.

  • static_args (Union[str, Iterable[str]]) – a string or iterable of strings specifying which callable parameters are not schedules. inject_hyperparams treats all callables as schedules by default, so if a hyperparameter is a non-schedule callable, you must specify that using this argument.

  • hyperparam_dtype (Optional[dtype]) – Optional datatype override. If specified, all float hyperparameters will be cast to this type.

Return type:

Callable[..., GradientTransformationExtraArgs]

Returns:

A callable that returns a optax.GradientTransformationExtraArgs. This callable accepts the same arguments as inner_factory, except you may provide schedules in place of the constant arguments.

Changed in version 0.1.9: New parameter hyperparam_dtype, the returned callable outputs a GradientTransformationExtraArgs instead of a GradientTransformation.

class optax.InjectHyperparamsState(count: jnp.ndarray, hyperparams: dict[str, chex.Numeric], inner_state: base.OptState)[source]#

Deprecated class kept for backwards compatibility.

Deprecated since version 0.1.9: Use InjectStatefulHyperparamsState instead.

count: jnp.ndarray#

Alias for field number 0

hyperparams: dict[str, chex.Numeric]#

Alias for field number 1

inner_state: base.OptState#

Alias for field number 2

__getnewargs__()[source]#

Return self as a plain tuple. Used by copy and pickle.

Linear schedules#

optax.linear_onecycle_schedule(transition_steps, peak_value, pct_start=0.3, pct_final=0.85, div_factor=25.0, final_div_factor=10000.0)[source]#

Returns a function which implements the onecycle learning rate schedule.

This function uses a linear annealing strategy.

References

Smith et al, Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates, 2017

Parameters:
  • transition_steps (int) – Number of steps over which annealing takes place.

  • peak_value (float) – Maximum value attained by schedule at pct_start percent of the cycle (in number of steps).

  • pct_start (float) – The percentage of the cycle (in number of steps) spent increasing the learning rate.

  • pct_final (float) – The percentage of the cycle (in number of steps) spent increasing to peak_value then decreasing back to init_value.

  • div_factor (float) – Determines the initial value via init_value = peak_value / div_factor.

  • final_div_factor (float) – Determines the final value via final_value = init_value / final_div_factor.

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values

optax.linear_schedule(init_value, end_value, transition_steps, transition_begin=0)[source]#

Schedule with linear transition from init_value to end_value.

More precisely, the learning rate at iteration \(t\) is given by:

\[\begin{cases} I, & \text{if } t < B \\ I + \frac{t - B}{T} (E - I), & \text{if } B \leq t < B + T \\ E, & \text{if } t \geq B + T \end{cases} \]

where \(I\) is the initial value, \(E\) is the end value, \(B\) is the transition begin, and \(T\) is the transition steps.

This schedule is equivalent to optax.polynomial_schedule() with power=1.

Examples

>>> schedule_fn = optax.linear_schedule(
...    init_value=1.0, end_value=0.01, transition_steps=100)
>>> schedule_fn(0)  # learning rate on the first iteration
Array(1., dtype=float32, weak_type=True)
>>> schedule_fn(100)  # learning rate on the last iteration
Array(0.01, dtype=float32, weak_type=True)
Parameters:
  • init_value (Union[float, int]) – initial value for the scalar to be annealed.

  • end_value (Union[float, int]) – end value of the scalar to be annealed.

  • transition_steps (int) – number of steps over which annealing takes place. The scalar starts changing at transition_begin steps and completes the transition by transition_begin + transition_steps steps. If transition_steps <= 0, then the entire annealing process is disabled and the value is held fixed at init_value.

  • transition_begin (int) – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values.

Piecewise schedules#

optax.piecewise_constant_schedule(init_value, boundaries_and_scales=None)[source]#

Returns a function which implements a piecewise constant schedule.

Parameters:
  • init_value (float) – An initial value init_v.

  • boundaries_and_scales (Optional[dict[int, float]]) – A map from boundaries b_i to non-negative scaling factors f_i. For any step count s, the schedule returns init_v scaled by the product of all factors f_i such that b_i < s.

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values.

optax.piecewise_interpolate_schedule(interpolate_type, init_value, boundaries_and_scales=None)[source]#

Returns a function which implements a piecewise interpolated schedule.

Parameters:
  • interpolate_type (str) – β€˜linear’ or β€˜cosine’, specifying the interpolation strategy.

  • init_value (float) – An initial value init_v.

  • boundaries_and_scales (Optional[dict[int, float]]) – A map from boundaries b_i to non-negative scaling factors f_i. At boundary step b_i, the schedule returns init_v scaled by the product of all factors f_j such that b_j <= b_i. The values in between each boundary will be interpolated as per type.

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values.

Polynomial schedules#

optax.polynomial_schedule(init_value, end_value, power, transition_steps, transition_begin=0)[source]#

Constructs a schedule with polynomial transition from init to end value.

Parameters:
  • init_value (Union[float, int]) – initial value for the scalar to be annealed.

  • end_value (Union[float, int]) – end value of the scalar to be annealed.

  • power (Union[float, int]) – the power of the polynomial used to transition from init to end.

  • transition_steps (int) – number of steps over which annealing takes place. The scalar starts changing at transition_begin steps and completes the transition by transition_begin + transition_steps steps. If transition_steps <= 0, then the entire annealing process is disabled and the value is held fixed at init_value.

  • transition_begin (int) – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values.

Reduce on plateau#

optax.contrib.reduce_on_plateau(factor=0.1, patience=10, rtol=0.0001, atol=0.0, cooldown=0, accumulation_size=1)[source]#

Reduce learning rate when a metric has stopped improving.

Models often benefit from reducing the learning once learning stagnates. his scheduler reads a metrics quantity and if no improvement is seen for a patience number of epochs, the learning rate is reduced.

Parameters:
  • factor (float) – Factor by which to reduce the learning rate. new_scale = scale * factor.

  • patience (int) – Number of iterations with no improvement after which learning rate will be reduced.

  • rtol (float) – Relative tolerance for measuring new optimum.

  • atol (float) – Absolute tolerance for measuring new optimum.

  • cooldown (int) – Number of iterations to wait before resuming normal operation after scale has been reduced.

  • accumulation_size (int) – Number of valeus to aggregate before applying the logic of reduce on plateau. If the value fed to the optimizer is a test value, simply take 1 (default). If the value fed to the optimizer is the loss on a the current minibatch, consider using a larger accumulation size.

Return type:

GradientTransformationExtraArgs

Returns:

A GradientTransformationExtraArgs object.

Schedules with warm-up#

optax.warmup_cosine_decay_schedule(init_value, peak_value, warmup_steps, decay_steps, end_value=0.0, exponent=1.0)[source]#

Linear warmup followed by cosine decay.

Parameters:
  • init_value (float) – Initial value for the scalar to be annealed.

  • peak_value (float) – Peak value for scalar to be annealed at end of warmup.

  • warmup_steps (int) – Positive integer, the length of the linear warmup.

  • decay_steps (int) – Positive integer, the total length of the schedule. Note that this includes the warmup time, so the number of steps during which cosine annealing is applied is decay_steps - warmup_steps.

  • end_value (float) – End value of the scalar to be annealed.

  • exponent (float) – Float. The default decay is 0.5 * (1 + cos(pi t/T)), where t is the current timestep and T is decay_steps. The exponent modifies this to be (0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values

optax.warmup_exponential_decay_schedule(init_value, peak_value, warmup_steps, transition_steps, decay_rate, transition_begin=0, staircase=False, end_value=None)[source]#

Linear warmup followed by exponential decay.

Parameters:
  • init_value (float) – Initial value for the scalar to be annealed.

  • peak_value (float) – Peak value for scalar to be annealed at end of warmup.

  • warmup_steps (int) – Positive integer, the length of the linear warmup.

  • transition_steps (int) – must be positive. See optax.exponential_decay() for more details.

  • decay_rate (float) – must not be zero. The decay rate.

  • transition_begin (int) – must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at peak_value).

  • staircase (bool) – if True, decay the values at discrete intervals.

  • end_value (Optional[float]) – the value at which the exponential decay stops. When decay_rate < 1, end_value is treated as a lower bound, otherwise as an upper bound. Has no effect when decay_rate = 0.

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values

Warm restarts#

optax.sgdr_schedule(cosine_kwargs)[source]#

SGD with warm restarts.

This learning rate schedule applies multiple joined cosine decay cycles.

References

Loshchilov et al., SGDR: Stochastic Gradient Descent with Warm Restarts, 2017

Parameters:

cosine_kwargs (Iterable[dict[str, Union[Array, ndarray, bool_, number, float, int]]]) – An Iterable of dicts, where each element specifies the arguments to pass to each cosine decay cycle. The decay_steps kwarg will specify how long each cycle lasts for, and therefore when to transition to the next cycle.

Return type:

Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]

Returns:

schedule

A function that maps step counts to values