optax.losses.cosine_distance

optax.losses.cosine_distance#

optax.losses.cosine_distance(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, epsilon: jax.typing.ArrayLike = 0.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Computes the cosine distance between targets and predictions.

The cosine distance, implemented here, measures the dissimilarity of two vectors as the opposite of cosine similarity: 1 - cos(theta).

Parameters:
  • predictions โ€“ The predicted vectors, with shape [โ€ฆ, dim].

  • targets โ€“ Ground truth target vectors, with shape [โ€ฆ, dim].

  • epsilon โ€“ minimum value used to clip the squared norms in the denominator, for numerical stability. The squared norms (not the norms) are clipped to be at least epsilon, so the effective minimum norm is sqrt(epsilon).

  • axis โ€“ Axis or axes along which to compute.

  • where โ€“ Elements to include in the computation.

Returns:

cosine distances, with shape [โ€ฆ].

References

Cosine distance, Wikipedia.

Changed in version 0.2.4: Added axis and where arguments.