torch.backends.mha 的源代码
# 配置选项以启用/禁用nn.functional.MHA和nn.TransformerEncoder的C++内核
# 以及nn.TransformerEncoder
import torch
_is_fastpath_enabled: bool = True
[docs]def get_fastpath_enabled() -> bool:
"""返回是否启用了TransformerEncoder和MultiHeadAttention的快速路径,
如果jit正在脚本化,则返回``True``。
..注意:
即使``get_fastpath_enabled``返回``True``,除非输入满足所有条件,
否则可能不会运行快速路径。
"""
if not torch.jit.is_scripting():
return _is_fastpath_enabled
return True
[docs]def set_fastpath_enabled(value: bool) -> None:
"""设置是否启用快速路径"""
global _is_fastpath_enabled
_is_fastpath_enabled = value