torch.backends.opt_einsum 的源代码
import sys
import warnings
from contextlib import contextmanager
from functools import lru_cache as _lru_cache
from typing import Any
from torch.backends import __allow_nonbracketed_mutation, ContextProp, PropModule
try:
import opt_einsum as _opt_einsum # type: ignore[import]
except ImportError:
_opt_einsum = None
[docs]@_lru_cache
def is_available() -> bool:
r"""返回一个布尔值,指示opt_einsum当前是否可用。"""
return _opt_einsum is not None
[docs]def get_opt_einsum() -> Any:
r"""如果opt_einsum当前可用,则返回opt_einsum包,否则返回None。"""
return _opt_einsum
def _set_enabled(_enabled: bool) -> None:
if not is_available() and _enabled:
raise ValueError(
f"opt_einsum不可用,因此将`enabled`设置为{_enabled}将不会带来计算einsum最佳路径的好处。torch.einsum将回退到从左到右的收缩。要启用此最佳路径计算,请安装opt-einsum。"
)
global enabled
enabled = _enabled
def _get_enabled() -> bool:
return enabled
def _set_strategy(_strategy: str) -> None:
if not is_available():
raise ValueError(
f"opt_einsum不可用,因此将`strategy`设置为{_strategy}将没有意义。torch.einsum将绕过路径计算,简单地从左到右收缩。请安装opt_einsum或取消设置`strategy`。"
)
if not enabled:
raise ValueError(
f"opt_einsum未启用,因此将`strategy`设置为{_strategy}将没有意义。torch.einsum将绕过路径计算,简单地从左到右收缩。请将`enabled`设置为`True`,或者取消设置`strategy`。"
)
if _strategy not in ["auto", "greedy", "optimal"]:
raise ValueError(
f"`strategy`必须是以下之一:[auto, greedy, optimal],但当前为{_strategy}"
)
global strategy
strategy = _strategy
def _get_strategy() -> str:
return strategy
def set_flags(_enabled=None, _strategy=None):
orig_flags = (enabled, None if not is_available() else strategy)
if _enabled is not None:
_set_enabled(_enabled)
if _strategy is not None:
_set_strategy(_strategy)
return orig_flags
@contextmanager
def flags(enabled=None, strategy=None):
with __allow_nonbracketed_mutation():
orig_flags = set_flags(enabled, strategy)
try:
yield
finally:
# 恢复之前的值
with __allow_nonbracketed_mutation():
set_flags(*orig_flags)
# 这里的魔法是为了让我们能够拦截类似这样的代码:
#
# torch.backends.opt_einsum.enabled = True
class OptEinsumModule(PropModule):
def __init__(self, m, name):
super().__init__(m, name)
global enabled
enabled = ContextProp(_get_enabled, _set_enabled)
global strategy
strategy = None
if is_available():
strategy = ContextProp(_get_strategy, _set_strategy)
# 这是sys.modules替换技巧,参见
# https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__)
enabled = True if is_available() else False
strategy = "auto" if is_available() else None