torch.ao.ns.fx.utils 的源代码
```html
import enum import operator import torch import torch.nn as nn import torch.ao.nn.intrinsic.quantized as nniq import torch.ao.nn.quantized as nnq toq = torch.ops.quantized from typing import Tuple, Callable, Dict, Set, List, Optional, Union from torch.fx import GraphModule from torch.fx.graph import Node from torch.ao.quantization import ( ObserverBase, FakeQuantizeBase, ) from torch.ao.quantization.utils import getattr_from_fqn from torch.ao.quantization.observer import _is_activation_post_process from .ns_types import NSNodeTargetType, NSResultsType # TODO(future PR): 考虑删除这个枚举并直接使用torch类型。这可能很棘手,因为这不是一对一的映射。 class NodeInputOrOutputType(enum.Enum): FP32 = enum.auto() # torch.float INT8 = enum.auto() # torch.qint8 或 torch.quint8 FP16 = enum.auto() # torch.float16 UNKNOWN = enum.auto() # 我们无法确定输入/输出数据类型 # TODO(future PR): 虽然这些函数可以支持多种数据类型, # 为了数值调试的目的,我们希望获得模型中实际使用的数据类型。我们可能需要某种数据类型传播来估计这一点。 FP32_OR_INT8 = enum.auto() # 要么是 torch.float 要么是 torch.quint8 要么是 torch.qint8 # TODO(future PRs): 动态量化、假量化等 def get_node_first_input_and_output_type( node: Node, gm: GraphModule, logger_cls: Callable, node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]], ) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]: # TODO(future PR): 清理这个 FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"] FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"] FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"] FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"] MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"] MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"] MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"] METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"] if node.op == "call_function": if node.target in FUNS_IO_TYPE_FP32: return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32) if node.target in FUNS_IO_TYPE_FP16: return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16) elif node.target in FUNS_IO_TYPE_INT8: return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8) elif node.target in FUNS_IO_TYPE_FP32_OR_INT8: first_arg = get_normalized_nth_input(node, gm, 0) assert isinstance(first_arg, Node) ( _prev_node_input_type, prev_node_output_type, ) = get_node_first_input_and_output_type( first_arg, gm, logger_cls, node_type_to_io_type_map ) return (prev_node_output_type, prev_node_output_type) else: return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN) elif node.op == "call_module": assert node.op == "call_module" assert isinstance(node.target, str) mod = getattr_from_fqn(gm, node.target) is_known_fp32_or_int8_input_module = any( isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8 # type: ignore[arg-type]