optax.ScaleByScheduleState

optax.ScaleByScheduleState#

class optax.ScaleByScheduleState(count: jax.typing.ArrayLike)[source]#

Maintains count for scale scheduling.