optax.skip_not_finite

optax.skip_not_finite#

optax.skip_not_finite(updates: optax.Updates, gradient_step: jax.typing.ArrayLike, params: TypeAliasForwardRef('optax.Params') | None) tuple[Array, TypeAliasForwardRef('optax.ArrayTree')][source]#

Returns True iff any of the updates contains an inf or a NaN.

Parameters:
  • updates โ€“ see ShouldSkipUpdateFunction.

  • gradient_step โ€“ see ShouldSkipUpdateFunction.

  • params โ€“ see ShouldSkipUpdateFunction.

Returns:

  • First element is a scalar array of type bool.

  • Second element is a dictionary with keys: - should_skip: True iff updates contains an inf or a NaN. - num_not_finite: total number of inf and NaN found in updates.

Return type:

A tuple