Shortcuts

torch.ao.quantization.stubs 的源代码


from torch import nn

[docs]class QuantStub(nn.Module): r"""量化存根模块,在校准之前,这与观察者相同, 它将在 `convert` 中被交换为 `nnq.Quantize`。 参数: qconfig: 张量的量化配置, 如果未提供 qconfig,我们将从父模块获取 qconfig """ def __init__(self, qconfig=None): super().__init__() if qconfig: self.qconfig = qconfig def forward(self, x): return x
[docs]class DeQuantStub(nn.Module): r"""反量化存根模块,在校准之前,这与恒等映射相同, 它将在 `convert` 中被交换为 `nnq.DeQuantize`。 参数: qconfig: 张量的量化配置, 如果未提供 qconfig,我们将从父模块获取 qconfig """ def __init__(self, qconfig=None): super().__init__() if qconfig: self.qconfig = qconfig def forward(self, x): return x
[docs]class QuantWrapper(nn.Module): r"""一个包装类,包装输入模块,添加 QuantStub 和 DeQuantStub 并在调用模块时围绕调用 quant 和 dequant 模块。 这是由 `quantization` 实用函数用于添加 quant 和 dequant 模块,在 `convert` 函数之前,`QuantStub` 只是一个观察者, 它观察输入张量,在 `convert` 之后,`QuantStub` 将被交换为 `nnq.Quantize`,它执行实际的量化。同样 适用于 `DeQuantStub`。 """ quant: QuantStub dequant: DeQuantStub module: nn.Module def __init__(self, module): super().__init__() qconfig = getattr(module, "qconfig", None) self.add_module('quant', QuantStub(qconfig)) self.add_module('dequant', DeQuantStub(qconfig)) self.add_module('module', module) self.train(module.training) def forward(self, X): X = self.quant(X) X = self.module(X) return self.dequant(X)
优云智算