optax.losses.kl_divergence_with_log_targets

optax.losses.kl_divergence_with_log_targets#

optax.losses.kl_divergence_with_log_targets(log_predictions: jax.typing.ArrayLike, log_targets: jax.typing.ArrayLike, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#

Computes the Kullback-Leibler divergence (relative entropy) loss.

Version of kl_div_loss where targets are given in log-space.

Parameters:
  • log_predictions โ€“ Probabilities of predicted distribution with shape [โ€ฆ, dim]. Expected to be in the log-space to avoid underflow.

  • log_targets โ€“ Probabilities of target distribution with shape [โ€ฆ, dim]. Expected to be in the log-space.

  • axis โ€“ Axis or axes along which to compute.

  • where โ€“ Elements to include in the computation.

Returns:

Kullback-Leibler divergence of predicted distribution from target distribution with shape [โ€ฆ].

Changed in version 0.2.4: Added axis and where arguments.