optax.losses.triplet_margin_loss#
- optax.losses.triplet_margin_loss(anchors: jax.typing.ArrayLike, positives: jax.typing.ArrayLike, negatives: jax.typing.ArrayLike, axis: int = -1, norm_degree: jax.typing.ArrayLike = 2, margin: jax.typing.ArrayLike = 1.0, eps: jax.typing.ArrayLike = 1e-06) Array[source]#
Returns the triplet loss for a batch of embeddings.
Examples
>>> import jax.numpy as jnp, optax >>> jnp.set_printoptions(precision=4) >>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]]) >>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]]) >>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]]) >>> output = optax.losses.triplet_margin_loss(anchors, positives, negatives, ... margin=1.0) >>> print(output) [0.1414 0.1414]
- Parameters:
anchors โ An array of anchor embeddings, with shape [batch, feature_dim].
positives โ An array of positive embeddings (similar to anchors), with shape [batch, feature_dim].
negatives โ An array of negative embeddings (dissimilar to anchors), with shape [batch, feature_dim].
axis โ The axis along which to compute the distances (default is -1).
norm_degree โ The norm degree for distance calculation (default is 2 for Euclidean distance).
margin โ The minimum margin by which the positive distance should be smaller than the negative distance.
eps โ A small epsilon value to ensure numerical stability in the distance calculation.
- Returns:
Returns the computed triplet loss as an array.
References
V. Balntas et al, Learning shallow convolutional feature descriptors with triplet losses, 2016.