optax.scale_by_param_block_rms#
- optax.scale_by_param_block_rms(min_scale: jax.typing.ArrayLike = 0.001) optax.GradientTransformation[source]#
Scale updates by rms of the gradient for each param vector or matrix.
A block is here a weight vector (e.g. in a Linear layer) or a weight matrix (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree.
- Parameters:
min_scale โ Minimum scaling factor.
- Returns:
A
optax.GradientTransformationobject.