Shortcuts

torch.ao.quantization.qconfig 的源代码

from collections import namedtuple
from typing import Optional, Any, Union, Type

import torch
import torch.nn as nn
from torch.ao.quantization.fake_quantize import (
    FakeQuantize,
    FakeQuantizeBase,
    default_fake_quant,
    default_dynamic_fake_quant,
    default_per_channel_weight_fake_quant,
    default_weight_fake_quant,
    default_fused_act_fake_quant,
    default_fused_wt_fake_quant,
    FusedMovingAvgObsFakeQuantize,
    default_fused_per_channel_wt_fake_quant,
    default_embedding_fake_quant,
    default_embedding_fake_quant_4bit,
    fused_wt_fake_quant_range_neg_127_to_127,
    fused_per_channel_wt_fake_quant_range_neg_127_to_127,
)

from .observer import (
    _PartialWrapper,
    MinMaxObserver,
    HistogramObserver,
    MovingAverageMinMaxObserver,
    NoopObserver,
    PlaceholderObserver,
    ReuseInputObserver,
    default_debug_observer,
    default_dynamic_quant_observer,
    default_float_qparams_observer,
    default_float_qparams_observer_4bit,
    default_observer,
    default_per_channel_weight_observer,
    default_placeholder_observer,
    default_weight_observer,
    weight_observer_range_neg_127_to_127,
    per_channel_weight_observer_range_neg_127_to_127,
    default_reuse_input_observer,
    ObserverBase,
)
import warnings
import copy

__all__ = [
    "QConfig",
    # TODO: deprecated, remove
    "QConfigDynamic",
    "default_qconfig",
    "default_debug_qconfig",
    "default_per_channel_qconfig",
    "default_dynamic_qconfig",
    "float16_dynamic_qconfig",
    "float16_static_qconfig",
    "per_channel_dynamic_qconfig",
    "float_qparams_weight_only_qconfig",
    "float_qparams_weight_only_qconfig_4bit",
    "default_quint8_weight_qconfig",
    "default_qat_qconfig",
    "default_dynamic_qat_qconfig",
    "default_weight_only_qconfig",
    "default_activation_only_qconfig",
    "default_qat_qconfig_v2",
    "default_reuse_input_qconfig",
    "default_symmetric_qnnpack_qconfig",
    "default_per_channel_symmetric_qnnpack_qconfig",
    "default_symmetric_qnnpack_qat_qconfig",
    "default_per_channel_symmetric_qnnpack_qat_qconfig",
    "default_embedding_qat_qconfig",
    "default_embedding_qat_qconfig_4bit",
    "get_default_qconfig",
    "get_default_qat_qconfig",
    "get_default_qconfig_dict",
    "get_default_qat_qconfig_dict",
    "QConfigAny",
    "qconfig_equals",

]

[docs]class QConfig(namedtuple('QConfig', ['activation', 'weight'])): """ 描述如何通过提供设置(观察者类)来量化层或网络的一部分,分别用于激活和权重。 注意,QConfig需要包含观察者**类**(如MinMaxObserver)或返回实例的可调用对象,而不是具体的观察者实例本身。 量化准备函数将为每个层多次实例化观察者。 观察者类通常有合理的默认参数,但可以通过`with_args`方法覆盖,该方法的行为类似于functools.partial:: my_qconfig = QConfig( activation=MinMaxObserver.with_args(dtype=torch.qint8), weight=default_observer.with_args(dtype=torch.qint8)) """ def __new__(cls, activation, weight): # 捕获常见错误 if isinstance(activation, nn.Module) or isinstance(weight, nn.Module): raise ValueError("QConfig received observer instance, please pass observer class instead. " + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed") return super().__new__(cls, activation, weight)
class QConfigDynamic(namedtuple('QConfigDynamic', ['activation', 'weight'])): """ 描述如何通过提供设置(观察者类)来动态量化层或网络的一部分,分别用于权重。 类似于QConfig,但用于动态量化。 注意,QConfigDynamic需要包含观察者**类**(如MinMaxObserver)或返回实例的可调用对象,而不是具体的观察者实例本身。 量化函数将为每个层多次实例化观察者。 观察者类通常有合理的默认参数,但可以通过`with_args`方法覆盖,该方法的行为类似于functools.partial:: my_qconfig = QConfigDynamic(weight=default_observer.with_args(dtype=torch.qint8)) """ def __new__(cls, activation=torch.nn.Identity, weight=torch.nn.Identity): # 捕获常见错误 if isinstance(weight, nn.Module): raise ValueError("QConfigDynamic received observer instance, please pass observer class instead. " + "Use MyObserver.with_args(x=1) to override arguments to constructor if needed") warnings.warn("QConfigDynamic is going to be deprecated in PyTorch 1.12, please use QConfig instead") return super().__new__(cls, activation, weight) default_qconfig = QConfig(activation=default_observer, weight=default_weight_observer) """ 默认的qconfig配置。 """ default_debug_qconfig = QConfig(weight=<
优云智算