optax.TransformInitFn#
- class optax.TransformInitFn(*args, **kwargs)[source]#
A callable type for the init step of a GradientTransformation.
The init step takes a tree of params and uses these to construct an arbitrary structured initial state for the gradient transformation. This may hold statistics of the past updates or any other non static information.