optax.tree_utils.tree_sum

Contents

optax.tree_utils.tree_sum#

optax.tree_utils.tree_sum(tree: Any, associative_reduction: bool = False) jax.typing.ArrayLike[source]#

Compute the sum of all the elements in a pytree.

Parameters:
  • tree โ€“ pytree.

  • associative_reduction โ€“ If True, use reduce_associative for a potential compilation time speedup with large pytrees (requires JAX >= 0.6.0). This changes the order of summation which may result in slightly different floating-point values. Default is False.

Returns:

a scalar value.