torch.ao.quantization.quantize_fx 的源代码
from typing import Any, Dict, Optional, Tuple, Union
import warnings
import torch
import copy
from torch.fx import GraphModule
from torch.fx.graph_module import _USER_PRESERVED_ATTRIBUTES_KEY
from .fx.tracer import QuantizationTracer
from .fx.tracer import ( # noqa: F401
Scope,
ScopeContextManager
)
from .fx.fuse import fuse # noqa: F401
from .fx.prepare import prepare # noqa: F401
from .fx.convert import convert
from .backend_config import ( # noqa: F401
BackendConfig,
get_tensorrt_backend_config,
)
from .fx.graph_module import ObservedGraphModule # noqa: F401
from .fx.custom_config import (
ConvertCustomConfig,
FuseCustomConfig,
PrepareCustomConfig,
)
from .fx.utils import get_custom_module_class_keys # noqa: F401
from .fx.utils import get_skipped_module_name_and_classes
from .qconfig_mapping import QConfigMapping
def attach_preserved_attrs_to_model(
model: Union[GraphModule, torch.nn.Module],
preserved_attrs: Dict[str, Any],
) -> None:
""" 将保留的属性存储到模型的.meta中,以便在深度复制期间保留这些属性
"""
model.meta[_USER_PRESERVED_ATTRIBUTES_KEY] = copy.copy(preserved_attrs) # type: ignore[operator, index, assignment]
# 在模型中设置保留的属性,以便用户可以像在调用fx图模式量化之前一样调用model.attr
for attr_name, attr in model.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): # type: ignore[index, union-attr]
setattr(model, attr_name, attr)
def _check_is_graph_module(model: torch.nn.Module) -> None:
if not isinstance(model, GraphModule):
raise ValueError(
"输入模型必须是GraphModule,"
+ "获取的类型:"
+ str(type(model))
+ " 请确保遵循教程。"
)
def _attach_meta_to_node_if_not_exist(model: GraphModule) -> None:
""" 如果图中的节点不存在meta字段,则将其附加到所有节点上,
meta字段是一个存储节点一些元信息的字段,例如节点的输出数据类型和形状信息,
只有在程序被make_fx捕获时才会存在(用于quantize_pt2e流程),如果
程序被torch.fx符号追踪捕获,则此字段可能不存在,因此我们在这里添加它,
以避免在各处检查此字段
"""
for node in model.graph.nodes:
if not hasattr(node, "meta"):
node.meta = {}
def _swap_ff_with_fxff(model: torch.nn.Module) -> None:
r""" 将FloatFunctional替换为FXFloatFunctional
"""
modules_to_swap = []
for name, module in model.named_children():
if isinstance(module, torch.ao.nn.quantized.FloatFunctional):
modules_to_swap.append(name)
else:
_swap_ff_with_fxff(module)
for name in modules_to_swap:
del model._modules[name]
model._modules[name] = torch.ao.nn.quantized.FXFloatFunctional()
def _fuse_fx(
model: GraphModule,
is_qat: bool,
fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None,
backend_config: Union[BackendConfig, Dict[str, Any], None] = None,
) -> GraphModule:
r""" 内部辅助函数,用于在量化准备中融合模块
参数:
model: 来自符号追踪的GraphModule对象 (torch.fx.symbolic_trace)
"""
_check_is_graph_module(model)
return fuse(
model,