optax.schedules.warmup_exponential_decay_schedule

optax.schedules.warmup_exponential_decay_schedule#

optax.schedules.warmup_exponential_decay_schedule(init_value: jax.typing.ArrayLike, peak_value: jax.typing.ArrayLike, warmup_steps: int, transition_steps: int, decay_rate: jax.typing.ArrayLike, transition_begin: int = 0, staircase: bool = False, end_value: jax.typing.ArrayLike | None = None) base.Schedule[source]#

Linear warmup followed by exponential decay.

Parameters:
  • init_value โ€“ Initial value for the scalar to be annealed.

  • peak_value โ€“ Peak value for scalar to be annealed at end of warmup.

  • warmup_steps โ€“ Positive integer, the length of the linear warmup.

  • transition_steps โ€“ must be positive. See optax.schedules.exponential_decay() for more details.

  • decay_rate โ€“ must not be zero. The decay rate.

  • transition_begin โ€“ must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at peak_value).

  • staircase โ€“ if True, decay the values at discrete intervals.

  • end_value โ€“ the value at which the exponential decay stops. When decay_rate < 1, end_value is treated as a lower bound, otherwise as an upper bound. Has no effect when decay_rate = 0.

Returns:

schedule

A function that maps step counts to values