optax.stateless_with_tree_map

optax.stateless_with_tree_map#

optax.stateless_with_tree_map(f: Callable[[TypeAliasForwardRef('jax.typing.ArrayLike'), TypeAliasForwardRef('jax.typing.ArrayLike') | None], TypeAliasForwardRef('jax.typing.ArrayLike')]) GradientTransformation[source]#

Creates a stateless transformation from an update-like function for arrays.

This wrapper eliminates the boilerplate needed to create a transformation that does not require saved state between iterations, just like optax.stateless. In addition, this function will apply the tree map over update/params for you.

Parameters:

f โ€“ Update function that takes in an update array (e.g. gradients) and parameter array and returns an update array. The parameter array may be None.

Returns:

A optax.GradientTransformation.