optax.tree_utils.tree_vdot#
- optax.tree_utils.tree_vdot(tree_x: Any, tree_y: Any) jax.typing.ArrayLike[source]#
Compute the inner product between two pytrees.
- Parameters:
tree_x โ first pytree to use.
tree_y โ second pytree to use.
- Returns:
inner product between
tree_xandtree_y, a scalar value.
Examples
>>> optax.tree_utils.tree_vdot( ... {'a': jnp.array([1, 2]), 'b': jnp.array([1, 2])}, ... {'a': jnp.array([-1, -1]), 'b': jnp.array([1, 1])}, ... ) Array(0, dtype=int32)
Note
We upcast the values to the highest precision to avoid numerical issues.