optax.tree_utils.tree_max

Contents

optax.tree_utils.tree_max#

optax.tree_utils.tree_max(tree: Any) jax.typing.ArrayLike[source]#

Compute the max of all the elements in a pytree.

Parameters:

tree โ€“ pytree.

Returns:

a scalar value.