torch.ao.quantization.qconfig 的源代码
from collections import namedtuple
from typing import Optional, Any, Union, Type
import torch
import torch.nn as nn
from torch.ao.quantization.fake_quantize import (
FakeQuantize,
FakeQuantizeBase,
default_fake_quant,
default_dynamic_fake_quant,
default_per_channel_weight_fake_quant,
default_weight_fake_quant,
default_fused_act_fake_quant,
default_fused_wt_fake_quant,
FusedMovingAvgObsFakeQuantize,
default_fused_per_channel_wt_fake_quant,
default_embedding_fake_quant,
default_embedding_fake_quant_4bit,
fused_wt_fake_quant_range_neg_127_to_127,
fused_per_channel_wt_fake_quant_range_neg_127_to_127,
)
from .observer import (
_PartialWrapper,
MinMaxObserver,
HistogramObserver,
MovingAverageMinMaxObserver,
NoopObserver,
PlaceholderObserver,
ReuseInputObserver,
default_debug_observer,
default_dynamic_quant_observer,
default_float_qparams_observer,
default_float_qparams_observer_4bit,
default_observer,
default_per_channel_weight_observer,
default_placeholder_observer,
default_weight_observer,
weight_observer_range_neg_127_to_127,
per_channel_weight_observer_range_neg_127_to_127,
default_reuse_input_observer,
ObserverBase,
)
import warnings
import copy
__all__ = [
"QConfig",
# TODO: deprecated, remove
"QConfigDynamic",
"default_qconfig",
"default_debug_qconfig",
"default_per_channel_qconfig",
"default_dynamic_qconfig",
"float16_dynamic_qconfig",
"float16_static_qconfig",
"per_channel_dynamic_qconfig",
"float_qparams_weight_only_qconfig",
"float_qparams_weight_only_qconfig_4bit",
"default_quint8_weight_qconfig",
"default_qat_qconfig",
"default_dynamic_qat_qconfig",
"default_weight_only_qconfig",
"default_activation_only_qconfig",
"default_qat_qconfig_v2",
"default_reuse_input_qconfig",
"default_symmetric_qnnpack_qconfig",
"default_per_channel_symmetric_qnnpack_qconfig",
"default_symmetric_qnnpack_qat_qconfig",
"default_per_channel_symmetric_qnnpack_qat_qconfig",
"default_embedding_qat_qconfig",
"default_embedding_qat_qconfig_4bit",
"get_default_qconfig",
"get_default_qat_qconfig",
"get_default_qconfig_dict",
"get_default_qat_qconfig_dict",
"QConfigAny",
"qconfig_equals",
]
[docs]class QConfig(namedtuple('QConfig', ['activation', 'weight'])):
"""
描述如何通过提供设置(观察者类)来量化层或网络的一部分,分别用于激活和权重。
注意,QConfig需要包含观察者**类**(如MinMaxObserver)或返回实例的可调用对象,而不是具体的观察者实例本身。
量化准备函数将为每个层多次实例化观察者。
观察者类通常有合理的默认参数,但可以通过`with_args`方法覆盖,该方法的行为类似于functools.partial::
my_qconfig = QConfig(
activation=MinMaxObserver.with_args(dtype=torch.qint8),
weight=default_observer.with_args(dtype=torch.qint8))
"""
def __new__(cls, activation, weight):
# 捕获常见错误
if isinstance(activation, nn.Module) or isinstance(weight, nn.Module):
raise ValueError("QConfig received observer instance, please pass observer class instead. " +
"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
return super().__new__(cls, activation, weight)
class QConfigDynamic(namedtuple('QConfigDynamic', ['activation', 'weight'])):
"""
描述如何通过提供设置(观察者类)来动态量化层或网络的一部分,分别用于权重。
类似于QConfig,但用于动态量化。
注意,QConfigDynamic需要包含观察者**类**(如MinMaxObserver)或返回实例的可调用对象,而不是具体的观察者实例本身。
量化函数将为每个层多次实例化观察者。
观察者类通常有合理的默认参数,但可以通过`with_args`方法覆盖,该方法的行为类似于functools.partial::
my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8))
"""
def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity):
# 捕获常见错误
if isinstance(weight, nn.Module):
raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " +
"Use MyObserver.with_args(x=1) to override arguments to constructor if needed")
warnings.warn("QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead")
return super().__new__(cls, activation, weight)
default_qconfig = QConfig(activation=default_observer,
weight=default_weight_observer)
"""
默认的qconfig配置。
"""
default_debug_qconfig = QConfig(weight=<