Shortcuts

torch.nn.functional.scaled_dot_product_attention

torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) Tensor:

计算查询、键和值张量上的缩放点积注意力,如果传递了可选的注意力掩码,则使用该掩码,并在指定的概率大于0.0时应用dropout。可选的scale参数只能作为关键字参数指定。

# 高效的实现等同于以下内容:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

警告

此功能为测试版,可能会发生变化。

注意

目前有三种支持的缩放点积注意力实现:

在使用CUDA后端时,该函数可能会调用优化的内核以提高性能。对于所有其他后端,将使用PyTorch实现。

所有实现默认情况下都是启用的。缩放点积注意力尝试根据输入自动选择最优的实现。为了提供更细粒度的控制,可以使用以下函数来启用和禁用实现。上下文管理器是首选机制:

每个融合内核都有特定的输入限制。如果用户需要使用特定的融合实现, 请使用 torch.nn.attention.sdpa_kernel() 禁用 PyTorch C++ 实现。 如果无法使用融合实现,将会发出警告,并说明无法运行融合实现的原因。

由于浮点运算融合的性质,此函数的输出可能会因所选的后端内核而异。 C++ 实现支持 torch.float64,并且可以在需要更高精度时使用。 有关更多信息,请参阅 数值精度

注意

在某些情况下,当给定的张量位于CUDA设备上并使用CuDNN时,此操作符可能会选择一个非确定性算法以提高性能。如果这是不可取的,您可以尝试通过设置torch.backends.cudnn.deterministic = True来使操作具有确定性(可能会以性能为代价)。更多信息请参见可重复性

Parameters
  • 查询 (张量) – 查询张量;形状 (N,...,L,E)(N, ..., L, E).

  • (张量) – 键张量;形状 (N,...,S,E)(N, ..., S, E).

  • (张量) – 值张量;形状 (N,...,S,Ev)(N, ..., S, Ev).

  • attn_mask (可选的张量) – 注意力掩码;形状必须可广播到注意力权重的形状, 即 (N,...,L,S)(N,..., L, S). 支持两种类型的掩码。 一个布尔掩码,其中True值表示该元素应该参与注意力计算。 一个与查询、键、值类型相同的浮点掩码,它被添加到注意力分数中。

  • dropout_p (float) – 丢弃概率;如果大于 0.0,则应用丢弃

  • is_causal (bool) – 如果为真,假设上左因果注意力掩码,并在同时设置attn_mask和is_causal时报错。

  • scale (可选 python:float, 仅关键字) – 在softmax之前应用的缩放因子。如果为None,默认值设置为1E\frac{1}{\sqrt{E}}

Returns

注意力输出;形状 (N,...,L,Ev)(N, ..., L, Ev)

Return type

输出 (Tensor)

Shape legend:
  • N:批量大小...:任意数量的其他批次维度(可选)N: \text{批量大小} ... : \text{任意数量的其他批次维度(可选)}

  • S:源序列长度S: \text{源序列长度}

  • L:目标序列长度L: \text{目标序列长度}

  • E:查询和键的嵌入维度E: \text{查询和键的嵌入维度}

  • Ev:值的嵌入维度Ev: \text{值的嵌入维度}

示例

>>> # 可选地使用上下文管理器以确保运行其中一个融合内核
>>> query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
>>> with torch.backends.cuda.sdp_kernel(enable_math=False):
>>>     F.scaled_dot_product_attention(query,key,value)
优云智算