Shortcuts

padded_collate_dpo

torchtune.data.padded_collate_dpo(batch: List[Dict[str, List[int]]], padding_idx: int = 0, ignore_idx: int = - 100) Tuple[Tensor, Tensor][source]

为直接偏好优化(DPO)填充一批序列。

此函数接收一批序列,其中每个序列表示为具有多个键值对的字典。每个键对应不同的序列组件,例如input_ids或labels。

Parameters:
  • batch (List[Dict[str, List[int]]]) – 一个字典列表,其中每个字典表示一个具有多个组件的序列,必须包含‘chosen_input_ids’、‘chosen_labels’、‘rejected_input_ids’和‘rejected_labels’。

  • padding_idx (int) – 输入ID的填充索引。默认为0。

  • ignore_idx (int) – 标签的填充索引。默认为-100。

Returns:

一个包含连接和填充的输入ID和标签的元组。

Return type:

元组[torch.Tensor, torch.Tensor]

示例

>>> batch = [
>>>    {'chosen_input_ids': [1, 2, 3], 'rejected_input_ids': [4, 5],
>>>      'chosen_labels': [6, 7, 8], 'rejected_labels': [9, 10]},
>>>    {'chosen_input_ids': [11, 12], 'rejected_input_ids': [13, 14, 15],
>>>      'chosen_labels': [16, 17], 'rejected_labels': [18, 19, 20]},
>>> ]
>>> padded_collate_dpo(batch)
>>> (tensor([[ 1,  2,  3],
>>>          [11, 12,  0],
>>>          [ 4,  5,  0],
>>>          [13, 14, 15]]),
>>>  tensor([[ 6,  7,  8],
>>>          [16, 17, -100],
>>>          [ 9, 10, -100],
>>>          [18, 19, 20]]))