optax.schedules.sgdr_schedule

Contents

optax.schedules.sgdr_schedule#

optax.schedules.sgdr_schedule(cosine_kwargs: Iterable[dict[str, jax.typing.ArrayLike]]) base.Schedule[source]#

SGD with warm restarts.

This learning rate schedule applies multiple joined cosine decay cycles.

Parameters:

cosine_kwargs โ€“ An Iterable of dicts, where each element specifies the arguments to pass to each cosine decay cycle. The decay_steps kwarg will specify how long each cycle lasts for, and therefore when to transition to the next cycle.

Returns:

schedule

A function that maps step counts to values

References

Loshchilov et al., SGDR: Stochastic Gradient Descent with Warm Restarts, 2017