torch.nn.attention 的源代码
""" 该模块包含改变 torch.nn.functional.scaled_dot_product_attention 行为的函数和类 """
import contextlib
from typing import List, Union
from warnings import warn
from torch.backends.cuda import (
can_use_efficient_attention,
can_use_flash_attention,
enable_flash_sdp,
enable_math_sdp,
enable_mem_efficient_sdp,
flash_sdp_enabled,
math_sdp_enabled,
mem_efficient_sdp_enabled,
SDPAParams,
)
__all__: List[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
# 注意: [SDPA 警告]
# 待办事项: 考虑无论子类如何都使用此功能
# 这只影响使用偏置子类的用户
# 如果设置为 True,如果用户没有使用融合内核,我们将警告用户
# 此外,它将针对所有无法运行融合内核的原因发出警告。
# 要将其设置为 True,请运行
# torch.nn.attention.WARN_FOR_UNFUSED_KERNELS = True
WARN_FOR_UNFUSED_KERNELS = False
from torch._C import _SDPBackend as SDPBackend
# Sphinx 文档的技巧:
# https://stackoverflow.com/questions/38765577/overriding-sphinx-autodoc-alias-of-for-import-of-private-class
SDPBackend = SDPBackend
r"""一个包含缩放点积注意力的不同后端的类,类似于枚举。
这个后端类设计用于 sdpa_kernel 上下文管理器。
以下枚举可用:
- ERROR: 尝试确定后端时发生错误。
- MATH: 缩放点积注意力的数学后端。
- FLASH_ATTENTION: 缩放点积注意力的闪存注意力后端。
- EFFICIENT_ATTENTION: 缩放点积注意力的高效注意力后端。
- CUDNN_ATTENTION: 缩放点积注意力的 cuDNN 后端。
更多详情请参见 :func:`torch.nn.attention.sdpa_kernel`。
.. 警告:: 此类处于测试阶段,可能会发生变化。
"""
SDPBackend.__module__ = __name__
SDPBackend.__name__ = "SDPBackend"
def _raise_kernel_warnings(params: SDPAParams) -> None:
"""
如果 WARN_FOR_UNFUSED_KERNELS 设置为 True,这将针对所有无法运行融合内核的原因发出警告。
如果使用子类
"""
if WARN_FOR_UNFUSED_KERNELS:
if not can_use_efficient_attention(params):
warn("无法使用高效注意力,因为:")
can_use_efficient_attention(params, True)
if not can_use_flash_attention(params):
warn("无法使用闪存注意力,因为:")
can_use_flash_attention(params, True)
[docs]@contextlib.contextmanager
def sdpa_kernel(backends: Union[List[SDPBackend], SDPBackend]):
r"""
上下文管理器,用于选择缩放点积注意力的后端。
.. 警告:: 此函数处于测试阶段,可能会发生变化。
参数:
backend (Union[List[SDPBackend], SDPBackend]): 缩放点积注意力的后端或后端列表。
示例:
.. 代码块:: python
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
# 仅启用闪存注意力后端
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
scaled_dot_product_attention(...)
# 启用数学或高效注意力后端
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
scaled_dot_product_attention(...)
此上下文管理器可用于选择缩放点积注意力的后端。
退出上下文管理器时,标志的先前状态将被恢复,启用所有后端。
"""
assert isinstance(
backends, (list, SDPBackend)
), "后端必须是 SDPBackend 的实例或 SDPBackend 实例的列表"
if isinstance(backends, SDPBackend):
backends = [backends]
backends = set(backends)
previous_flash: bool = flash_sdp_enabled()
previous_mem_efficient: bool = mem_efficient_sdp_enabled()
previous_math: bool = math_sdp_enabled()
try:
enable_flash = SDPBackend.FLASH_ATTENTION in backends
enable_mem_efficient = SDPBackend.EFFICIENT_ATTENTION in backends
enable_math = SDPBackend.MATH in backends
enable_flash_sdp(enable_flash)
enable_mem_efficient_sdp(enable_mem_efficient)
enable_math_sdp(enable_math)
yield {}
finally:
enable_flash_sdp(previous_flash)
enable_mem_efficient_sdp(previous_mem_efficient)
enable_math_sdp(previous_math)