optax.contrib.reduce_on_plateau#
- optax.contrib.reduce_on_plateau(factor: float = 0.1, patience: jax.typing.ArrayLike = 10, rtol: float = 0.0001, atol: float = 0.0, cooldown: jax.typing.ArrayLike = 0, accumulation_size: jax.typing.ArrayLike = 1, min_scale: jax.typing.ArrayLike = 0.0) base.GradientTransformationExtraArgs[source]#
Reduce learning rate when a metric has stopped improving.
Models often benefit from reducing the learning rate once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a
patiencenumber of epochs, the learning rate is reduced.- Parameters:
factor โ Factor by which to reduce the learning rate. new_scale = scale * factor.
patience โ Number of iterations with no improvement after which learning rate will be reduced.
rtol โ Relative tolerance for measuring new optimum.
atol โ Absolute tolerance for measuring new optimum.
cooldown โ Number of iterations to wait before resuming normal operation after scale has been reduced.
accumulation_size โ Number of values to aggregate before applying the logic of reduce on plateau. If the value fed to the optimizer is a test value, simply take 1 (default). If the value fed to the optimizer is the loss on a the current minibatch, consider using a larger accumulation size.
min_scale โ Scale at which the learning rate decay stops.
- Returns:
A
optax.GradientTransformationExtraArgsobject.
See also