Shortcuts

get_causal_mask_from_padding_mask

torchtune.generation.get_causal_mask_from_padding_mask(padding_mask: Tensor, target_seq_len: Optional[int] = None) Tensor[source]

将形状为[bsz, seq_len]的填充掩码转换为适合scaled_dot_product_attention()使用的形状为[bsz, seq_len, seq_len]的因果注意力掩码。如果提供了target_seq_len,这将返回一个形状为[bsz, seq_len, target_seq_len]的掩码。这在为静态KV缓存生成掩码时非常有用,因为缓存设置的最大长度比当前序列长。

Parameters:
  • padding_mask (torch.Tensor) – 布尔张量,其中False表示序列中的相应标记是填充标记,应在注意力中被屏蔽,形状为[bsz x seq_length]

  • target_seq_len (可选[int]) – 用于创建注意力掩码的目标序列长度。默认为 None。

Returns:

具有形状的布尔因果掩码
  • [bsz, seq_length, seq_length] 或

  • [bsz, seq_length, target_seq_len] 如果指定了 target_seq_len

Return type:

torch.Tensor

Raises:

AssertionError – 如果 target_seq_len > seq_len,即填充掩码的序列长度。

示例

>>> padding_mask = torch.tensor([[False, True, True, True]])
>>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5)
tensor([[[ True, False, False, False, False],
          [False,  True, False, False, False],
          [False,  True,  True, False, False],
          [False,  True,  True,  True, False]]])
])