Shortcuts

torch.nn.attention.bias 的源代码

"""定义与scaled_dot_product_attention一起工作的偏置子类"""
from enum import auto, IntEnum
from typing import Optional
from warnings import warn

import torch
from torch.backends.cuda import (
    can_use_efficient_attention,
    can_use_flash_attention,
    SDPAParams,
)
from torch.nn.attention import _raise_kernel_warnings
from torch.nn.attention._utils import (
    _calculate_scale,
    _input_requires_grad,
    _postprocess_flash_output,
    _validate_sdpa_input,
)
from torch.nn.functional import scaled_dot_product_attention

__all__ = ["causal_upper_left", "causal_lower_right", "CausalVariant", "CausalBias"]


torch._dynamo.allow_in_graph(can_use_flash_attention)
torch._dynamo.allow_in_graph(can_use_efficient_attention)
torch._dynamo.allow_in_graph(SDPAParams)


[docs]class CausalVariant(IntEnum): r""" 用于注意力机制中的因果变体的枚举。 定义了两种类型的因果偏置: `UPPER_LEFT`:表示标准因果注意的上左三角偏置。 构造此偏置的等效pytorch代码为: .. code-block:: python torch.tril(torch.ones(size, dtype=torch.bool)) 例如,对于`shape=(3,4)`,具体化的偏置张量为: .. code-block:: text [[1, 0, 0, 0], [1, 1, 0, 0], [1, 1, 1, 0]] `LOWER_RIGHT`:表示下右三角偏置,包含的值对齐到矩阵的右下角。 构造此偏置的等效pytorch代码为: .. code-block:: python diagonal_offset = size[1] - size[0] torch.tril( torch.ones(size, dtype=torch.bool), diagonal=diagonal_offset, ) 例如,对于`shape=(3,4)`,具体化的偏置张量为: .. code-block:: text [[1, 1, 0, 0], [1, 1, 1, 0], [1, 1, 1, 1]] 请注意,当查询和键/值张量的序列长度相等时,这些变体是等效的,因为三角矩阵是方形的。 .. warning:: 此枚举是原型,可能会更改。 """ UPPER_LEFT = auto() LOWER_RIGHT = auto()
[docs]class CausalBias(torch.Tensor): """ 表示因果注意模式的偏置。有关偏置结构的概述,请参见:class:`CausalVariant`枚举。 此类用于定义因果(三角形)注意偏置。为了构造偏置,存在两个工厂函数::func:`causal_upper_left`和:func:`causal_lower_right`。 示例: .. code-block:: python from torch.nn.attention.bias import causal_lower_right bsz, num_heads, seqlen_q, seqlen_kv, head_dim = 32, 8, 4, 12, 8 # 创建一个下右因果偏置 attn_bias = causal_lower_right(seqlen_q, seqlen_kv) q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16) k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16) out = F.scaled_dot_product_attention(q, k, v, attn_bias) .. warning:: 此类是原型,可能会更改。 """ def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int): """ 使用指定的变体和序列长度初始化CausalBias实例。 参数: variant (CausalVariant): 要使用的因果偏置类型(UPPER_LEFT或LOWER_RIGHT)。 seq_len_q (int): 查询张量的序列长度。 seq_len_kv (int): 键/值张量的序列长度。 如果使用LOWER_RIGHT变体且seq_len_q > seq_len_kv,则会发出警告,因为它可能会产生NaN。 """ assert isinstance(variant, CausalVariant) self.variant = variant self.seq_len_q = seq_len_q self.seq_len_kv = seq_len_kv if seq_len_q > seq_len_kv and variant == CausalVariant.LOWER_RIGHT: warn( "Lower right causal bias will produce NaNs in the output when seq_len_q > seq_len_kv!" ) def _upper_left(self, device: torch.device) -> torch.Tensor: """上左因果偏置""" return torch.tril( torch.ones(self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool) ) def _lower_right(self, device: torch.device) -> torch.Tensor: """下右因果偏置""" diagonal_offset = self.seq_len_kv - self.seq_len_q return torch.tril( torch.ones( self.seq_len_q, self.seq_len_kv, device=device, dtype=torch.bool ), diagonal=<span class="
优云智算