optax.losses.cosine_similarity

optax.losses.cosine_similarity#

optax.losses.cosine_similarity(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, epsilon: jax.typing.ArrayLike = 0.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Computes the cosine similarity between targets and predictions.

The cosine similarity is a measure of similarity between vectors defined as the cosine of the angle between them, which is also the inner product of those vectors normalized to have unit norm.

Parameters:
  • predictions โ€“ The predicted vectors, with shape [โ€ฆ, dim].

  • targets โ€“ Ground truth target vectors, with shape [โ€ฆ, dim].

  • epsilon โ€“ minimum value used to clip the squared norms in the denominator, for numerical stability. The squared norms (not the norms) are clipped to be at least epsilon, so the effective minimum norm is sqrt(epsilon).

  • axis โ€“ Axis or axes along which to compute.

  • where โ€“ Elements to include in the computation.

Returns:

cosine similarity measures, with shape [โ€ฆ].

References

Cosine similarity, Wikipedia.

Changed in version 0.2.4: Added axis and where arguments.