optax.losses.ctc_loss#
- optax.losses.ctc_loss(logits: jax.typing.ArrayLike, logit_paddings: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, label_paddings: jax.typing.ArrayLike, *, blank_id: int = 0, log_epsilon: jax.typing.ArrayLike = -100000.0) Array[source]#
Computes CTC loss.
See docstring for
ctc_loss_with_forward_probsfor details.- Parameters:
logits โ (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in
logits, and K denotes the number of classes including a class for blanks.logit_paddings โ (B, T)-array. Padding indicators for
logits. Each element must be either 1.0 or 0.0, andlogitpaddings[b, t] == 1.0denotes thatlogits[b, t, :]are padded values.labels โ (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence.
label_paddings โ (B, N)-array. Padding indicators for
labels. Each element must be either 1.0 or 0.0, andlabelpaddings[b, n] == 1.0denotes thatlabels[b, n]is a padded label. In the current implementation,labelsmust be right-padded, i.e. each rowlabelpaddings[b, :]must be repetition of zeroes, followed by repetition of ones.blank_id โ Id for blank token.
logits[b, :, blank_id]are used as probabilities of blank symbols.log_epsilon โ Numerically-stable approximation of log(+0).
- Returns:
(B,)-array containing loss values for each sequence in the batch.