optax.losses.kl_divergence_with_log_targets#
- optax.losses.kl_divergence_with_log_targets(log_predictions: jax.typing.ArrayLike, log_targets: jax.typing.ArrayLike, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Computes the Kullback-Leibler divergence (relative entropy) loss.
Version of kl_div_loss where targets are given in log-space.
- Parameters:
log_predictions โ Probabilities of predicted distribution with shape [โฆ, dim]. Expected to be in the log-space to avoid underflow.
log_targets โ Probabilities of target distribution with shape [โฆ, dim]. Expected to be in the log-space.
axis โ Axis or axes along which to compute.
where โ Elements to include in the computation.
- Returns:
Kullback-Leibler divergence of predicted distribution from target distribution with shape [โฆ].
Changed in version 0.2.4: Added
axisandwherearguments.