optax.tree_utils.tree_scale# optax.tree_utils.tree_scale(scalar: jax.typing.ArrayLike, tree: Any) → Any[source]# Multiply a tree by a scalar. In infix notation, the function performs out = scalar * tree. Parameters: scalar โ scalar value. tree โ pytree. Returns: a pytree with the same structure as tree.