optax.clip_by_block_rms

optax.clip_by_block_rms#

optax.clip_by_block_rms(threshold: jax.typing.ArrayLike) optax.GradientTransformation[source]#

Clips updates to a max rms for the gradient of each param vector or matrix.

A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

Parameters:

threshold โ€“ The maximum rms for the gradient of each param vector or matrix.

Returns:

A optax.GradientTransformation object.