get_causal_mask_from_padding_mask¶
- torchtune.generation.get_causal_mask_from_padding_mask(padding_mask: Tensor, target_seq_len: Optional[int] = None) Tensor[source]¶
将形状为
[bsz, seq_len]的填充掩码转换为适合scaled_dot_product_attention()使用的形状为[bsz, seq_len, seq_len]的因果注意力掩码。如果提供了target_seq_len,这将返回一个形状为[bsz, seq_len, target_seq_len]的掩码。这在为静态KV缓存生成掩码时非常有用,因为缓存设置的最大长度比当前序列长。- Parameters:
padding_mask (torch.Tensor) – 布尔张量,其中False表示序列中的相应标记是填充标记,应在注意力中被屏蔽,形状为[bsz x seq_length]
target_seq_len (可选[int]) – 用于创建注意力掩码的目标序列长度。默认为 None。
- Returns:
- 具有形状的布尔因果掩码
[bsz, seq_length, seq_length] 或
[bsz, seq_length, target_seq_len] 如果指定了
target_seq_len。
- Return type:
- Raises:
AssertionError – 如果
target_seq_len > seq_len,即填充掩码的序列长度。
示例
>>> padding_mask = torch.tensor([[False, True, True, True]]) >>> get_causal_mask_from_padding_mask(padding_mask, target_seq_len=5) tensor([[[ True, False, False, False, False], [False, True, False, False, False], [False, True, True, False, False], [False, True, True, True, False]]]) ])