optax.losses.cosine_distance#
- optax.losses.cosine_distance(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, epsilon: jax.typing.ArrayLike = 0.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Computes the cosine distance between targets and predictions.
The cosine distance, implemented here, measures the dissimilarity of two vectors as the opposite of cosine similarity: 1 - cos(theta).
- Parameters:
predictions โ The predicted vectors, with shape [โฆ, dim].
targets โ Ground truth target vectors, with shape [โฆ, dim].
epsilon โ minimum norm for terms in the denominator of the cosine similarity.
axis โ Axis or axes along which to compute.
where โ Elements to include in the computation.
- Returns:
cosine distances, with shape [โฆ].
References
Cosine distance, Wikipedia.
Changed in version 0.2.4: Added
axisandwherearguments.