optax.tree_utils.tree_add_scale

Contents

optax.tree_utils.tree_add_scale#

optax.tree_utils.tree_add_scale(tree_x: Any, scalar: jax.typing.ArrayLike, tree_y: Any) Any[source]#

Add two trees, where the second tree is scaled by a scalar.

In infix notation, the function performs out = tree_x + scalar * tree_y.

Parameters:
  • tree_x โ€“ first pytree.

  • scalar โ€“ scalar value.

  • tree_y โ€“ second pytree.

Returns:

a pytree with the same structure as tree_x and tree_y.