torch.ao.quantization.stubs 的源代码
from torch import nn
[docs]class QuantStub(nn.Module):
r"""量化存根模块,在校准之前,这与观察者相同,
它将在 `convert` 中被交换为 `nnq.Quantize`。
参数:
qconfig: 张量的量化配置,
如果未提供 qconfig,我们将从父模块获取 qconfig
"""
def __init__(self, qconfig=None):
super().__init__()
if qconfig:
self.qconfig = qconfig
def forward(self, x):
return x
[docs]class DeQuantStub(nn.Module):
r"""反量化存根模块,在校准之前,这与恒等映射相同,
它将在 `convert` 中被交换为 `nnq.DeQuantize`。
参数:
qconfig: 张量的量化配置,
如果未提供 qconfig,我们将从父模块获取 qconfig
"""
def __init__(self, qconfig=None):
super().__init__()
if qconfig:
self.qconfig = qconfig
def forward(self, x):
return x
[docs]class QuantWrapper(nn.Module):
r"""一个包装类,包装输入模块,添加 QuantStub 和
DeQuantStub 并在调用模块时围绕调用 quant 和 dequant 模块。
这是由 `quantization` 实用函数用于添加 quant 和
dequant 模块,在 `convert` 函数之前,`QuantStub` 只是一个观察者,
它观察输入张量,在 `convert` 之后,`QuantStub`
将被交换为 `nnq.Quantize`,它执行实际的量化。同样
适用于 `DeQuantStub`。
"""
quant: QuantStub
dequant: DeQuantStub
module: nn.Module
def __init__(self, module):
super().__init__()
qconfig = getattr(module, "qconfig", None)
self.add_module('quant', QuantStub(qconfig))
self.add_module('dequant', DeQuantStub(qconfig))
self.add_module('module', module)
self.train(module.training)
def forward(self, X):
X = self.quant(X)
X = self.module(X)
return self.dequant(X)