optax.skip_large_updates

optax.skip_large_updates#

optax.skip_large_updates(updates: optax.Updates, gradient_step: jax.typing.ArrayLike, params: TypeAliasForwardRef('optax.Params') | None, max_squared_norm: jax.typing.ArrayLike) tuple[Array, TypeAliasForwardRef('optax.ArrayTree')][source]#

Returns True if the global norm square of updates is small enough.

Parameters:
Returns:

  • First element is a scalar array of type bool.

  • Second element is a dictionary with keys: - should_skip: iff ||updates||^2 is greater than max_squared_norm. - norm_squared: overall norm square of the updates.

Return type:

A tuple