optax.losses.make_fenchel_young_loss

optax.losses.make_fenchel_young_loss#

optax.losses.make_fenchel_young_loss(max_fun: MaxFun)[source]#

Creates a Fenchel-Young loss from a max function.

Parameters:

max_fun โ€“ the max function on which the Fenchel-Young loss is built.

Returns:

A Fenchel-Young loss function with the same signature.

Examples

Given a max function, e.g., the log sum exp, you can construct a Fenchel-Young loss easily as follows:

>>> from jax.scipy.special import logsumexp
>>> fy_loss = optax.losses.make_fenchel_young_loss(max_fun=logsumexp)
Reference:

Blondel et al. Learning with Fenchel-Young Losses, 2020

Warning

The resulting loss accepts an arbitrary number of leading dimensions with the fy_loss operating over the last dimension. The jaxopt version of this function would instead flatten any vector in a single big 1D vector.