optax.schedules.polynomial_schedule#
- optax.schedules.polynomial_schedule(init_value: jax.typing.ArrayLike, end_value: jax.typing.ArrayLike, power: jax.typing.ArrayLike, transition_steps: int, transition_begin: int = 0) base.Schedule[source]#
Constructs a schedule with polynomial transition from init to end value.
This function transitions the learning rate from an initial value (
init_value) to a final value (end_value) over a specified number of steps (transition_steps) with a polynomial function of powerpower. The transition can optionally begin after a specified number of initial steps (transition_begin).More precisely, the learning rate at iteration \(t\) is given by:
\[\begin{cases} I, & \text{if } t < B \\ (I - E) \left( 1 - \frac{t - B}{T} \right)^{P} + E, & \text{if } B \leq t < B + T \\ E, & \text{if } t \geq B + T \end{cases} \]where \(I\) is the initial value, \(E\) is the end value, \(B\) is the transition begin, \(T\) is the transition steps, and \(P\) is the power used for the polynomial transition.
- Parameters:
init_value โ initial value for the scalar to be annealed.
end_value โ end value of the scalar to be annealed.
power โ the power of the polynomial used to transition from init to end.
transition_steps โ number of steps over which annealing takes place. The scalar starts changing at
transition_beginsteps and completes the transition bytransition_begin + transition_stepssteps. Iftransition_steps <= 0, then the entire annealing process is disabled and the value is held fixed atinit_value.transition_begin โ must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at
init_value).
- Returns:
- schedule
A function that maps step counts to values.
Examples
>>> schedule_fn = optax.polynomial_schedule( ... init_value=1.0, end_value=0.01, transition_steps=100, power=2) >>> schedule_fn(0) # learning rate on the first iteration Array(1., dtype=float32, weak_type=True) >>> schedule_fn(100) # learning rate on the last iteration Array(0.01, dtype=float32, weak_type=True)
The following example uses a non-zero
transition_begin. In this case the learning rate is kept constant for the firsttransition_beginiterations:>>> schedule_fn = optax.polynomial_schedule( ... init_value=1.0, ... end_value=0.01, ... transition_steps=100, ... transition_begin=5, ... power=2, ... ) >>> counts = [0, 5, 6, 104, 105, 110] >>> print( ... *[f'count:{i} value:{schedule_fn(i):.4f}' for i in counts], ... sep='\n') count:0 value:1.0000 count:5 value:1.0000 count:6 value:0.9803 count:104 value:0.0101 count:105 value:0.0100 count:110 value:0.0100