optax.tree_utils.tree_add

Contents

optax.tree_utils.tree_add#

optax.tree_utils.tree_add(tree_x: Any, tree_y: Any, *other_trees: Any) Any[source]#

Add two (or more) pytrees.

Parameters:
  • tree_x โ€“ first pytree.

  • tree_y โ€“ second pytree.

  • *other_trees โ€“ optional other trees to add

Returns:

the sum of the two (or more) pytrees.

Changed in version 0.2.1: Added optional *other_trees argument.