optax.losses.make_fenchel_young_loss#
- optax.losses.make_fenchel_young_loss(max_fun: MaxFun)[source]#
Creates a Fenchel-Young loss from a max function.
- Parameters:
max_fun โ the max function on which the Fenchel-Young loss is built.
- Returns:
A Fenchel-Young loss function with the same signature.
Examples
Given a max function, e.g., the log sum exp, you can construct a Fenchel-Young loss easily as follows:
>>> from jax.scipy.special import logsumexp >>> fy_loss = optax.losses.make_fenchel_young_loss(max_fun=logsumexp)
- Reference:
Blondel et al. Learning with Fenchel-Young Losses, 2020
Warning
The resulting loss accepts an arbitrary number of leading dimensions with the fy_loss operating over the last dimension. The jaxopt version of this function would instead flatten any vector in a single big 1D vector.