optax.adaptive_grad_clip#
- optax.adaptive_grad_clip(clipping: jax.typing.ArrayLike, eps: jax.typing.ArrayLike = 0.001, axis: int | tuple[int, ...] | None = None) optax.GradientTransformation[source]#
Clips updates to be at most
clipping * parameter_norm, unit-wise.- Parameters:
clipping โ The maximum allowed ratio of update norm to parameter norm.
eps โ An epsilon term to prevent clipping of zero-initialized params.
axis โ Axis or axes along which to compute the unit-wise norm. If None, uses default behavior based on input dimensions (including Conv3D, ndim=5). Provide axis for custom parameter shapes beyond the defaults.
- Returns:
A
optax.GradientTransformationobject.
References
Brock et al., High-Performance Large-Scale Image Recognition Without Normalization, 2021