optax.TransformUpdateFn

optax.TransformUpdateFn#

class optax.TransformUpdateFn(*args, **kwargs)[source]#

A callable type for the update step of a GradientTransformation.

The update step takes a tree of candidate parameter updates (e.g. their gradient with respect to some loss), an arbitrary structured state, and the current params of the model being optimized. The params argument is optional, it must however be provided when using transformations that require access to the current values of the parameters.

For the case where additional arguments are required, an alternative interface may be used, see TransformUpdateExtraArgsFn for details.