optax.GradientTransformation#

class optax.GradientTransformation(init: TransformInitFn, update: TransformUpdateFn)[source]#

A pair of pure functions implementing a gradient transformation.

Prefer GradientTransformationExtraArgs for new optimizers.

Optax optimizers are all implemented as gradient transformations. A gradient transformation is defined to be a pair of pure functions, which are combined together in a NamedTuple so that they can be referred to by name.

Note that an extended API is provided for users wishing to build optimizers that take additional arguments during the update step. For more details, see optax.GradientTransformationExtraArgs().

Since gradient transformations do not contain any internal state, all stateful optimizer properties (such as the current step count when using optimizer schedules or momentum values) are passed through optax gradient transformations by using the optimizer state pytree. Each time a gradient transformation is applied, a new state is computed and returned, ready to be passed to the next call to the gradient transformation.

Since gradient transformations are pure functions, the only way to change the behavior of a gradient transformation between steps, is to change the values in the optimizer state. To see an example of mutating the optimizer state in order to control the behavior of an optax gradient transformation see the meta-learning example in the optax documentation.

init#

A pure function which, when called with an example instance of the parameters whose gradients will be transformed, returns a pytree containing the initial value for the optimizer state.

Type:

optax._src.base.TransformInitFn

update#

A pure function which takes as input a pytree of updates (with the same tree structure as the original params pytree passed to init), the previous optimizer state (which may have been initialized using the init function), and optionally the current params. The update function then returns the computed gradient updates, and a new optimizer state.

Type:

optax._src.base.TransformUpdateFn