optax.losses.convex_kl_divergence

optax.losses.convex_kl_divergence#

optax.losses.convex_kl_divergence(log_predictions: jax.typing.ArrayLike, targets: jax.typing.ArrayLike, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#