optax.flatten#
- optax.flatten(inner: base.GradientTransformation) base.GradientTransformationExtraArgs[source]#
Flattens parameters and gradients for init and update of inner transform.
This can reduce the overhead of performing many calculations on lots of small variables, at the cost of slightly increased memory usage.
- Parameters:
inner โ Inner transformation to flatten inputs for.
- Returns: