optax.losses.poly_loss_cross_entropy#
- optax.losses.poly_loss_cross_entropy(logits: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, *, epsilon: jax.typing.ArrayLike = 2.0, axis: int | tuple[int, ...] | None = -1, where: TypeAliasForwardRef('jax.typing.ArrayLike') | None = None) Array[source]#
Computes PolyLoss between logits and labels.
The PolyLoss is a loss function that decomposes commonly used classification loss functions into a series of weighted polynomial bases. It is inspired by the Taylor expansion of cross-entropy loss and focal loss in the bases of \((1 - P_t)^j\).
\[L_{Poly} = \sum_1^\infty \alpha_j \cdot (1 - P_t)^j \\ L_{Poly-N} = (\epsilon_1 + 1) \cdot (1 - P_t) + \ldots + \\ (\epsilon_N + \frac{1}{N}) \cdot (1 - P_t)^N + \frac{1}{N + 1} \cdot (1 - P_t)^{N + 1} + \ldots = \\ - \log(P_t) + \sum_{j = 1}^N \epsilon_j \cdot (1 - P_t)^j \]This function provides a simplified version of \(L_{Poly-N}\) with only the coefficient of the first polynomial term being changed.
- Parameters:
logits โ Unnormalized log probabilities, with shape [โฆ, num_classes].
labels โ Valid probability distributions (non-negative, sum to 1), e.g. a one hot encoding specifying the correct class for each input; must have a shape broadcastable to [โฆ, num_classes].
epsilon โ The coefficient of the first polynomial term. According to the paper, the following values are recommended: - For the ImageNet 2d image classification, epsilon = 2.0. - For the 2d Instance Segmentation and object detection, epsilon = -1.0. - It is also recommended to adjust this value based on the task, e.g. by using grid search.
axis โ Axis or axes along which to compute.
where โ Elements to include in the computation.
- Returns:
Poly loss between each prediction and the corresponding target distributions, with shape [โฆ].
References
Leng et al, PolyLoss: A Polynomial Expansion Perspective of Classification Loss Functions, 2022
Changed in version 0.2.4: Added
axisandwherearguments.