optax.GradientTransformationExtraArgs

optax.GradientTransformationExtraArgs#

class optax.GradientTransformationExtraArgs(init: TransformInitFn, update: TransformUpdateFn)[source]#

A specialization of GradientTransformation that supports extra args.

Extends the existing GradientTransformation interface by adding support for passing extra arguments to the update function.

Note that if no extra args are provided, then the API of this function is identical to the case of TransformUpdateFn. This means that we can safely wrap any gradient transformation (that does not support extra args) as one that does. The new gradient transformation will accept (and ignore) any extra arguments that a user might pass to it. This is the behavior implemented by optax.with_extra_args_support().

update#

Overrides the type signature of the update in the base type to accept extra arguments.

Type:

optax._src.base.TransformUpdateExtraArgsFn