Shortcuts

torch.ao.quantization.fake_quantize 的源代码

"""实现用于执行伪量化的模块。"""

import torch
from torch.nn import Module
from torch.ao.quantization.observer import (
    MovingAverageMinMaxObserver,
    HistogramObserver,
    MovingAveragePerChannelMinMaxObserver,
    FixedQParamsObserver,
    default_fixed_qparams_range_0to1_observer,
    default_fixed_qparams_range_neg1to1_observer,
    _with_args,
)
import re
from abc import ABC, abstractmethod
from typing import Any, Tuple

__all__ = [
    "FakeQuantizeBase",
    "FakeQuantize",
    "FixedQParamsFakeQuantize",
    "FusedMovingAvgObsFakeQuantize",
    "disable_fake_quant",
    "disable_observer",
    "enable_fake_quant",
    "enable_observer",
    "default_fake_quant",
    "default_weight_fake_quant",
    "default_dynamic_fake_quant",
    "default_fixed_qparams_range_neg1to1_fake_quant",
    "default_fixed_qparams_range_0to1_fake_quant",
    "default_symmetric_fixed_qparams_fake_quant",
    "default_affine_fixed_qparams_fake_quant",
    "default_per_channel_weight_fake_quant",
    "default_embedding_fake_quant",
    "default_embedding_fake_quant_4bit",
    "default_histogram_fake_quant",
    "default_fused_act_fake_quant",
    "default_fused_wt_fake_quant",
    "default_fused_per_channel_wt_fake_quant",
    "fused_wt_fake_quant_range_neg_127_to_127",
    "fused_per_channel_wt_fake_quant_range_neg_127_to_127",
]

def _is_per_channel(qscheme: 'torch.qscheme') -> bool:
    return qscheme in [torch.per_channel_symmetric, torch.per_channel_affine, torch.per_channel_affine_float_qparams]

def _is_per_tensor(qscheme: 'torch.qscheme') -> bool:
    return qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]

def _is_symmetric_quant(qscheme: 'torch.qscheme') -> bool:
    return qscheme in [torch.per_tensor_symmetric, torch.per_channel_symmetric]

def _is_float_qparams(qscheme: 'torch.qscheme') -> bool:
    return qscheme in [torch.per_channel_affine_float_qparams, ]

[docs]class FakeQuantizeBase(ABC, Module): r"""伪量化模块的基类。""" 伪量化模块的基类 任何伪量化实现都应该继承此类。 具体的伪量化模块应遵循相同的API。在前向传播中,它们将更新观察到的张量的统计信息并伪量化输入。它们还应提供一个`calculate_qparams`函数,该函数根据收集的统计信息计算量化参数。 """ fake_quant_enabled: torch.Tensor observer_enabled: torch.Tensor def __init__(self): """设置fake_quant_enabled和observer_enabled。""" super().__init__() # fake_quant_enabled和observer_enabled是缓冲区,以支持它们在DDP中的复制。数据类型为uint8,因为NCCL不支持布尔张量。 self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8)) self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8)) @abstractmethod def forward(self, x): pass @abstractmethod def calculate_qparams(self, **kwargs): pass @torch.jit.export def enable_fake_quant(self, enabled: bool = True) -> None: self.fake_quant_enabled[0] = 1 if enabled else 0 @torch.jit.export def disable_fake_quant(self): self.enable_fake_quant(False) @torch.jit.export def enable_observer(self, enabled: bool = True) -> None: self.observer_enabled[0] = 1 if enabled else 0 @torch.jit.export def disable_observer(self): self.enable_observer(False) @classmethod def with_args(cls, **
优云智算