optax.losses.cosine_similarity#
- optax.losses.cosine_similarity(predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, *, epsilon: jax.typing.ArrayLike = 0.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Computes the cosine similarity between targets and predictions.
The cosine similarity is a measure of similarity between vectors defined as the cosine of the angle between them, which is also the inner product of those vectors normalized to have unit norm.
- Parameters:
predictions โ The predicted vectors, with shape [โฆ, dim].
targets โ Ground truth target vectors, with shape [โฆ, dim].
epsilon โ minimum norm for terms in the denominator of the cosine similarity.
axis โ Axis or axes along which to compute.
where โ Elements to include in the computation.
- Returns:
cosine similarity measures, with shape [โฆ].
References
Cosine similarity, Wikipedia.
Changed in version 0.2.4: Added
axisandwherearguments.