optax.tree_utils.NamedTupleKey

optax.tree_utils.NamedTupleKey#

class optax.tree_utils.NamedTupleKey(tuple_name: str, name: str)[source]#

KeyType for a NamedTuple in a tree.

When using a function filtering(path: KeyPath, value: Any) -> bool: ... in a tree in optax.tree_utils.tree_get_all_with_path(), optax.tree_utils.tree_get(), or optax.tree_utils.tree_set(), can filter the path to check if of the KeyEntry is a NamedTupleKey and then check if the name of named tuple is the one intended to be searched.

tuple_name#

name of the tuple containing the key.

Type:

str

name#

name of the key.

Type:

str

See also

jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey, jax.tree_util.GetAttrKey, jax.tree_util.SequenceKey, optax.tree_utils.tree_get_all_with_path(), optax.tree_utils.tree_get(), optax.tree_utils.tree_set(),

Added in version 0.2.2.