Shortcuts

padded_collate_sft

torchtune.data.padded_collate_sft(batch: List[Dict[str, List[int]]], padding_idx: int = 0, ignore_idx: int = - 100) Dict[str, Tensor][source]

将一批序列填充到批次中最长序列的长度,并将整数列表转换为张量。

Parameters:
  • batch (List[Dict[str, List[int]]]) – 包含输入、标签对的字典列表。

  • padding_idx (int) – 输入ID的填充索引。默认为0。

  • ignore_idx (int) – 标签的填充索引。默认为-100。

Returns:

整理后的输入和标签张量。

Return type:

字典[str, torch.Tensor]

示例

>>> token_pairs = [
>>>    {"tokens": [1, 2, 3], "labels": [4, 5, 6]},
>>>    {"tokens": [7,], "labels": [10,]},
>>> ]
>>> collated = padded_collate(
>>>    batch=token_pairs,
>>>    padding_idx=padding_idx,
>>>    ignore_idx=ignore_idx,
>>> )
>>> collated["tokens"]
>>> tensor([[1, 2, 3], [7, 0, 0]])
>>> collated["labels"]
>>> tensor([[4, 5, 6], [10, -100, -100]])