optax.tree_utils.tree_max# optax.tree_utils.tree_max(tree: Any) → jax.typing.ArrayLike[source]# Compute the max of all the elements in a pytree. Parameters: tree โ pytree. Returns: a scalar value.