Shortcuts

prepare_qat_fx

class torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=None, backend_config=None)[源代码]

准备一个用于量化感知训练的模型

Parameters
  • 模型 (*) – torch.nn.Module 模型

  • qconfig_mapping (*) – 见 prepare_fx()

  • example_inputs (*) – 见 prepare_fx()

  • prepare_custom_config (*) – 见 prepare_fx()

  • backend_config (*) – 见 prepare_fx()

Returns

一个带有伪量化模块的GraphModule(由qconfig_mapping和backend_config配置),准备进行量化感知训练

Return type

图模块

示例:

import torch
from torch.ao.quantization import get_default_qat_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_qat_fx

class Submodule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
    def forward(self, x):
        x = self.linear(x)
        return x

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
        self.sub = Submodule()

    def forward(self, x):
        x = self.linear(x)
        x = self.sub(x) + x
        return x

# 初始化一个浮点模型
float_model = M().train()
# (可选,但首选) 从预训练模型加载权重
# float_model.load_weights(...)

# 定义用于量化感知训练的训练循环
def train_loop(model, train_data):
    model.train()
    for image, target in data_loader:
        ...

# qconfig 是我们为特定操作符插入观察者的配置
# qconfig = get_default_qconfig("fbgemm")
# 自定义 qconfig 的示例:
# qconfig = torch.ao.quantization.QConfig(
#    activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)),
#    weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)))
# `activation` 和 `weight` 是观察者模块的构造函数

# qconfig_mapping 是量化配置的集合,用户可以通过 qconfig_mapping 为模型中的每个操作符
# (torch op 调用、功能调用、模块调用) 设置 qconfig
# 以下调用将获取最适合 "fbgemm" 后端模型的 qconfig_mapping
qconfig_mapping = get_default_qat_qconfig("fbgemm")

# 我们可以通过不同的方式自定义 qconfig_mapping,请查看 :func:`~torch.ao.quantization.prepare_fx` 的文档字符串
# 以了解不同的配置方式

# example_inputs 是一个输入元组,用于推断模型中输出的类型
# 目前它没有被使用,但请确保 model(*example_inputs) 可以运行
example_inputs = (torch.randn(1, 3, 224, 224),)

# TODO: 在拆分 fbgemm 和 qnnpack 的后端配置后添加 backend_config
# 例如 backend_config = get_default_backend_config("fbgemm")
# `prepare_qat_fx` 根据 qconfig_mapping 和 backend_config 在模型中插入观察者,
# 如果 qconfig_mapping 中的操作符配置在 backend_config 中受支持(意味着它被目标硬件支持),
# 我们将根据 qconfig_mapping 插入 fake_quantize 模块,否则 qconfig_mapping 中的配置将被忽略
# 请参阅 :func:`~torch.ao.quantization.prepare_fx` 以详细了解 qconfig_mapping 如何与 backend_config 交互
prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs)
# 运行训练
train_loop(prepared_model, train_loop)
优云智算