optax.losses.kl_divergence#
- optax.losses.kl_divergence(log_predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Computes the Kullback-Leibler divergence (relative entropy) loss.
Measures the information gain achieved if target probability distribution would be used instead of predicted probability distribution.
- Parameters:
log_predictions โ Probabilities of predicted distribution with shape [โฆ, dim]. Expected to be in the log-space to avoid underflow.
targets โ Probabilities of target distribution with shape [โฆ, dim]. Expected to be strictly positive.
axis โ Axis or axes along which to compute.
where โ Elements to include in the computation.
- Returns:
Kullback-Leibler divergence of predicted distribution from target distribution with shape [โฆ].
References
Kullback and Leibler, On Information and Sufficiency, 1951
Changed in version 0.2.4: Added
axisandwherearguments.