optax.tree_utils.tree_set

Contents

optax.tree_utils.tree_set#

optax.tree_utils.tree_set(tree: optax.PyTree, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any], bool] | None = None, /, **kwargs: Any) optax.PyTree[source]#

Creates a copy of tree with some values replaced as specified by kwargs.

Search in the tree for keys in **kwargs (which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple). If such a key is found, replace the corresponding value with the one given in **kwargs.

Raises a KeyError if some keys in **kwargs are not present in the tree.

Parameters:
  • tree โ€“ pytree whose values are to be replaced.

  • filtering โ€“ optional callable to further filter values in tree that match the keys to replace. filtering(path: Key_Path, value: Any) -> bool: ... takes as arguments both the path to the value (as returned by optax.tree_utils.tree_get_all_with_path()) and the value that match a given key.

  • **kwargs โ€“ dictionary of keys with values to replace in tree.

Returns:

new_tree

new pytree with the same structure as tree. For each element in tree whose key/field matches a key in **kwargs, its value is set by the corresponding value in **kwargs.

Raises:

KeyError โ€“ If no values of some key in **kwargs are found in tree or none of the values satisfy the filtering operation.

Examples

Basic usage

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.adam(learning_rate=1.)
>>> state = opt.init(params)
>>> print(state)
(ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
>>> new_state = optax.tree_utils.tree_set(state, count=2.)
>>> print(new_state)
(ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())

Usage with a filtering operation

>>> import jax.numpy as jnp
>>> import optax
>>> params = jnp.array([1., 2., 3.])
>>> opt = optax.inject_hyperparams(optax.sgd)(
...     learning_rate=lambda count: 1/(count+1)
...  )
>>> state = opt.init(params)
>>> print(state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
>>> filtering = lambda path, value: isinstance(value, jnp.ndarray)
>>> new_state = optax.tree_utils.tree_set(
...   state, filtering, learning_rate=jnp.asarray(0.1)
... )
>>> print(new_state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))

Note

The recommended usage to inject hyperparameters schedules is through optax.inject_hyperparams(). This function is a helper for other purposes.

Added in version 0.2.2.