optax.contrib.reduce_on_plateau

optax.contrib.reduce_on_plateau#

optax.contrib.reduce_on_plateau(factor: float = 0.1, patience: jax.typing.ArrayLike = 10, rtol: float = 0.0001, atol: float = 0.0, cooldown: jax.typing.ArrayLike = 0, accumulation_size: jax.typing.ArrayLike = 1, min_scale: jax.typing.ArrayLike = 0.0) base.GradientTransformationExtraArgs[source]#

Reduce learning rate when a metric has stopped improving.

Models often benefit from reducing the learning rate once learning stagnates. This scheduler reads a metrics quantity and if no improvement is seen for a patience number of epochs, the learning rate is reduced.

Parameters:
  • factor โ€“ Factor by which to reduce the learning rate. new_scale = scale * factor.

  • patience โ€“ Number of iterations with no improvement after which learning rate will be reduced.

  • rtol โ€“ Relative tolerance for measuring new optimum.

  • atol โ€“ Absolute tolerance for measuring new optimum.

  • cooldown โ€“ Number of iterations to wait before resuming normal operation after scale has been reduced.

  • accumulation_size โ€“ Number of values to aggregate before applying the logic of reduce on plateau. If the value fed to the optimizer is a test value, simply take 1 (default). If the value fed to the optimizer is the loss on a the current minibatch, consider using a larger accumulation size.

  • min_scale โ€“ Scale at which the learning rate decay stops.

Returns:

A optax.GradientTransformationExtraArgs object.