optax.schedules.linear_schedule

optax.schedules.linear_schedule#

optax.schedules.linear_schedule(init_value: jax.typing.ArrayLike, end_value: jax.typing.ArrayLike, transition_steps: int, transition_begin: int = 0) base.Schedule[source]#

Schedule with linear transition from init_value to end_value.

More precisely, the learning rate at iteration \(t\) is given by:

\[\begin{cases} I, & \text{if } t < B \\ I + \frac{t - B}{T} (E - I), & \text{if } B \leq t < B + T \\ E, & \text{if } t \geq B + T \end{cases} \]

where \(I\) is the initial value, \(E\) is the end value, \(B\) is the transition begin, and \(T\) is the transition steps.

This schedule is equivalent to optax.polynomial_schedule() with power=1.

Parameters:
  • init_value โ€“ initial value for the scalar to be annealed.

  • end_value โ€“ end value of the scalar to be annealed.

  • transition_steps โ€“ number of steps over which annealing takes place. The scalar starts changing at transition_begin steps and completes the transition by transition_begin + transition_steps steps. If transition_steps <= 0, then the entire annealing process is disabled and the value is held fixed at init_value.

  • transition_begin โ€“ must be positive. After how many steps to start annealing (before this many steps the scalar value is held fixed at init_value).

Returns:

schedule

A function that maps step counts to values.

Examples

>>> schedule_fn = optax.linear_schedule(
...    init_value=1.0, end_value=0.01, transition_steps=100)
>>> schedule_fn(0)  # learning rate on the first iteration
Array(1., dtype=float32, weak_type=True)
>>> schedule_fn(100)  # learning rate on the last iteration
Array(0.01, dtype=float32, weak_type=True)