optax.per_example_global_norm_clip#
- optax.per_example_global_norm_clip(grads: optax.ArrayTree, l2_norm_clip: jax.typing.ArrayLike) tuple[TypeAliasForwardRef('optax.ArrayTree'), Array][source]#
Applies gradient clipping per-example using their global norm.
- Parameters:
grads โ flattened update; the function expects each array in this list to have a batch dimension on the 0th axis.
l2_norm_clip โ maximum L2 norm of the per-example gradients.
- Returns:
A tuple containing sum of the clipped per-example grads, and the number of per-example grads that were clipped.
Example
>>> import optax >>> import jax.numpy as jnp >>> grads = [jnp.array([[0, 0, 0], [0, 3, 4], [4, 0, 3], [3, 4, 0]])] >>> optax.per_example_global_norm_clip(grads, jnp.inf) ([Array([7., 7., 7.], dtype=float32)], Array(0, dtype=int32)) >>> optax.per_example_global_norm_clip(grads, 0.0) ([Array([0., 0., 0.], dtype=float32)], Array(3, dtype=int32)) >>> optax.per_example_global_norm_clip(grads, 1.25) ([Array([1.75, 1.75, 1.75], dtype=float32)], Array(3, dtype=int32))
References
Abadi et al., Deep Learning with Differential Privacy, 2016
See also
optax.contrib.differentially_private_aggregate()for more realistic example usages.