optax.schedules.warmup_constant_schedule

optax.schedules.warmup_constant_schedule#

optax.schedules.warmup_constant_schedule(init_value: jax.typing.ArrayLike, peak_value: jax.typing.ArrayLike, warmup_steps: int) base.Schedule[source]#

Linear warmup followed by constant schedule i.e no decay.

Parameters:
  • init_value โ€“ Initial value for the scalar to be annealed.

  • peak_value โ€“ Peak value for scalar to be annealed at end of warmup.

  • warmup_steps โ€“ Positive integer, the length of the linear warmup.

Returns:

schedule

A function that maps step counts to values