optax.tree_utils.tree_scale

Contents

optax.tree_utils.tree_scale#

optax.tree_utils.tree_scale(scalar: jax.typing.ArrayLike, tree: Any) Any[source]#

Multiply a tree by a scalar.

In infix notation, the function performs out = scalar * tree.

Parameters:
  • scalar โ€“ scalar value.

  • tree โ€“ pytree.

Returns:

a pytree with the same structure as tree.