optax.tree_utils.tree_min

Contents

optax.tree_utils.tree_min#

optax.tree_utils.tree_min(tree: Any) jax.typing.ArrayLike[source]#

Compute the min of all the elements in a pytree.

Parameters:

tree โ€“ pytree.

Returns:

a scalar value.