optax.tree_utils.tree_sum#
- optax.tree_utils.tree_sum(tree: Any, associative_reduction: bool = False) jax.typing.ArrayLike[source]#
Compute the sum of all the elements in a pytree.
- Parameters:
tree โ pytree.
associative_reduction โ If True, use reduce_associative for a potential compilation time speedup with large pytrees (requires JAX >= 0.6.0). This changes the order of summation which may result in slightly different floating-point values. Default is False.
- Returns:
a scalar value.