optax.tree_utils.tree_set#
- optax.tree_utils.tree_set(tree: optax.PyTree, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any], bool] | None = None, /, **kwargs: Any) optax.PyTree[source]#
Creates a copy of tree with some values replaced as specified by kwargs.
Search in the
treeforkeysin**kwargs(which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple). If such a key is found, replace the corresponding value with the one given in**kwargs.Raises a
KeyErrorif some keys in**kwargsare not present in the tree.- Parameters:
tree โ pytree whose values are to be replaced.
filtering โ optional callable to further filter values in
treethat match the keys to replace.filtering(path: Key_Path, value: Any) -> bool: ...takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()) and the value that match a given key.**kwargs โ dictionary of keys with values to replace in
tree.
- Returns:
- new_tree
new pytree with the same structure as
tree. For each element intreewhose key/field matches a key in**kwargs, its value is set by the corresponding value in**kwargs.
- Raises:
KeyError โ If no values of some key in
**kwargsare found intreeor none of the values satisfy the filtering operation.
Examples
Basic usage
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.adam(learning_rate=1.) >>> state = opt.init(params) >>> print(state) (ScaleByAdamState(count=Array(0, dtype=int32), mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState()) >>> new_state = optax.tree_utils.tree_set(state, count=2.) >>> print(new_state) (ScaleByAdamState(count=2.0, mu=Array([0., 0., 0.], dtype=float32), nu=Array([0., 0., 0.], dtype=float32)), EmptyState())
Usage with a filtering operation
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = opt.init(params) >>> print(state) InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState())) >>> filtering = lambda path, value: isinstance(value, jnp.ndarray) >>> new_state = optax.tree_utils.tree_set( ... state, filtering, learning_rate=jnp.asarray(0.1) ... ) >>> print(new_state) InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
Note
The recommended usage to inject hyperparameters schedules is through
optax.inject_hyperparams(). This function is a helper for other purposes.Added in version 0.2.2.