optax.skip_not_finite#
- optax.skip_not_finite(updates: optax.Updates, gradient_step: jax.typing.ArrayLike, params: TypeAliasForwardRef('optax.Params') | None) tuple[Array, TypeAliasForwardRef('optax.ArrayTree')][source]#
Returns True iff any of the updates contains an inf or a NaN.
- Parameters:
updates โ see ShouldSkipUpdateFunction.
gradient_step โ see ShouldSkipUpdateFunction.
params โ see ShouldSkipUpdateFunction.
- Returns:
First element is a scalar array of type bool.
Second element is a dictionary with keys: - should_skip: True iff updates contains an inf or a NaN. - num_not_finite: total number of inf and NaN found in updates.
- Return type:
A tuple