optax.schedules.exponential_decay

optax.schedules.exponential_decay#

optax.schedules.exponential_decay(init_value: jax.typing.ArrayLike, transition_steps: int, decay_rate: float, transition_begin: int = 0, staircase: bool = False, end_value: jax.typing.ArrayLike | None = None) base.Schedule[source]#

Constructs a schedule with either continuous or discrete exponential decay.

This function applies an exponential decay function to a provided initial value. When count >= transition_begin the function returns the decayed value as:

rate_factor = ((count - transition_begin) / transition_steps)
decayed_value = init_value * (decay_rate ** rate_factor)

If the argument staircase is True then count / transition_steps is an integer division and the decayed value follows a staircase function.

Parameters:
  • init_value โ€“ the initial learning rate.

  • transition_steps โ€“ must be positive. See the decay computation above.

  • decay_rate โ€“ must not be zero. The decay rate.

  • transition_begin โ€“ must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).

  • staircase โ€“ if True, decay the values at discrete intervals.

  • end_value โ€“ the value at which the exponential decay stops. When decay_rate < 1, end_value is treated as a lower bound, otherwise as an upper bound. Has no effect when decay_rate = 0.

Returns:

schedule

A function that maps step counts to values.