optax.schedules.piecewise_constant_schedule

optax.schedules.piecewise_constant_schedule#

optax.schedules.piecewise_constant_schedule(init_value: jax.typing.ArrayLike, boundaries_and_scales: dict[int, float] | None = None) base.Schedule[source]#

Piecewise constant schedule with scaled jumps at specific boundaries.

At each step t, this schedule returns init_value scaled by the product of all factors f_i such that t >= b_i, where (b_i, f_i) are the entries in boundaries_and_scales.

Parameters:
  • init_value โ€“ The starting value of the schedule.

  • boundaries_and_scales โ€“ Dictionary of {step: scale} where scale is multiplied into the schedule value at the given step. All scale values must be non-negative.

Returns:

A function that maps step index to schedule value.

Example

>>> sched = optax.piecewise_constant_schedule(
...     init_value=1.0, boundaries_and_scales={100: 0.1, 200: 0.01})
>>> print(sched(50))   # before first boundary
1.0
>>> print(sched(150))  # after first boundary
0.1
>>> print(sched(250))  # after second boundary
0.001