optax.losses.kl_divergence

Contents

optax.losses.kl_divergence#

optax.losses.kl_divergence(log_predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Computes the Kullback-Leibler divergence (relative entropy) loss.

Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution.

Parameters:
  • log_predictions โ€“ Probabilities of predicted distribution with shape [โ€ฆ, dim]. Expected to be in the log-space to avoid underflow.

  • targets โ€“ Probabilities of target distribution with shape [โ€ฆ, dim]. Expected to be strictly positive.

  • axis โ€“ Axis or axes along which to compute.

  • where โ€“ Elements to include in the computation.

Returns:

Kullback-Leibler divergence of predicted distribution from target distribution with shape [โ€ฆ].

References

Kullback and Leibler, On Information and Sufficiency, 1951

Changed in version 0.2.4: Added axis and where arguments.