Shortcuts

padded_collate

torchtune.data.padded_collate(batch: List[Dict[str, List[int]]], *, pad_direction: str, keys_to_pad: List[str], padding_idx: Union[int, Dict[str, int]])[source]

一个通用的填充整理函数,它将从给定的pad_direction方向填充keys_to_pad条目,使其达到批次中每个条目的最大序列长度。

注意

此函数假设所有不在keys_to_pad中的批次元素不需要任何整理(参见下面的示例)。

Parameters:
  • batch (List[Dict[str, List[int]]]) – 包含输入的字典列表。

  • pad_direction (str) – 是从左侧还是右侧填充条目。如果 pad_direction="right",我们使用 torch.nn.utils.rnn.pad_sequence(),否则如果 pad_direction="left", 我们使用 torchtune.data.left_pad_sequence()

  • keys_to_pad (List[str]) – 要应用填充的批处理元素键。应该是批处理中键的子集。

  • padding_idx (Union[int, Dict[str, int]]) – 可以是一个单一的整数值,应用于所有keys_to_pad元素,或者是一个与keys_to_pad相同的键的映射,每个键都有对应的填充值。

Returns:

填充后的输入ID张量,形状为[batch_size, max_seq_len]。

Return type:

torch.Tensor

Raises:
  • ValueError – 如果 pad_direction 不是“left”或“right”之一。

  • ValueError – 如果 keys_to_pad 为空,或者不是列表,或者不是批次中键的子集。

  • ValueError – 如果 padding_idx 作为字典提供,但键与 keys_to_pad 不完全相同。

示例

>>> a = [1, 2, 3]
>>> b = [4, 5, 6, 7]
>>> c = [8, 9, 10, 11, 12]
>>> batch = [
>>>     {"tokens": a, "labels": 1},
>>>     {"tokens": b, "labels": 3},
>>>     {"tokens": c, "labels": 0},
>>> ]
>>> padded_collate(
>>>     batch,
>>>     pad_direction="left",
>>>     keys_to_pad=["tokens"],
>>>     padding_idx=-10
>>> )
{
    'labels': tensor([1, 3, 0]),
    'tokens': tensor([[-10, -10,   1,   2,   3],
                      [-10,   4,   5,   6,   7],
                      [  8,   9,  10,  11,  12]])
}