optax.schedules.sgdr_schedule#
- optax.schedules.sgdr_schedule(cosine_kwargs: Iterable[dict[str, jax.typing.ArrayLike]]) base.Schedule[source]#
SGD with warm restarts.
This learning rate schedule applies multiple joined cosine decay cycles.
- Parameters:
cosine_kwargs โ An Iterable of dicts, where each element specifies the arguments to pass to each cosine decay cycle. The
decay_stepskwarg will specify how long each cycle lasts for, and therefore when to transition to the next cycle.- Returns:
- schedule
A function that maps step counts to values
References
Loshchilov et al., SGDR: Stochastic Gradient Descent with Warm Restarts, 2017