optax.apply_if_finite

optax.apply_if_finite#

optax.apply_if_finite(inner: optax.GradientTransformation, max_consecutive_errors: int) optax.GradientTransformation[source]#

A function that wraps an optimizer to make it robust to a few NaNs or Infs.

The purpose of this function is to prevent any optimization to happen if the gradients contain NaNs or Infs. That is, when a NaN or Inf is detected in the gradients, the wrapped optimizer ignores that gradient update. If the NaNs or Infs persist after a given number of updates, the wrapped optimizer gives up and accepts the update.

Parameters:
  • inner โ€“ Inner transformation to be wrapped.

  • max_consecutive_errors โ€“ Maximum number of consecutive gradient updates containing NaNs or Infs that the wrapped optimizer will ignore. After that many ignored updates, the optimizer will give up and accept.

Returns:

New optax.GradientTransformationExtraArgs.