optax.tree_utils.tree_map_params#
- optax.tree_utils.tree_map_params(initable: Callable[[TypeAliasForwardRef('optax.Params')], TypeAliasForwardRef('optax.OptState')] | Initable, f: Callable[[...], Any], state: optax.OptState, /, *rest: Any, transform_non_params: Callable[[...], Any] | None = None, is_leaf: Callable[[TypeAliasForwardRef('optax.Params')], bool] | None = None) optax.OptState[source]#
Apply a callable over all params in the given optimizer state.
This function exists to help construct partition specs over optimizer states, in the case that a partition spec is already known for the parameters.
For example, the following will replace all optimizer state parameter trees with copies of the given partition spec instead. The argument transform_non_params can be used to replace any remaining fields as required, in this case, we replace those fields by None.
>>> params, specs = jnp.array(0.), jnp.array(0.) # Trees with the same shape >>> opt = optax.sgd(1e-3) >>> state = opt.init(params) >>> opt_specs = optax.tree_map_params( ... opt, ... lambda _, spec: spec, ... state, ... specs, ... transform_non_params=lambda _: None, ... )
- Parameters:
initable โ A callable taking parameters and returning an optimizer state, or an object with an init attribute having the same function.
f โ A callable that will be applied for all copies of the parameter tree within this optimizer state.
state โ The optimizer state to map over.
*rest โ Additional arguments, having the same shape as the parameter tree, that will be passed to f.
transform_non_params โ An optional function that will be called on all non-parameter fields within the optimizer state.
is_leaf โ Passed through to jax.tree.map. This makes it possible to ignore parts of the parameter tree e.g. when the gradient transformations modify the shape of the original pytree, such as for
optax.masked.
- Returns:
The result of applying the function f on all trees in the optimizerโs state that have the same shape as the parameter tree, along with the given optional extra arguments.