Shortcuts

torch.nn.functional.ctc_loss

torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[源代码]

应用连接时序分类损失。

详情请参见 CTCLoss

注意

在某些情况下,当给定的张量位于CUDA设备上并使用CuDNN时,此操作符可能会选择一个非确定性算法以提高性能。如果这是不可取的,您可以尝试通过设置torch.backends.cudnn.deterministic = True来使操作具有确定性(可能会以性能为代价)。更多信息请参见可重复性

注意

当在CUDA设备上给定张量时,此操作可能会产生不确定的梯度。更多信息请参见可重复性

Parameters
  • log_probs (Tensor) – (T,N,C)(T, N, C)(T,C)(T, C) 其中 C = 字母表中字符的数量,包括空白, T = 输入长度, 和 N = 批次大小. 输出的对数概率 (例如,通过 torch.nn.functional.log_softmax() 获得).

  • 目标 (张量) – (N,S)(N, S)(sum(target_lengths)). 目标不能为空白。在第二种形式中,假设目标是连接在一起的。

  • input_lengths (Tensor) – (N)(N)()(). 输入的长度(每个长度必须为 T\leq T

  • target_lengths (Tensor) – (N)(N)()(). 目标的长度

  • 空白 (int, 可选) – 空白标签。默认值 00

  • reduction (str, 可选) – 指定应用于输出的reduction方式: 'none' | 'mean' | 'sum''none':不进行reduction, 'mean':输出损失将被目标长度除,然后对批次取平均值,'sum':输出将被求和。默认值:'mean'

  • zero_infinity (布尔值, 可选) – 是否将无限损失和相关的梯度归零。 默认值: False 无限损失主要发生在输入太短而无法与目标对齐时。

Return type

张量

示例:

>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
>>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
>>> input_lengths = torch.full((16,), 50, dtype=torch.long)
>>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long)
>>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths)
>>> loss.backward()
优云智算