Shortcuts

torch.ao.quantization.quantize 的源代码

import copy
import itertools
import warnings

import torch
import torch.nn as nn
import torch.ao.nn.quantized as nnq
from torch.ao.nn.intrinsic import _FusedModule

from torch.ao.quantization.quantization_mappings import (
    get_default_dynamic_quant_module_mappings,
    get_default_static_quant_module_mappings,
    get_default_static_quant_reference_module_mappings,
    get_default_qat_module_mappings,
    get_default_qconfig_propagation_list,
    no_observer_set,
    _has_special_act_post_process,
    _get_special_act_post_process,
)
from .utils import get_qparam_dict, has_no_children_ignoring_parametrizations
from torch.ao.quantization.stubs import DeQuantStub, QuantWrapper
from torch.ao.quantization.qconfig import (
    _add_module_to_qconfig_obs_ctr,
    default_dynamic_qconfig,
    float16_dynamic_qconfig,
    float_qparams_weight_only_qconfig,
    float_qparams_weight_only_qconfig_4bit,
    _activation_is_memoryless)
from torch.nn.utils.parametrize import type_before_parametrizations
from torch.ao.quantization.observer import _is_activation_post_process

# TODO 一旦不再需要避免SEV,就删除这个
from torch.ao.quantization.observer import (   # noqa: F401
    _is_activation_post_process as is_activation_post_process
)

__all__ = [
    "get_default_custom_config_dict",
    "propagate_qconfig_",
    "add_quant_dequant",
    "prepare",
    "quantize",
    "quantize_dynamic",
    "prepare_qat",
    "quantize_qat",
    "convert",
    "swap_module",
]

_DEFAULT_CUSTOM_CONFIG_DICT = {
    'float_to_observed_custom_module_class': {
        nn.LSTM: nn.quantizable.LSTM,
        nn.MultiheadAttention: nn.quantizable.MultiheadAttention,
    },
    'observed_to_quantized_custom_module_class': {
        nn.quantizable.LSTM: nn.quantized.LSTM,
        nn.quantizable.MultiheadAttention: nn.quantized.MultiheadAttention,
    }
}

def get_default_custom_config_dict():
    r"""定义默认的自定义配置字典。
    """
    return _DEFAULT_CUSTOM_CONFIG_DICT

def _propagate_qconfig_helper(module, qconfig_dict,
                              qconfig_parent=None, prefix='', prepare_custom_config_dict=None):
    r"""这是 `propagate_qconfig_` 的辅助函数

    参数:
        module: 输入模块
        qconfig_dict: 从子模块名称映射到量化配置的字典
        qconfig_parent: 父模块的量化配置,当当前模块没有指定配置时,将回退到此配置
        prefix: 当前模块的对应前缀,用作 qconfig_dict 中的键
        prepare_custom_config_dict: 自定义模块处理的配置字典
                                   参见 :func:`~torch.ao.quantization.prepare_fx` 的文档

    返回:
        None, 模块在原地修改并附加了 qconfig
    """

    module_qconfig = qconfig_dict.get(type_before_parametrizations(module), qconfig_parent)
    module_qconfig = qconfig_dict.get(prefix, module_qconfig)
    module_qconfig = getattr(module, 'qconfig', module_qconfig)

    torch.ao.quantization.qconfig._assert_valid_qconfig(module_qconfig, module)

    qconfig_with_device_check = _add_module_to_qconfig_obs_ctr(module_qconfig, module)
    module.qconfig = qconfig_with_device_check

    for name, child in module.named_children():
        module_prefix = prefix + '.' + name if prefix else name
        # 如果子模块不可追踪,则不传播 qconfig
        if prepare_custom_config_dict is None or not (
            name in prepare_custom_config_dict.get("non_traceable_module_name", [])
            or type(child) in prepare_custom_config_dict.get("non_traceable_module_class", [])
        ):
            _propagate_qconfig_helper(
                child, qconfig_dict, qconfig_with_device_check, module_prefix
            )

[docs]def propagate_qconfig_(module, qconfig_dict=None, prepare_custom_config_dict=None): r"""通过模块层次结构传播 qconfig 并在每个叶子模块上附加 `qconfig` 属性 参数: module: 输入模块 qconfig_dict: 从子模块名称或类型映射到量化配置的字典,qconfig 适用于给定模块的所有子模块,除非子模块已经指定了 qconfig prepare_custom_config_dict: 自定义模块处理的配置字典 参见