optax.tree_utils.tree_cast

Contents

optax.tree_utils.tree_cast#

optax.tree_utils.tree_cast(tree: optax.ArrayTree, dtype: str | type[Any] | dtype | SupportsDType | None) optax.ArrayTree[source]#

Cast tree to given dtype, skip if None.

Parameters:
  • tree โ€“ the tree to cast.

  • dtype โ€“ the dtype to cast to, or None to skip.

Returns:

the tree, with leaves cast to dtype.

Examples

>>> import jax.numpy as jnp
>>> import optax
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)},
...         'c': jnp.array(2.0, dtype=jnp.float32)}
>>> optax.tree_utils.tree_cast(tree, dtype=jnp.bfloat16)
{'a': {'b': Array(1, dtype=bfloat16)}, 'c': Array(2, dtype=bfloat16)}