optax.centralize#
- optax.centralize() optax.GradientTransformation[source]#
Centralizes gradients by subtracting their mean along leading dimension.
- Returns:
A
optax.GradientTransformationobject.
Example
>>> import jax.numpy as jnp >>> import optax >>> grad = jnp.array([[1, 2, 3], [4, 5, 6]]) >>> opt = optax.centralize() >>> state = opt.init(grad) >>> updates, state = opt.update(grad, state) >>> print(updates) [[-1. 0. 1.] [-1. 0. 1.]] >>> print(state) EmptyState()
References
Yong et al, Gradient Centralization: A New Optimization Technique for Deep Neural Networks, 2020.