optax.schedules.exponential_decay#
- optax.schedules.exponential_decay(init_value: jax.typing.ArrayLike, transition_steps: int, decay_rate: float, transition_begin: int = 0, staircase: bool = False, end_value: jax.typing.ArrayLike | None = None) base.Schedule[source]#
Constructs a schedule with either continuous or discrete exponential decay.
This function applies an exponential decay function to a provided initial value. When
count >= transition_beginthe function returns the decayed value as:rate_factor = ((count - transition_begin) / transition_steps) decayed_value = init_value * (decay_rate ** rate_factor)
If the argument
staircaseisTruethencount / transition_stepsis an integer division and the decayed value follows a staircase function.- Parameters:
init_value โ the initial learning rate.
transition_steps โ must be positive. See the decay computation above.
decay_rate โ must not be zero. The decay rate.
transition_begin โ must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).
staircase โ if
True, decay the values at discrete intervals.end_value โ the value at which the exponential decay stops. When
decay_rate < 1,end_valueis treated as a lower bound, otherwise as an upper bound. Has no effect whendecay_rate = 0.
- Returns:
- schedule
A function that maps step counts to values.