Shortcuts

truncate_sequence_at_first_stop_token

torchtune.rlhf.truncate_sequence_at_first_stop_token(sequences: Tensor, stop_tokens: Tensor, fill_value: int = 0) Tuple[Tensor, Tensor][source]

在第一个停止标记后截断序列,并用fill_value填充。

Parameters:
  • sequences (torch.Tensor) – 形状为 [batch_size, sequence_length] 或 [sequence_length] 的张量。

  • stop_tokens (torch.Tensor) – 包含停止标记的张量。

  • fill_value (int) – 在第一个停止标记后用此值填充序列,通常为 pad_id

Returns:

一个与sequences形状相同的两个张量的元组:
  • padding_mask (torch.Tensor): 一个布尔张量,其中True表示该标记已被截断。

  • sequences (torch.Tensor) 一个包含截断和填充序列的张量。

Return type:

元组[torch.Tensor, torch.Tensor]

示例

>>> stop_token_ids = torch.tensor([2, 869])
>>> fill_value = 0
>>> sequences = torch.tensor(
>>>     [
>>>         [869, 30, 869],
>>>         [2, 30, 869],
>>>         [869, 30, 2],
>>>         [50, 30, 869],
>>>         [13, 30, 2],
>>>         [13, 30, 5],
>>>         [13, 2, 20],
>>>         [13, 2, 2],
>>>         [2, 2, 2],
>>>     ]
>>> )
>>> eos_mask, truncated_sequences = rlhf.truncate_sequence_at_first_stop_token(
>>>     sequences, stop_token_ids, fill_value
>>> )
>>> eos_mask
>>> torch.tensor([
>>>         [False, True, True],
>>>         [False, True, True],
>>>         [False, True, True],
>>>         [False, False, False],
>>>         [False, False, False],
>>>         [False, False, False],
>>>         [False, False, True],
>>>         [False, False, True],
>>>         [False, True, True],
>>>     ]
>>> )
>>> truncated_sequences
>>> torch.tensor([
>>>         [869, 0, 0],
>>>         [2, 0, 0],
>>>         [869, 0, 0],
>>>         [50, 30, 869],
>>>         [13, 30, 2],
>>>         [13, 30, 5],
>>>         [13, 2, 0],
>>>         [13, 2, 0],
>>>         [2, 0, 0],
>>>     ]
>>> )