Shortcuts

torch.nn.utils.rnn.pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)[源代码]

打包一个包含可变长度填充序列的张量。

input 可以是大小为 T x B x * 的,其中 T 是最长序列的长度(等于 lengths[0]),B 是批次大小,而 * 是任意数量的维度(包括 0)。如果 batch_firstTrue,则期望 B x T x *input

对于未排序的序列,使用 enforce_sorted = False。如果 enforce_sortedTrue,序列应按长度降序排列,即 input[:,0] 应为最长序列,而 input[:,B-1] 为最短 序列。enforce_sorted = True 仅在导出为 ONNX 时需要。

注意

此函数接受至少具有两个维度的任何输入。您可以将其应用于打包标签,并使用RNN的输出与它们一起直接计算损失。可以通过访问PackedSequence对象的.data属性来检索张量。

Parameters
  • 输入 (张量) – 填充后的可变长度序列批次。

  • 长度 (张量列表(整数)) – 每个批次元素的序列长度列表(如果作为张量提供,则必须在CPU上)。

  • batch_first (bool, 可选) – 如果True,输入应为B x T x * 格式。

  • enforce_sorted (布尔值, 可选) – 如果True,输入应包含按长度递减顺序排序的序列。如果False,输入将被无条件排序。默认值:True

Returns

一个 PackedSequence 对象

Return type

打包序列