torch.nn.attention.bias 的源代码
"""定义与scaled_dot_product_attention一起工作的偏置子类"""
from enum import auto, IntEnum
from typing import Optional
from warnings import warn
import torch
from torch.backends.cuda import (
can_use_efficient_attention,
can_use_flash_attention,
SDPAParams,
)
from torch.nn.attention import _raise_kernel_warnings
from torch.nn.attention._utils import (
_calculate_scale,
_input_requires_grad,
_postprocess_flash_output,
_validate_sdpa_input,
)
from torch.nn.functional import scaled_dot_product_attention
__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]
torch._dynamo.allow_in_graph(can_use_flash_attention)
torch._dynamo.allow_in_graph(can_use_efficient_attention)
torch._dynamo.allow_in_graph(SDPAParams)
[docs]class CausalVariant(IntEnum):
r"""
用于注意力机制中的因果变体的枚举。
定义了两种类型的因果偏置:
`UPPER_LEFT`:表示标准因果注意的上左三角偏置。
构造此偏置的等效pytorch代码为:
.. code-block:: python
torch.tril(torch.ones(size, dtype=torch.bool))
例如,对于`shape=(3,4)`,具体化的偏置张量为:
.. code-block:: text
[[1, 0, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 0]]
`LOWER_RIGHT`:表示下右三角偏置,包含的值对齐到矩阵的右下角。
构造此偏置的等效pytorch代码为:
.. code-block:: python
diagonal_offset = size[1] - size[0]
torch.tril(
torch.ones(size, dtype=torch.bool),
diagonal=diagonal_offset,
)
例如,对于`shape=(3,4)`,具体化的偏置张量为:
.. code-block:: text
[[1, 1, 0, 0],
[1, 1, 1, 0],
[1, 1, 1, 1]]
请注意,当查询和键/值张量的序列长度相等时,这些变体是等效的,因为三角矩阵是方形的。
.. warning:: 此枚举是原型,可能会更改。
"""
UPPER_LEFT = auto()
LOWER_RIGHT = auto()
[docs]class CausalBias(torch.Tensor):
"""
表示因果注意模式的偏置。有关偏置结构的概述,请参见:class:`CausalVariant`枚举。
此类用于定义因果(三角形)注意偏置。为了构造偏置,存在两个工厂函数::func:`causal_upper_left`和:func:`causal_lower_right`。
示例:
.. code-block:: python
from torch.nn.attention.bias import causal_lower_right
bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8
# 创建一个下右因果偏置
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
out = F.scaled_dot_product_attention(q, k, v, attn_bias)
.. warning:: 此类是原型,可能会更改。
"""
def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int):
"""
使用指定的变体和序列长度初始化CausalBias实例。
参数:
variant (CausalVariant): 要使用的因果偏置类型(UPPER_LEFT或LOWER_RIGHT)。
seq_len_q (int): 查询张量的序列长度。
seq_len_kv (int): 键/值张量的序列长度。
如果使用LOWER_RIGHT变体且seq_len_q > seq_len_kv,则会发出警告,因为它可能会产生NaN。
"""
assert isinstance(variant, CausalVariant)
self.variant = variant
self.seq_len_q = seq_len_q
self.seq_len_kv = seq_len_kv
if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT:
warn(
"Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!"
)
def _upper_left(self, device: torch.device) -> torch.Tensor:
"""上左因果偏置"""
return torch.tril(
torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool)
)
def _lower_right(self, device: torch.device) -> torch.Tensor:
"""下右因果偏置"""
diagonal_offset = self.seq_len_kv - self.seq_len_q
return torch.tril(
torch.ones(
self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool
),
diagonal=<span class="