optax.apply_if_finite#
- optax.apply_if_finite(inner: optax.GradientTransformation, max_consecutive_errors: int) optax.GradientTransformation[source]#
A function that wraps an optimizer to make it robust to a few NaNs or Infs.
The purpose of this function is to prevent any optimization to happen if the gradients contain NaNs or Infs. That is, when a NaN or Inf is detected in the gradients, the wrapped optimizer ignores that gradient update. If the NaNs or Infs persist after a given number of updates, the wrapped optimizer gives up and accepts the update.
- Parameters:
inner โ Inner transformation to be wrapped.
max_consecutive_errors โ Maximum number of consecutive gradient updates containing NaNs or Infs that the wrapped optimizer will ignore. After that many ignored updates, the optimizer will give up and accept.
- Returns: