optax.schedules.warmup_cosine_decay_schedule#
- optax.schedules.warmup_cosine_decay_schedule(init_value: jax.typing.ArrayLike, peak_value: jax.typing.ArrayLike, warmup_steps: int, decay_steps: int, end_value: jax.typing.ArrayLike = 0.0, exponent: jax.typing.ArrayLike = 1.0) base.Schedule[source]#
Linear warmup followed by cosine decay.
- Parameters:
init_value โ Initial value for the scalar to be annealed.
peak_value โ Peak value for scalar to be annealed at end of warmup.
warmup_steps โ Positive integer, the length of the linear warmup.
decay_steps โ Positive integer, the total length of the schedule. Note that this includes the warmup time, so the number of steps during which cosine annealing is applied is
decay_steps - warmup_steps.end_value โ End value of the scalar to be annealed.
exponent โ The default decay is
0.5 * (1 + cos(pi t/T)), wheretis the current timestep andTisdecay_steps. The exponent modifies this to be(0.5 * (1 + cos(pi * t/T))) ** exponent. Defaults to 1.0.
- Returns:
- schedule
A function that maps step counts to values