optax.losses.convex_kl_divergence
optax.losses.convex_kl_divergence
-
optax.losses.convex_kl_divergence(log_predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) → Array[source]