optax.tree_utils.tree_random_like

optax.tree_utils.tree_random_like#

optax.tree_utils.tree_random_like(rng_key: base.PRNGKey, target_tree: base.ArrayTree, sampler: Union[Callable[[base.PRNGKey, base.Shape, jax.typing.DTypeLike], jax.typing.ArrayLike], Callable[[base.PRNGKey, base.Shape, jax.typing.DTypeLike, jax.sharding.Sharding], jax.typing.ArrayLike]] = <function normal>, dtype: Optional[jax.typing.DTypeLike] = None) base.ArrayTree[source]#

Create tree with random entries of the same shape as target tree.

Parameters:
  • rng_key โ€“ the key for the random number generator.

  • target_tree โ€“ the tree whose structure to match. Leaves must be arrays.

  • sampler โ€“ the noise sampling function, by default jax.random.normal.

  • dtype โ€“ the desired dtype for the random numbers, passed to sampler. If None, the dtype of the target tree is used if possible.

Returns:

a random tree with the same structure as target_tree, whose leaves have distribution sampler.

Warning

The possible dtypes may be limited by the sampler, for example jax.random.rademacher only supports integer dtypes and will raise an error if the dtype of the target tree is not an integer or if the dtype is not of integer type.

Added in version 0.2.1.