optax.losses.cosine_distance

optax.losses.cosine_distance#

optax.losses.cosine_distance(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, epsilon: jax.typing.ArrayLike = 0.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Computes the cosine distance between targets and predictions.

The cosine distance, implemented here, measures the dissimilarity of two vectors as the opposite of cosine similarity: 1 - cos(theta).

Parameters:
  • predictions โ€“ The predicted vectors, with shape [โ€ฆ, dim].

  • targets โ€“ Ground truth target vectors, with shape [โ€ฆ, dim].

  • epsilon โ€“ minimum norm for terms in the denominator of the cosine similarity.

  • axis โ€“ Axis or axes along which to compute.

  • where โ€“ Elements to include in the computation.

Returns:

cosine distances, with shape [โ€ฆ].

References

Cosine distance, Wikipedia.

Changed in version 0.2.4: Added axis and where arguments.