optax.schedules.linear_schedule#
- optax.schedules.linear_schedule(init_value: jax.typing.ArrayLike, end_value: jax.typing.ArrayLike, transition_steps: int, transition_begin: int = 0) base.Schedule[source]#
Schedule with linear transition from
init_valuetoend_value.More precisely, the learning rate at iteration \(t\) is given by:
\[\begin{cases} I, & \text{if } t < B \\ I + \frac{t - B}{T} (E - I), & \text{if } B \leq t < B + T \\ E, & \text{if } t \geq B + T \end{cases} \]where \(I\) is the initial value, \(E\) is the end value, \(B\) is the transition begin, and \(T\) is the transition steps.
This schedule is equivalent to
optax.polynomial_schedule()withpower=1.- Parameters:
init_value โ initial value for the scalar to be annealed.
end_value โ end value of the scalar to be annealed.
transition_steps โ number of steps over which annealing takes place. The scalar starts changing at
transition_beginsteps and completes the transition bytransition_begin + transition_stepssteps. Iftransition_steps <= 0, then the entire annealing process is disabled and the value is held fixed atinit_value.transition_begin โ must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at
init_value).
- Returns:
- schedule
A function that maps step counts to values.
Examples
>>> schedule_fn = optax.linear_schedule( ... init_value=1.0, end_value=0.01, transition_steps=100) >>> schedule_fn(0) # learning rate on the first iteration Array(1., dtype=float32, weak_type=True) >>> schedule_fn(100) # learning rate on the last iteration Array(0.01, dtype=float32, weak_type=True)