optax.tree_utils.tree_dtype#
- optax.tree_utils.tree_dtype(tree: optax.ArrayTree, mixed_dtype_handler: str | None = None) str | type[Any] | dtype | SupportsDType[source]#
Fetch dtype of tree.
If the tree is empty, returns the default dtype of JAX arrays.
- Parameters:
tree – the tree to fetch the dtype of.
mixed_dtype_handler – how to handle mixed dtypes in the tree. - If
mixed_dtype_handler=None, returns the common dtype of the leaves of the tree if it exists, otherwise raises an error. - Ifmixed_dtype_handler='promote', promotes the dtypes of the leaves of the tree to a common promoted dtype usingjax.numpy.promote_types(). - Ifmixed_dtype_handler='highest'ormixed_dtype_handler='lowest', returns the highest/lowest dtype of the leaves of the tree. We consider a partial ordering of dtypes asdtype1 <= dtype2ifdtype1is promoted todtype2, that is, ifjax.numpy.promote_types(dtype1, dtype2) == dtype2. Since some dtypes cannot be promoted to one another, this is not a total ordering, and the ‘highest’ or ‘lowest’ options may not be applicable. These options will throw an error if the dtypes of the leaves of the tree cannot be promoted to one another.
- Returns:
the dtype of the tree.
- Raises:
ValueError – If
mixed_dtype_handleris set toNoneand multiple dtypes are found in the tree.ValueError – If
mixed_dtype_handleris set to'highest'or'lowest'and some leaves’ dtypes in the tree cannot be promoted to one another.
Examples
>>> import jax.numpy as jnp >>> import optax >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, ... 'c': jnp.array(2.0, dtype=jnp.float32)} >>> optax.tree_utils.tree_dtype(tree) dtype('float32') >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float16)}, ... 'c': jnp.array(2.0, dtype=jnp.float32)} >>> optax.tree_utils.tree_dtype(tree, 'lowest') dtype('float16') >>> optax.tree_utils.tree_dtype(tree, 'highest') dtype('float32') >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.int32)}, ... 'c': jnp.array(2.0, dtype=jnp.uint32)} >>> # optax.tree_utils.tree_dtype(tree, 'highest') >>> # -> will throw an error because int32 and uint32 >>> # cannot be promoted to one another. >>> optax.tree_utils.tree_dtype(tree, 'promote') dtype('int32')
Added in version 0.2.4.