optax.schedules.warmup_exponential_decay_schedule#
- optax.schedules.warmup_exponential_decay_schedule(init_value: jax.typing.ArrayLike, peak_value: jax.typing.ArrayLike, warmup_steps: int, transition_steps: int, decay_rate: jax.typing.ArrayLike, transition_begin: int = 0, staircase: bool = False, end_value: jax.typing.ArrayLike | None = None) base.Schedule[source]#
Linear warmup followed by exponential decay.
- Parameters:
init_value โ Initial value for the scalar to be annealed.
peak_value โ Peak value for scalar to be annealed at end of warmup.
warmup_steps โ Positive integer, the length of the linear warmup.
transition_steps โ must be positive. See
optax.schedules.exponential_decay()for more details.decay_rate โ must not be zero. The decay rate.
transition_begin โ must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at
peak_value).staircase โ if
True, decay the values at discrete intervals.end_value โ the value at which the exponential decay stops. When
decay_rate < 1,end_valueis treated as a lower bound, otherwise as an upper bound. Has no effect whendecay_rate = 0.
- Returns:
- schedule
A function that maps step counts to values