optax.losses.triplet_margin_loss

optax.losses.triplet_margin_loss#

optax.losses.triplet_margin_loss(anchors: jax.typing.ArrayLike, positives: jax.typing.ArrayLike, negatives: jax.typing.ArrayLike, axis: int = -1, norm_degree: jax.typing.ArrayLike = 2, margin: jax.typing.ArrayLike = 1.0, eps: jax.typing.ArrayLike = 1e-06) Array[source]#

Returns the triplet loss for a batch of embeddings.

Examples

>>> import jax.numpy as jnp, optax
>>> jnp.set_printoptions(precision=4)
>>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]])
>>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]])
>>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]])
>>> output = optax.losses.triplet_margin_loss(anchors, positives, negatives,
...                                           margin=1.0)
>>> print(output)
[0.1414 0.1414]
Parameters:
  • anchors โ€“ An array of anchor embeddings, with shape [batch, feature_dim].

  • positives โ€“ An array of positive embeddings (similar to anchors), with shape [batch, feature_dim].

  • negatives โ€“ An array of negative embeddings (dissimilar to anchors), with shape [batch, feature_dim].

  • axis โ€“ The axis along which to compute the distances (default is -1).

  • norm_degree โ€“ The norm degree for distance calculation (default is 2 for Euclidean distance).

  • margin โ€“ The minimum margin by which the positive distance should be smaller than the negative distance.

  • eps โ€“ A small epsilon value to ensure numerical stability in the distance calculation.

Returns:

Returns the computed triplet loss as an array.

References

V. Balntas et al, Learning shallow convolutional feature descriptors with triplet losses, 2016.