optax.losses.cosine_distance#
- optax.losses.cosine_distance(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, epsilon: jax.typing.ArrayLike = 0.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Computes the cosine distance between targets and predictions.
The cosine distance, implemented here, measures the dissimilarity of two vectors as the opposite of cosine similarity: 1 - cos(theta).
- Parameters:
predictions โ The predicted vectors, with shape [โฆ, dim].
targets โ Ground truth target vectors, with shape [โฆ, dim].
epsilon โ minimum value used to clip the squared norms in the denominator, for numerical stability. The squared norms (not the norms) are clipped to be at least
epsilon, so the effective minimum norm issqrt(epsilon).axis โ Axis or axes along which to compute.
where โ Elements to include in the computation.
- Returns:
cosine distances, with shape [โฆ].
References
Cosine distance, Wikipedia.
Changed in version 0.2.4: Added
axisandwherearguments.