optax.losses.generalized_kl_divergence

optax.losses.generalized_kl_divergence#

optax.losses.generalized_kl_divergence(log_predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Computes the generalized Kullback-Leibler divergence loss.

Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution.

This function generalizes the standard Kullback-Leibler divergence to unnormalized probability distributions. Technically, this is the Bregman divergence generated by the convex function f(x) = x log x - x.

Parameters:
  • log_predictions โ€“ Probabilities of predicted distribution with shape [โ€ฆ, dim]. Expected to be in the log-space to avoid underflow.

  • targets โ€“ Probabilities of target distribution with shape [โ€ฆ, dim]. Expected to be strictly positive.

  • axis โ€“ Axis or axes along which to compute.

  • where โ€“ Elements to include in the computation.

Returns:

Generalized Kullback-Leibler divergence of predicted distribution from target distribution with shape [โ€ฆ].

References

Boyd and Vandenberghe, Convex Optimization, p. 90.

L. M. Bregman, The relaxation method of finding the common point of convex sets and its application to the solution of problems in convex programming, 1967.

Added in version 0.2.4.