Shortcuts

torch.fx.node 的源代码

# mypy: ignore-errors

# 节点表示在我们的操作符图中一个值的定义。
from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
from ._compatibility import compatibility
from .immutable_collections import immutable_dict, immutable_list
import torch
import builtins
import types
import inspect
import warnings
from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair
from .._ops import ops as _ops

if TYPE_CHECKING:
    from .graph import Graph

__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"]

BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype,
                          torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload]
base_types = BaseArgumentTypes.__args__  # type: ignore[attr-defined]

Target = Union[Callable[..., Any], str]

Argument = Optional[Union[
    Tuple[Any, ...],  # 实际上是Argument,但mypy无法表示递归类型
    List[Any],  # 实际上是Argument
    Dict[str, Any],  # 实际上是Argument
    slice,  # Slice[Argument, Argument, Argument],但slice在typing中不是模板类型
    range,
    'Node',
    BaseArgumentTypes
]]

_side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = {
    torch._C._set_grad_enabled,
    torch.amp._enter_autocast,
    torch.amp._exit_autocast,
}

_side_effectful_functions: Set[Callable] = {
    torch._assert,
    torch._assert_async,
    _ops.aten._assert_async.msg,
    _ops.aten._assert_scalar.default,
    _ops.aten.copy_.default,
    _ops.aten.sym_constrain_range.default,
    _ops.aten.sym_constrain_range_for_size.default,
    _ops.profiler._record_function_enter,
    _ops.profiler._record_function_enter_new,
    _ops.profiler._record_function_exit,
    _ops.inductor.accumulate_grad_.default,
    _ops.inductor.resize_storage_bytes_.default,
} | _side_effectful_need_to_be_preserved_pre_dispatch


@compatibility(is_backward_compatible=False)
def has_side_effect(fn: Callable) -> None:
    _side_effectful_functions.add(fn)
    return fn


# 这是在主分支上修复的,针对1.5的WAR
def _find_module_of_method(orig_method: Callable[..., Any]) -> str:
    name = orig_method.__name__
    module = orig_method.__module__
    if module is not None:
        return module
    for guess in [torch, torch.nn.functional]:
        if getattr(guess, name, None) is orig_method:
            return guess.__name__
    raise RuntimeError(f'无法找到模块为 {orig_method}')

# 从CPython typing模块借用
# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156
def _type_repr(obj):
    """返回对象的repr(),特殊处理类型(内部帮助函数)。
    如果obj是类型,我们返回一个比默认的type.__repr__更短的版本,基于模块和限定名称,这通常足以唯一标识一个类型。对于其他所有内容,我们回退到repr(obj)。
    """
    if isinstance(obj, type):
        if obj.__module__ <span class="