optax.tree_utils.tree_cast#
- optax.tree_utils.tree_cast(tree: optax.ArrayTree, dtype: str | type[Any] | dtype | SupportsDType | None) optax.ArrayTree[source]#
Cast tree to given dtype, skip if None.
- Parameters:
tree โ the tree to cast.
dtype โ the dtype to cast to, or None to skip.
- Returns:
the tree, with leaves cast to dtype.
Examples
>>> import jax.numpy as jnp >>> import optax >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, ... 'c': jnp.array(2.0, dtype=jnp.float32)} >>> optax.tree_utils.tree_cast(tree, dtype=jnp.bfloat16) {'a': {'b': Array(1, dtype=bfloat16)}, 'c': Array(2, dtype=bfloat16)}