optax.tree_utils.tree_cast_like#
- optax.tree_utils.tree_cast_like(tree: T, other_tree: optax.ArrayTree) T[source]#
Cast tree to dtypes of other_tree.
- Parameters:
tree โ the tree to cast.
other_tree โ reference array tree to use to cast to dtypes of leaves
- Returns:
the tree, with leaves cast to dtype.
Examples
>>> import jax.numpy as jnp >>> import optax >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, ... 'c': jnp.array(2.0, dtype=jnp.float32)} >>> other_tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, ... 'c': jnp.array(2.0, dtype=jnp.bfloat16)} >>> optax.tree_utils.tree_cast_like(tree, other_tree) {'a': {'b': Array(1., dtype=float32)}, 'c': Array(2, dtype=bfloat16)}