optax.losses.ntxent#
- optax.losses.ntxent(embeddings: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, temperature: jax.typing.ArrayLike = 0.07) Array[source]#
Normalized temperature scaled cross entropy loss (NT-Xent).
Examples
>>> import jax >>> import optax >>> import jax.numpy as jnp >>> >>> key = jax.random.key(42) >>> key1, key2, key3 = jax.random.split(key, 3) >>> x = jax.random.normal(key1, shape=(4,2)) >>> labels = jnp.array([0, 0, 1, 1]) >>> >>> print("input:", x) input: [[ 0.07592554 -0.48634264] [ 1.2903206 0.5196119 ] [ 0.30040437 0.31034866] [ 0.5761609 -0.8074621 ]] >>> print("labels:", labels) labels: [0 0 1 1] >>> >>> w = jax.random.normal(key2, shape=(2,1)) # params >>> b = jax.random.normal(key3, shape=(1,)) # params >>> out = x @ w + b # model >>> >>> print("Embeddings:", out) Embeddings: [[0.08969027] [1.6291292 ] [0.8622629 ] [0.13612625]] >>> loss = optax.ntxent(out, labels) >>> print("loss:", loss) loss: 1.0986123
- Parameters:
embeddings โ batch of embeddings, with shape [batch, feature_length]
labels โ labels for groups that are positive pairs. e.g. if you have a batch of 4 embeddings and the first two and last two were positive pairs your labels should look like [0, 0, 1, 1]. Shape [batch]
temperature โ temperature scaling parameter.
- Returns:
A scalar loss value of NT-Xent values averaged over all positive pairs
References
T. Chen et al A Simple Framework for Contrastive Learning of Visual Representations, 2020
kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss
Added in version 0.2.3.