optax.tree_utils.tree_get_all_with_path#
- optax.tree_utils.tree_get_all_with_path(tree: optax.PyTree, key: Any, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any], bool] | None = None) list[tuple[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any]][source]#
Extract values of a pytree matching a given key.
Search in a pytree
treefor a specifickey(which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple).That key/field
keymay appear more than once intree. So this function returns a list of all values corresponding tokeywith the path to that value. The path is a sequence ofKeyEntrythat can be transformed in readable format usingjax.tree_util.keystr(), see the example below.- Parameters:
tree โ tree to search in.
key โ keyword or field to search in tree for.
filtering โ optional callable to further filter values in tree that match the key.
filtering(path: Key_Path, value: Any) -> bool: ...takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()) and the value that match the given key.
- Returns:
- values_with_path
list of tuples where each tuple is of the form (
path_to_value,value). Herevalueis one entry of the tree that corresponds to thekey, andpath_to_valueis a tuple of KeyEntry that is a tuple ofjax.tree_util.DictKey,jax.tree_util.FlattenedIndexKey,jax.tree_util.GetAttrKey,jax.tree_util.SequenceKey, oroptax.tree_utils.NamedTupleKey.
Examples
Basic usage
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> solver = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = solver.init(params) >>> found_values_with_path = optax.tree_utils.tree_get_all_with_path( ... state, 'learning_rate' ... ) >>> print( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) ("InjectStatefulHyperparamsState.hyperparams['learning_rate']", Array(1., dtype=float32)) ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))
Usage with a filtering operation
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> solver = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = solver.init(params) >>> filtering = lambda path, value: isinstance(value, tuple) >>> found_values_with_path = optax.tree_utils.tree_get_all_with_path( ... state, 'learning_rate', filtering ... ) >>> print( ... *[(jax.tree_util.keystr(p), v) for p, v in found_values_with_path], ... sep="\n", ... ) ("InjectStatefulHyperparamsState.hyperparams_states['learning_rate']", WrappedScheduleState(count=Array(0, dtype=int32)))
Added in version 0.2.2.