optax.tree_utils.tree_get#
- optax.tree_utils.tree_get(tree: optax.PyTree, key: Any, default: Any | None = None, filtering: Callable[[Tuple[DictKey | FlattenedIndexKey | GetAttrKey | SequenceKey | NamedTupleKey, ...], Any], bool] | None = None) Any[source]#
Extract a value from a pytree matching a given key.
Search in the
treefor a specifickey(which can be a key from a dictionary, a field from a NamedTuple or the name of a NamedTuple).If the
treedoes not containkeyreturnsdefault.Raises a
KeyErrorif multiple values ofkeyare found intree.Generally, you may first get all pairs
(path_to_value, value)for a givenkeyusingoptax.tree_utils.tree_get_all_with_path(). You may then define a filtering operationfiltering(path: Key_Path, value: Any) -> bool: ...that enables you to select the specific values you wanted to fetch by looking at the type of the value, or looking at the path to that value. Note that contrarily to the paths returned byjax.tree_util.tree_leaves_with_path()the paths analyzed by the filtering operation inoptax.tree_utils.tree_get_all_with_path(),optax.tree_utils.tree_get(), oroptax.tree_utils.tree_set()detail the names of the named tuples considered in the path. Concretely, if the value considered is in the attributekeyof a named tuple calledMyNamedTuplethe last element of the path will be aoptax.tree_utils.NamedTupleKeycontaining bothname=keyandtuple_name='MyNamedTuple'. That way you may distinguish between identical values in different named tuples (arising for example when chaining transformations in optax). See the last example below.- Parameters:
tree โ tree to search in.
key โ keyword or field to search in
treefor.default โ default value to return if
keyis not found intree.filtering โ optional callable to further filter values in
treethat match thekey.filtering(path: Key_Path, value: Any) -> bool: ...takes as arguments both the path to the value (as returned byoptax.tree_utils.tree_get_all_with_path()) and the value that match the given key.
- Returns:
- value
value in
treematching the givenkey. If none are found returndefaultvalue. If multiple are found raises an error.
- Raises:
KeyError โ If multiple values of
keyare found intree.
Examples
Basic usage
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.adam(learning_rate=1.) >>> state = opt.init(params) >>> count = optax.tree_utils.tree_get(state, 'count') >>> print(count) 0
Usage with a filtering operation
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.inject_hyperparams(optax.sgd)( ... learning_rate=lambda count: 1/(count+1) ... ) >>> state = opt.init(params) >>> filtering = lambda path, value: isinstance(value, jnp.ndarray) >>> lr = optax.tree_utils.tree_get( ... state, 'learning_rate', filtering=filtering ... ) >>> print(lr) 1.0
Extracting a named tuple by its name
>>> params = jnp.array([1., 2., 3.]) >>> opt = optax.chain( ... optax.add_noise(1.0, 0.9, key=0), ... optax.scale_by_adam() ... ) >>> state = opt.init(params) >>> noise_state = optax.tree_utils.tree_get(state, 'AddNoiseState') >>> print(noise_state) AddNoiseState(count=Array(0, dtype=int32), rng_key=Array((), dtype=key<fry>) overlaying: [0 0])
Differentiating between two values by the name of their named tuples.
>>> import jax.numpy as jnp >>> import optax >>> params = jnp.array([1., 2., 3.]) >>> opt = optax.chain( ... optax.add_noise(1.0, 0.9, key=0), ... optax.scale_by_adam() ... )
Added in version 0.2.2.