optax.centralize

Contents

optax.centralize#

optax.centralize() optax.GradientTransformation[source]#

Centralizes gradients by subtracting their mean along leading dimension.

Returns:

A optax.GradientTransformation object.

Example

>>> import jax.numpy as jnp
>>> import optax
>>> grad = jnp.array([[1, 2, 3], [4, 5, 6]])
>>> opt = optax.centralize()
>>> state = opt.init(grad)
>>> updates, state = opt.update(grad, state)
>>> print(updates)
[[-1.  0.  1.]
 [-1.  0.  1.]]
>>> print(state)
EmptyState()

References

Yong et al, Gradient Centralization: A New Optimization Technique for Deep Neural Networks, 2020.