optax.tree_utils.tree_split_key_like

optax.tree_utils.tree_split_key_like#

optax.tree_utils.tree_split_key_like(rng_key: base.PRNGKey, target_tree: base.ArrayTree) base.ArrayTree[source]#

Split keys to match structure of target tree.

Parameters:
  • rng_key โ€“ the key to split.

  • target_tree โ€“ the tree whose structure to match.

Returns:

a tree of rng keys.