optax.schedules.cosine_decay_schedule#
- optax.schedules.cosine_decay_schedule(init_value: jax.typing.ArrayLike, decay_steps: int, alpha: jax.typing.ArrayLike = 0.0, exponent: jax.typing.ArrayLike = 1.0) base.Schedule[source]#
Returns a function which implements cosine learning rate decay.
This schedule smoothly decreases the learning rate over a specified number of steps (
decay_steps). The decay follows a cosine function, with an optional exponent to modify the decay curve. A minimum value (alpha) ensures the learning rate does not drop entirely to zero.More precisely, the learning rate at iteration \(t\) is given by:
\[\begin{cases} \alpha I + (1 - \alpha) I \left[ \frac{1}{2} \left( 1+ \cos \left( \pi\,\frac{t}{T} \right) \right) \right] ^p & \text{, if } t \leq T \\ \alpha I & \text{, if } t > T \end{cases} \]where \(T\) is the number of decay steps (
decay_steps), \(p\) is theexponentand \(I\) is the initial value (init_value).- Parameters:
init_value โ An initial value for the learning rate.
decay_steps โ Positive integer - the number of steps for which to apply the decay for.
alpha โ The minimum value of the multiplier used to adjust the learning rate. Defaults to 0.0.
exponent โ The default decay is
0.5 * (1 + cos(pi * t/T)), wheretis the current timestep andTis thedecay_steps. The exponent modifies this to be(0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.
- Returns:
- schedule
A function that maps step counts to values.
References
Loshchilov et al., SGDR: Stochastic Gradient Descent with Warm Restarts, 2017