optax.stateless_with_tree_map#
- optax.stateless_with_tree_map(f: Callable[[TypeAliasForwardRef('jax.typing.ArrayLike'), TypeAliasForwardRef('jax.typing.ArrayLike') | None], TypeAliasForwardRef('jax.typing.ArrayLike')]) GradientTransformation[source]#
Creates a stateless transformation from an update-like function for arrays.
This wrapper eliminates the boilerplate needed to create a transformation that does not require saved state between iterations, just like optax.stateless. In addition, this function will apply the tree map over update/params for you.
- Parameters:
f โ Update function that takes in an update array (e.g. gradients) and parameter array and returns an update array. The parameter array may be None.
- Returns: