torch.fx.node 的源代码
# mypy: ignore-errors
# 节点表示在我们的操作符图中一个值的定义。
from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
from ._compatibility import compatibility
from .immutable_collections import immutable_dict, immutable_list
import torch
import builtins
import types
import inspect
import warnings
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
from .._ops import ops as _ops
if TYPE_CHECKING:
from .graph import Graph
__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"]
BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype,
torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload]
base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined]
Target = Union[Callable[..., Any], str]
Argument = Optional[Union[
Tuple[Any, ...], # 实际上是Argument,但mypy无法表示递归类型
List[Any], # 实际上是Argument
Dict[str, Any], # 实际上是Argument
slice, # Slice[Argument, Argument, Argument],但slice在typing中不是模板类型
range,
'Node',
BaseArgumentTypes
]]
_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
torch._C._set_grad_enabled,
torch.amp._enter_autocast,
torch.amp._exit_autocast,
}
_side_effectful_functions: Set[Callable] = {
torch._assert,
torch._assert_async,
_ops.aten._assert_async.msg,
_ops.aten._assert_scalar.default,
_ops.aten.copy_.default,
_ops.aten.sym_constrain_range.default,
_ops.aten.sym_constrain_range_for_size.default,
_ops.profiler._record_function_enter,
_ops.profiler._record_function_enter_new,
_ops.profiler._record_function_exit,
_ops.inductor.accumulate_grad_.default,
_ops.inductor.resize_storage_bytes_.default,
} | _side_effectful_need_to_be_preserved_pre_dispatch
@compatibility(is_backward_compatible=False)
def has_side_effect(fn: Callable) -> None:
_side_effectful_functions.add(fn)
return fn
# 这是在主分支上修复的,针对1.5的WAR
def _find_module_of_method(orig_method: Callable[..., Any]) -> str:
name = orig_method.__name__
module = orig_method.__module__
if module is not None:
return module
for guess in [torch, torch.nn.functional]:
if getattr(guess, name, None) is orig_method:
return guess.__name__
raise RuntimeError(f'无法找到模块为 {orig_method}')
# 从CPython typing模块借用
# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156
def _type_repr(obj):
"""返回对象的repr(),特殊处理类型(内部帮助函数)。
如果obj是类型,我们返回一个比默认的type.__repr__更短的版本,基于模块和限定名称,这通常足以唯一标识一个类型。对于其他所有内容,我们回退到repr(obj)。
"""
if isinstance(obj, type):
if obj.__module__ <span class="