torch.ao.quantization.quantize 的源代码
import copy
import itertools
import warnings
import torch
import torch.nn as nn
import torch.ao.nn.quantized as nnq
from torch.ao.nn.intrinsic import _FusedModule
from torch.ao.quantization.quantization_mappings import (
get_default_dynamic_quant_module_mappings,
get_default_static_quant_module_mappings,
get_default_static_quant_reference_module_mappings,
get_default_qat_module_mappings,
get_default_qconfig_propagation_list,
no_observer_set,
_has_special_act_post_process,
_get_special_act_post_process,
)
from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations
from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
from torch.ao.quantization.qconfig import (
_add_module_to_qconfig_obs_ctr,
default_dynamic_qconfig,
float16_dynamic_qconfig,
float_qparams_weight_only_qconfig,
float_qparams_weight_only_qconfig_4bit,
_activation_is_memoryless)
from torch.nn.utils.parametrize import type_before_parametrizations
from torch.ao.quantization.observer import _is_activation_post_process
# TODO 一旦不再需要避免SEV,就删除这个
from torch.ao.quantization.observer import ( # noqa: F401
_is_activation_post_process as is_activation_post_process
)
__all__ = [
"get_default_custom_config_dict",
"propagate_qconfig_",
"add_quant_dequant",
"prepare",
"quantize",
"quantize_dynamic",
"prepare_qat",
"quantize_qat",
"convert",
"swap_module",
]
_DEFAULT_CUSTOM_CONFIG_DICT = {
'float_to_observed_custom_module_class': {
nn.LSTM: nn.quantizable.LSTM,
nn.MultiheadAttention: nn.quantizable.MultiheadAttention,
},
'observed_to_quantized_custom_module_class': {
nn.quantizable.LSTM: nn.quantized.LSTM,
nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention,
}
}
def get_default_custom_config_dict():
r"""定义默认的自定义配置字典。
"""
return _DEFAULT_CUSTOM_CONFIG_DICT
def _propagate_qconfig_helper(module, qconfig_dict,
qconfig_parent=None, prefix='', prepare_custom_config_dict=None):
r"""这是 `propagate_qconfig_` 的辅助函数
参数:
module: 输入模块
qconfig_dict: 从子模块名称映射到量化配置的字典
qconfig_parent: 父模块的量化配置,当当前模块没有指定配置时,将回退到此配置
prefix: 当前模块的对应前缀,用作 qconfig_dict 中的键
prepare_custom_config_dict: 自定义模块处理的配置字典
参见 :func:`~torch.ao.quantization.prepare_fx` 的文档
返回:
None, 模块在原地修改并附加了 qconfig
"""
module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent)
module_qconfig = qconfig_dict.get(prefix, module_qconfig)
module_qconfig = getattr(module, 'qconfig', module_qconfig)
torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module)
qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module)
module.qconfig = qconfig_with_device_check
for name, child in module.named_children():
module_prefix = prefix + '.' + name if prefix else name
# 如果子模块不可追踪,则不传播 qconfig
if prepare_custom_config_dict is None or not (
name in prepare_custom_config_dict.get("non_traceable_module_name", [])
or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", [])
):
_propagate_qconfig_helper(
child, qconfig_dict, qconfig_with_device_check, module_prefix
)
[docs]def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None):
r"""通过模块层次结构传播 qconfig 并在每个叶子模块上附加 `qconfig` 属性
参数:
module: 输入模块
qconfig_dict: 从子模块名称或类型映射到量化配置的字典,qconfig 适用于给定模块的所有子模块,除非子模块已经指定了 qconfig
prepare_custom_config_dict: 自定义模块处理的配置字典
参见