optax.tree_utils.tree_full_like

Contents

optax.tree_utils.tree_full_like#

optax.tree_utils.tree_full_like(tree: Any, fill_value: jax.typing.ArrayLike, dtype: str | type[Any] | dtype | SupportsDType | None = None) Any[source]#

Creates an identical tree where all tensors are filled with fill_value.

Parameters:
  • tree โ€“ pytree.

  • fill_value โ€“ the fill value for all tensors in the tree.

  • dtype โ€“ optional dtype to use for the tensors in the tree.

Returns:

an tree with the same structure as tree.