torch.ao.quantization.qconfig_mapping 的源代码
from __future__ import annotations
from collections import OrderedDict
from typing import Any, Callable, Dict, Tuple, Union, List
import torch
from .fake_quantize import (
default_weight_fake_quant,
FixedQParamsFakeQuantize,
)
from .observer import (
_PartialWrapper,
default_fixed_qparams_range_0to1_observer,
default_fixed_qparams_range_neg1to1_observer,
default_placeholder_observer,
default_weight_observer,
)
from .qconfig import (
default_reuse_input_qconfig,
default_symmetric_qnnpack_qconfig,
default_symmetric_qnnpack_qat_qconfig,
get_default_qconfig,
get_default_qat_qconfig,
QConfig,
QConfigAny,
default_quint8_weight_qconfig
)
__all__ = [
"get_default_qconfig_mapping",
"get_default_qat_qconfig_mapping",
"QConfigMapping",
]
# TODO: 用这些常量替换所有用法
_GLOBAL_DICT_KEY = ""
_OBJECT_TYPE_DICT_KEY = "object_type"
_MODULE_NAME_REGEX_DICT_KEY = "module_name_regex"
_MODULE_NAME_DICT_KEY = "module_name"
_MODULE_NAME_OBJECT_TYPE_ORDER_DICT_KEY = "module_name_object_type_order"
# TODO: 从 BackendConfig 派生此映射
_FIXED_QPARAMS_OP_TO_OBSERVER: Dict[Union[Callable, str], _PartialWrapper] = {
torch.nn.Hardsigmoid: default_fixed_qparams_range_0to1_observer,
torch.nn.functional.hardsigmoid: default_fixed_qparams_range_0to1_observer,
"hardsigmoid": default_fixed_qparams_range_0to1_observer,
"hardsigmoid_": default_fixed_qparams_range_0to1_observer,
torch.nn.Sigmoid: default_fixed_qparams_range_0to1_observer,
torch.sigmoid: default_fixed_qparams_range_0to1_observer,
"sigmoid": default_fixed_qparams_range_0to1_observer,
"sigmoid_": default_fixed_qparams_range_0to1_observer,
torch.nn.Softmax: default_fixed_qparams_range_0to1_observer,
torch.nn.Tanh: default_fixed_qparams_range_neg1to1_observer,
torch.tanh: default_fixed_qparams_range_neg1to1_observer,
"tanh": default_fixed_qparams_range_neg1to1_observer,
"tanh_": default_fixed_qparams_range_neg1to1_observer,
}
def _get_default_qconfig_mapping(is_qat: bool, backend: str, version: int) -> QConfigMapping:
"""
返回给定量化类型和后端的默认 QConfigMapping。
"""
if is_qat:
qconfig = get_default_qat_qconfig(backend, version)
else:
qconfig = get_default_qconfig(backend, version)
default_weight = default_weight_fake_quant if is_qat else default_weight_observer
# default_per_channel_weight_observer 目前与 fbgemm 后端不兼容
# 因此我们必须将权重观察器修改为 default_weight_observer 或其他
# 支持的张量观察器。
# 参见 https://github.com/pytorch/pytorch/issues/47535
if backend in ("fbgemm", "x86"):
qconfig_transpose = QConfig(activation=qconfig.activation, weight=default_weight)
else:
qconfig_transpose = qconfig
# 目前 layernorm 仅支持浮点权重
# 我们必须添加这个,否则会有一个额外的量化-反量化对
qconfig_layernorm = QConfig(activation=qconfig.activation, weight=default_placeholder_observer)
qconfig_mapping = QConfigMapping() \
.set_global(qconfig) \
.set_object_type("reshape", default_reuse_input_qconfig) \
.set_object_type(torch.nn.ConvTranspose1d, qconfig_transpose) \
.set_object_type(torch.nn.ConvTranspose2d, qconfig_transpose) \
.set_object_type(torch.nn.ConvTranspose3d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose1d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose2d, qconfig_transpose) \
.set_object_type(torch.nn.functional.conv_transpose3d, qconfig_transpose) \