optax.losses.ctc_loss

Contents

optax.losses.ctc_loss#

optax.losses.ctc_loss(logits: jax.typing.ArrayLike, logit_paddings: jax.typing.ArrayLike, labels: jax.typing.ArrayLike, label_paddings: jax.typing.ArrayLike, *, blank_id: int = 0, log_epsilon: jax.typing.ArrayLike = -100000.0) Array[source]#

Computes CTC loss.

See docstring for ctc_loss_with_forward_probs for details.

Parameters:
  • logits โ€“ (B, T, K)-array containing logits of each class where B denotes the batch size, T denotes the max time frames in logits, and K denotes the number of classes including a class for blanks.

  • logit_paddings โ€“ (B, T)-array. Padding indicators for logits. Each element must be either 1.0 or 0.0, and logitpaddings[b, t] == 1.0 denotes that logits[b, t, :] are padded values.

  • labels โ€“ (B, N)-array containing reference integer labels where N denotes the max time frames in the label sequence.

  • label_paddings โ€“ (B, N)-array. Padding indicators for labels. Each element must be either 1.0 or 0.0, and labelpaddings[b, n] == 1.0 denotes that labels[b, n] is a padded label. In the current implementation, labels must be right-padded, i.e. each row labelpaddings[b, :] must be repetition of zeroes, followed by repetition of ones.

  • blank_id โ€“ Id for blank token. logits[b, :, blank_id] are used as probabilities of blank symbols.

  • log_epsilon โ€“ Numerically-stable approximation of log(+0).

Returns:

(B,)-array containing loss values for each sequence in the batch.