optax.tree_utils.tree_sub

Contents

optax.tree_utils.tree_sub#

optax.tree_utils.tree_sub(tree_x: Any, tree_y: Any) Any[source]#

Subtract two pytrees.

Parameters:
  • tree_x โ€“ first pytree.

  • tree_y โ€“ second pytree.

Returns:

the difference of the two pytrees.