optax.per_example_global_norm_clip

optax.per_example_global_norm_clip#

optax.per_example_global_norm_clip(grads: optax.ArrayTree, l2_norm_clip: jax.typing.ArrayLike) tuple[TypeAliasForwardRef('optax.ArrayTree'), Array][source]#

Applies gradient clipping per-example using their global norm.

Parameters:
  • grads โ€“ flattened update; the function expects each array in this list to have a batch dimension on the 0th axis.

  • l2_norm_clip โ€“ maximum L2 norm of the per-example gradients.

Returns:

A tuple containing sum of the clipped per-example grads, and the number of per-example grads that were clipped.

Example

>>> import optax
>>> import jax.numpy as jnp
>>> grads = [jnp.array([[0, 0, 0], [0, 3, 4], [4, 0, 3], [3, 4, 0]])]
>>> optax.per_example_global_norm_clip(grads, jnp.inf)
([Array([7., 7., 7.], dtype=float32)], Array(0, dtype=int32))
>>> optax.per_example_global_norm_clip(grads, 0.0)
([Array([0., 0., 0.], dtype=float32)], Array(3, dtype=int32))
>>> optax.per_example_global_norm_clip(grads, 1.25)
([Array([1.75, 1.75, 1.75], dtype=float32)], Array(3, dtype=int32))

References

Abadi et al., Deep Learning with Differential Privacy, 2016

See also

optax.contrib.differentially_private_aggregate() for more realistic example usages.