optax.tree_utils.tree_vdot

Contents

optax.tree_utils.tree_vdot#

optax.tree_utils.tree_vdot(tree_x: Any, tree_y: Any) jax.typing.ArrayLike[source]#

Compute the inner product between two pytrees.

Parameters:
  • tree_x โ€“ first pytree to use.

  • tree_y โ€“ second pytree to use.

Returns:

inner product between tree_x and tree_y, a scalar value.

Examples

>>> optax.tree_utils.tree_vdot(
...   {'a': jnp.array([1, 2]), 'b': jnp.array([1, 2])},
...   {'a': jnp.array([-1, -1]), 'b': jnp.array([1, 1])},
... )
Array(0, dtype=int32)

Note

We upcast the values to the highest precision to avoid numerical issues.