Shortcuts

torch.ao.quantization.quantize_fx 的源代码

from typing import Any, Dict, Optional, Tuple, Union
import warnings

import torch
import copy
from torch.fx import GraphModule
from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY
from .fx.tracer import QuantizationTracer
from .fx.tracer import (  # noqa: F401
    Scope,
    ScopeContextManager
)
from .fx.fuse import fuse  # noqa: F401
from .fx.prepare import prepare  # noqa: F401
from .fx.convert import convert
from .backend_config import (  # noqa: F401
    BackendConfig,
    get_tensorrt_backend_config,
)
from .fx.graph_module import ObservedGraphModule  # noqa: F401
from .fx.custom_config import (
    ConvertCustomConfig,
    FuseCustomConfig,
    PrepareCustomConfig,
)
from .fx.utils import get_custom_module_class_keys  # noqa: F401
from .fx.utils import get_skipped_module_name_and_classes
from .qconfig_mapping import QConfigMapping

def attach_preserved_attrs_to_model(
    model: Union[GraphModule, torch.nn.Module],
    preserved_attrs: Dict[str, Any],
) -> None:
    """ 将保留的属性存储到模型的.meta中,以便在深度复制期间保留这些属性
    """
    model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs)  # type: ignore[operator, index, assignment]
    # 在模型中设置保留的属性,以便用户可以像在调用fx图模式量化之前一样调用model.attr
    for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items():  # type: ignore[index, union-attr]
        setattr(model, attr_name, attr)

def _check_is_graph_module(model: torch.nn.Module) -> None:
    if not isinstance(model, GraphModule):
        raise ValueError(
            "输入模型必须是GraphModule,"
            + "获取的类型:"
            + str(type(model))
            + " 请确保遵循教程。"
        )

def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None:
    """ 如果图中的节点不存在meta字段,则将其附加到所有节点上,
    meta字段是一个存储节点一些元信息的字段,例如节点的输出数据类型和形状信息,
    只有在程序被make_fx捕获时才会存在(用于quantize_pt2e流程),如果
    程序被torch.fx符号追踪捕获,则此字段可能不存在,因此我们在这里添加它,
    以避免在各处检查此字段
    """
    for node in model.graph.nodes:
        if not hasattr(node, "meta"):
            node.meta = {}

def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
    r""" 将FloatFunctional替换为FXFloatFunctional
    """
    modules_to_swap = []
    for name, module in model.named_children():
        if isinstance(module, torch.ao.nn.quantized.FloatFunctional):
            modules_to_swap.append(name)
        else:
            _swap_ff_with_fxff(module)

    for name in modules_to_swap:
        del model._modules[name]
        model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional()


def _fuse_fx(
    model: GraphModule,
    is_qat: bool,
    fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
    backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
    r""" 内部辅助函数,用于在量化准备中融合模块

    参数:
        model: 来自符号追踪的GraphModule对象 (torch.fx.symbolic_trace)
    """
    _check_is_graph_module(model)
    return fuse(
        model,