optax.adaptive_grad_clip

optax.adaptive_grad_clip#

optax.adaptive_grad_clip(clipping: jax.typing.ArrayLike, eps: jax.typing.ArrayLike = 0.001, axis: int | tuple[int, ...] | None = None) optax.GradientTransformation[source]#

Clips updates to be at most clipping * parameter_norm, unit-wise.

Parameters:
  • clipping โ€“ The maximum allowed ratio of update norm to parameter norm.

  • eps โ€“ An epsilon term to prevent clipping of zero-initialized params.

  • axis โ€“ Axis or axes along which to compute the unit-wise norm. If None, uses default behavior based on input dimensions (including Conv3D, ndim=5). Provide axis for custom parameter shapes beyond the defaults.

Returns:

A optax.GradientTransformation object.

References

Brock et al., High-Performance Large-Scale Image Recognition Without Normalization, 2021