optax.per_example_layer_norm_clip#
- optax.per_example_layer_norm_clip(grads: optax.ArrayTree, global_l2_norm_clip: jax.typing.ArrayLike, uniform: bool = True) tuple[TypeAliasForwardRef('optax.ArrayTree'), TypeAliasForwardRef('optax.ArrayTree')][source]#
Applies gradient clipping per-example using per-layer norms.
If len(grads) == 1, this function is equivalent to optax.per_example_global_norm_clip. If len(grads) > 1, each array in grads will be independently clipped to a value
C_idocumented below.Let
C = global_l2_norm_clip value. Then per-layer clipping is done as follows:1. If
uniformisTrue, each of theKlayers has an individual clip norm ofC / sqrt(K).2. If
uniformisFalse, each of theKlayers has an individual clip norm ofC * sqrt(D_i / D)whereD_iis the number of parameters in layeri, andDis the total number of parameters in the model.- Parameters:
grads โ flattened update; i.e. a list of gradients in which each item is the gradient for one layer; the function expects these to have a batch dimension on the 0th axis.
global_l2_norm_clip โ overall L2 clip norm to use.
uniform โ If True, per-layer clip norm is
global_l2_norm_clip/sqrt(L), whereLis the number of layers. Otherwise, per-layer clip norm isglobal_l2_norm_clip * sqrt(f), wherefis the fraction of total model parameters that are in this layer.
- Returns:
A tuple containing sum of the clipped per-example grads and the number of per-example grads that were clipped for each layer.
Example
>>> import optax >>> import jax.numpy as jnp >>> grads = [jnp.array([[0, 0, 0], [0, 3, 4], [4, 0, 3], [3, 4, 0]])] >>> optax.per_example_layer_norm_clip(grads, jnp.inf) ([Array([7., 7., 7.], dtype=float32)], [Array(0, dtype=int32)]) >>> optax.per_example_layer_norm_clip(grads, 0.0) ([Array([0., 0., 0.], dtype=float32)], [Array(3, dtype=int32)]) >>> optax.per_example_layer_norm_clip(grads, 1.25) ([Array([1.75, 1.75, 1.75], dtype=float32)], [Array(3, dtype=int32)])
References
McMahan et al., Learning Differentially Private Recurrent Language Models, 2017