optax.losses.ntxent

Contents

optax.losses.ntxent#

optax.losses.ntxent(embeddings: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, temperature: jax.typing.ArrayLike = 0.07) Array[source]#

Normalized temperature scaled cross entropy loss (NT-Xent).

Examples

>>> import jax
>>> import optax
>>> import jax.numpy as jnp
>>>
>>> key = jax.random.key(42)
>>> key1, key2, key3 = jax.random.split(key, 3)
>>> x = jax.random.normal(key1, shape=(4,2))
>>> labels = jnp.array([0, 0, 1, 1])
>>>
>>> print("input:", x)
input: [[ 0.07592554 -0.48634264]
 [ 1.2903206   0.5196119 ]
 [ 0.30040437  0.31034866]
 [ 0.5761609  -0.8074621 ]]
>>> print("labels:", labels)
labels: [0 0 1 1]
>>>
>>> w = jax.random.normal(key2, shape=(2,1)) # params
>>> b = jax.random.normal(key3, shape=(1,)) # params
>>> out = x @ w + b # model
>>>
>>> print("Embeddings:", out)
Embeddings: [[0.08969027]
 [1.6291292 ]
 [0.8622629 ]
 [0.13612625]]
>>> loss = optax.ntxent(out, labels)
>>> print("loss:", loss)
loss: 1.0986123
Parameters:
  • embeddings โ€“ batch of embeddings, with shape [batch, feature_length]

  • labels โ€“ labels for groups that are positive pairs. e.g. if you have a batch of 4 embeddings and the first two and last two were positive pairs your labels should look like [0, 0, 1, 1]. Shape [batch]

  • temperature โ€“ temperature scaling parameter.

Returns:

A scalar loss value of NT-Xent values averaged over all positive pairs

References

T. Chen et al A Simple Framework for Contrastive Learning of Visual Representations, 2020

kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss

Added in version 0.2.3.