optax.tree_utils.tree_add_scale#
- optax.tree_utils.tree_add_scale(tree_x: Any, scalar: jax.typing.ArrayLike, tree_y: Any) Any[source]#
Add two trees, where the second tree is scaled by a scalar.
In infix notation, the function performs
out = tree_x + scalar * tree_y.- Parameters:
tree_x โ first pytree.
scalar โ scalar value.
tree_y โ second pytree.
- Returns:
a pytree with the same structure as
tree_xandtree_y.