optax.skip_large_updates#
- optax.skip_large_updates(updates: optax.Updates, gradient_step: jax.typing.ArrayLike, params: TypeAliasForwardRef('optax.Params') | None, max_squared_norm: jax.typing.ArrayLike) tuple[Array, TypeAliasForwardRef('optax.ArrayTree')][source]#
Returns True if the global norm square of updates is small enough.
- Parameters:
updates โ see
ShouldSkipUpdateFunction.gradient_step โ see
ShouldSkipUpdateFunction.params โ see
ShouldSkipUpdateFunction.max_squared_norm โ max square norm that can be accepted in updates.
- Returns:
First element is a scalar array of type bool.
Second element is a dictionary with keys: - should_skip: iff ||updates||^2 is greater than max_squared_norm. - norm_squared: overall norm square of the updates.
- Return type:
A tuple