optax.schedules.piecewise_constant_schedule#
- optax.schedules.piecewise_constant_schedule(init_value: jax.typing.ArrayLike, boundaries_and_scales: dict[int, float] | None = None) base.Schedule[source]#
Piecewise constant schedule with scaled jumps at specific boundaries.
At each step t, this schedule returns init_value scaled by the product of all factors f_i such that t >= b_i, where (b_i, f_i) are the entries in boundaries_and_scales.
- Parameters:
init_value โ The starting value of the schedule.
boundaries_and_scales โ Dictionary of {step: scale} where scale is multiplied into the schedule value at the given step. All scale values must be non-negative.
- Returns:
A function that maps step index to schedule value.
Example
>>> sched = optax.piecewise_constant_schedule( ... init_value=1.0, boundaries_and_scales={100: 0.1, 200: 0.01}) >>> print(sched(50)) # before first boundary 1.0 >>> print(sched(150)) # after first boundary 0.1 >>> print(sched(250)) # after second boundary 0.001