torch.nn.functional.scaled_dot_product_attention¶
- torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) Tensor:¶
计算查询、键和值张量上的缩放点积注意力,如果传递了可选的注意力掩码,则使用该掩码,并在指定的概率大于0.0时应用dropout。可选的scale参数只能作为关键字参数指定。
# 高效的实现等同于以下内容: def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = torch.softmax(attn_weight, dim=-1) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value
警告
此功能为测试版,可能会发生变化。
注意
目前有三种支持的缩放点积注意力实现:
一个在C++中定义的PyTorch实现,与上述公式相匹配
在使用CUDA后端时,该函数可能会调用优化的内核以提高性能。对于所有其他后端,将使用PyTorch实现。
所有实现默认情况下都是启用的。缩放点积注意力尝试根据输入自动选择最优的实现。为了提供更细粒度的控制,可以使用以下函数来启用和禁用实现。上下文管理器是首选机制:
torch.nn.attention.sdpa_kernel(): 一个用于启用或禁用任何实现的环境管理器。torch.backends.cuda.enable_flash_sdp(): 全局启用或禁用FlashAttention。torch.backends.cuda.enable_mem_efficient_sdp(): 全局启用或禁用内存高效注意力。torch.backends.cuda.enable_math_sdp(): 全局启用或禁用 PyTorch C++ 实现。
每个融合内核都有特定的输入限制。如果用户需要使用特定的融合实现, 请使用
torch.nn.attention.sdpa_kernel()禁用 PyTorch C++ 实现。 如果无法使用融合实现,将会发出警告,并说明无法运行融合实现的原因。由于浮点运算融合的性质,此函数的输出可能会因所选的后端内核而异。 C++ 实现支持 torch.float64,并且可以在需要更高精度时使用。 有关更多信息,请参阅 数值精度
注意
在某些情况下,当给定的张量位于CUDA设备上并使用CuDNN时,此操作符可能会选择一个非确定性算法以提高性能。如果这是不可取的,您可以尝试通过设置
torch.backends.cudnn.deterministic = True来使操作具有确定性(可能会以性能为代价)。更多信息请参见可重复性。- Parameters
查询 (张量) – 查询张量;形状 .
键 (张量) – 键张量;形状 .
值 (张量) – 值张量;形状 .
attn_mask (可选的张量) – 注意力掩码;形状必须可广播到注意力权重的形状, 即 . 支持两种类型的掩码。 一个布尔掩码,其中True值表示该元素应该参与注意力计算。 一个与查询、键、值类型相同的浮点掩码,它被添加到注意力分数中。
dropout_p (float) – 丢弃概率;如果大于 0.0,则应用丢弃
is_causal (bool) – 如果为真,假设上左因果注意力掩码,并在同时设置attn_mask和is_causal时报错。
scale (可选 python:float, 仅关键字) – 在softmax之前应用的缩放因子。如果为None,默认值设置为。
- Returns
注意力输出;形状 。
- Return type
输出 (Tensor)
- Shape legend:
示例
>>> # 可选地使用上下文管理器以确保运行其中一个融合内核 >>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda") >>> with torch.backends.cuda.sdp_kernel(enable_math=False): >>> F.scaled_dot_product_attention(query,key,value)