optax.losses.cosine_similarity

optax.losses.cosine_similarity#

optax.losses.cosine_similarity(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, epsilon: jax.typing.ArrayLike = 0.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Computes the cosine similarity between targets and predictions.

The cosine similarity is a measure of similarity between vectors defined as the cosine of the angle between them, which is also the inner product of those vectors normalized to have unit norm.

Parameters:
  • predictions โ€“ The predicted vectors, with shape [โ€ฆ, dim].

  • targets โ€“ Ground truth target vectors, with shape [โ€ฆ, dim].

  • epsilon โ€“ minimum norm for terms in the denominator of the cosine similarity.

  • axis โ€“ Axis or axes along which to compute.

  • where โ€“ Elements to include in the computation.

Returns:

cosine similarity measures, with shape [โ€ฆ].

References

Cosine similarity, Wikipedia.

Changed in version 0.2.4: Added axis and where arguments.