optax.scale_by_param_block_norm

optax.scale_by_param_block_norm#

optax.scale_by_param_block_norm(min_scale: jax.typing.ArrayLike = 0.001) optax.GradientTransformation[source]#

Scale updates for each param block by the norm of that blockโ€™s parameters.

A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.

Parameters:

min_scale โ€“ Minimum scaling factor.

Returns:

A optax.GradientTransformation object.