Shortcuts

torch.ao.ns.fx.utils 的源代码

```html
import enum
import operator

import torch
import torch.nn as nn
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.quantized as nnq

toq = torch.ops.quantized
from typing import Tuple, Callable, Dict, Set, List, Optional, Union

from torch.fx import GraphModule
from torch.fx.graph import Node
from torch.ao.quantization import (
    ObserverBase,
    FakeQuantizeBase,
)
from torch.ao.quantization.utils import getattr_from_fqn
from torch.ao.quantization.observer import _is_activation_post_process

from .ns_types import NSNodeTargetType, NSResultsType

# TODO(future PR): 考虑删除这个枚举并直接使用torch类型。这可能很棘手,因为这不是一对一的映射。
class NodeInputOrOutputType(enum.Enum):
    FP32 = enum.auto()  # torch.float
    INT8 = enum.auto()  # torch.qint8 或 torch.quint8
    FP16 = enum.auto()  # torch.float16
    UNKNOWN = enum.auto()  # 我们无法确定输入/输出数据类型
    # TODO(future PR): 虽然这些函数可以支持多种数据类型,
    #   为了数值调试的目的,我们希望获得模型中实际使用的数据类型。我们可能需要某种数据类型传播来估计这一点。
    FP32_OR_INT8 = enum.auto()  # 要么是 torch.float 要么是 torch.quint8 要么是 torch.qint8
    # TODO(future PRs): 动态量化、假量化等


def get_node_first_input_and_output_type(
    node: Node,
    gm: GraphModule,
    logger_cls: Callable,
    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]:

    # TODO(future PR): 清理这个
    FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
    FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
    FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
    FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
    MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
    MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
    METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]

    if node.op == "call_function":
        if node.target in FUNS_IO_TYPE_FP32:
            return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
        if node.target in FUNS_IO_TYPE_FP16:
            return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16)
        elif node.target in FUNS_IO_TYPE_INT8:
            return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
        elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
            first_arg = get_normalized_nth_input(node, gm, 0)
            assert isinstance(first_arg, Node)
            (
                _prev_node_input_type,
                prev_node_output_type,
            ) = get_node_first_input_and_output_type(
                first_arg, gm, logger_cls, node_type_to_io_type_map
            )
            return (prev_node_output_type, prev_node_output_type)
        else:
            return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)

    elif node.op == "call_module":
        assert node.op == "call_module"
        assert isinstance(node.target, str)
        mod = getattr_from_fqn(gm, node.target)
        is_known_fp32_or_int8_input_module = any(
            isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8  # type: ignore[arg-type]
       
优云智算