padded_collate_dpo¶
- torchtune.data.padded_collate_dpo(batch: List[Dict[str, List[int]]], padding_idx: int = 0, ignore_idx: int = - 100) Tuple[Tensor, Tensor][source]¶
为直接偏好优化(DPO)填充一批序列。
此函数接收一批序列,其中每个序列表示为具有多个键值对的字典。每个键对应不同的序列组件,例如input_ids或labels。
- Parameters:
- Returns:
一个包含连接和填充的输入ID和标签的元组。
- Return type:
元组[torch.Tensor, torch.Tensor]
示例
>>> batch = [ >>> {'chosen_input_ids': [1, 2, 3], 'rejected_input_ids': [4, 5], >>> 'chosen_labels': [6, 7, 8], 'rejected_labels': [9, 10]}, >>> {'chosen_input_ids': [11, 12], 'rejected_input_ids': [13, 14, 15], >>> 'chosen_labels': [16, 17], 'rejected_labels': [18, 19, 20]}, >>> ] >>> padded_collate_dpo(batch) >>> (tensor([[ 1, 2, 3], >>> [11, 12, 0], >>> [ 4, 5, 0], >>> [13, 14, 15]]), >>> tensor([[ 6, 7, 8], >>> [16, 17, -100], >>> [ 9, 10, -100], >>> [18, 19, 20]]))