Shortcuts

torch.nn.attention 的源代码

""" 该模块包含改变 torch.nn.functional.scaled_dot_product_attention 行为的函数和类 """
import contextlib
from typing import List, Union
from warnings import warn

from torch.backends.cuda import (
    can_use_efficient_attention,
    can_use_flash_attention,
    enable_flash_sdp,
    enable_math_sdp,
    enable_mem_efficient_sdp,
    flash_sdp_enabled,
    math_sdp_enabled,
    mem_efficient_sdp_enabled,
    SDPAParams,
)

__all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]

# 注意: [SDPA 警告]
# 待办事项: 考虑无论子类如何都使用此功能
# 这只影响使用偏置子类的用户
# 如果设置为 True,如果用户没有使用融合内核,我们将警告用户
# 此外,它将针对所有无法运行融合内核的原因发出警告。
# 要将其设置为 True,请运行
# torch.nn.attention.WARN_FOR_UNFUSED_KERNELS = True
WARN_FOR_UNFUSED_KERNELS = False


from torch._C import _SDPBackend as SDPBackend

# Sphinx 文档的技巧:
# https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class
SDPBackend = SDPBackend
r"""一个包含缩放点积注意力的不同后端的类,类似于枚举。
    这个后端类设计用于 sdpa_kernel 上下文管理器。

    以下枚举可用:
        - ERROR: 尝试确定后端时发生错误。
        - MATH: 缩放点积注意力的数学后端。
        - FLASH_ATTENTION: 缩放点积注意力的闪存注意力后端。
        - EFFICIENT_ATTENTION: 缩放点积注意力的高效注意力后端。
        - CUDNN_ATTENTION: 缩放点积注意力的 cuDNN 后端。

    更多详情请参见 :func:`torch.nn.attention.sdpa_kernel`。

    .. 警告:: 此类处于测试阶段,可能会发生变化。
"""
SDPBackend.__module__ = __name__
SDPBackend.__name__ = "SDPBackend"


def _raise_kernel_warnings(params: SDPAParams) -> None:
    """
    如果 WARN_FOR_UNFUSED_KERNELS 设置为 True,这将针对所有无法运行融合内核的原因发出警告。
    如果使用子类
    """
    if WARN_FOR_UNFUSED_KERNELS:
        if not can_use_efficient_attention(params):
            warn("无法使用高效注意力,因为:")
            can_use_efficient_attention(params, True)
        if not can_use_flash_attention(params):
            warn("无法使用闪存注意力,因为:")
            can_use_flash_attention(params, True)


[docs]@contextlib.contextmanager def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]): r""" 上下文管理器,用于选择缩放点积注意力的后端。 .. 警告:: 此函数处于测试阶段,可能会发生变化。 参数: backend (Union[List[SDPBackend], SDPBackend]): 缩放点积注意力的后端或后端列表。 示例: .. 代码块:: python from torch.nn.functional import scaled_dot_product_attention from torch.nn.attention import SDPBackend, sdpa_kernel # 仅启用闪存注意力后端 with sdpa_kernel(SDPBackend.FLASH_ATTENTION): scaled_dot_product_attention(...) # 启用数学或高效注意力后端 with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]): scaled_dot_product_attention(...) 此上下文管理器可用于选择缩放点积注意力的后端。 退出上下文管理器时,标志的先前状态将被恢复,启用所有后端。 """ assert isinstance( backends, (list, SDPBackend) ), "后端必须是 SDPBackend 的实例或 SDPBackend 实例的列表" if isinstance(backends, SDPBackend): backends = [backends] backends = set(backends) previous_flash: bool = flash_sdp_enabled() previous_mem_efficient: bool = mem_efficient_sdp_enabled() previous_math: bool = math_sdp_enabled() try: enable_flash = SDPBackend.FLASH_ATTENTION in backends enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends enable_math = SDPBackend.MATH in backends enable_flash_sdp(enable_flash) enable_mem_efficient_sdp(enable_mem_efficient) enable_math_sdp(enable_math) yield {} finally: enable_flash_sdp(previous_flash) enable_mem_efficient_sdp(previous_mem_efficient) enable_math_sdp(previous_math)
优云智算