Shortcuts

准备外汇

class torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=None, _equalization_config=None, backend_config=None)[源代码]

准备一个用于训练后量化的模型

Parameters
  • 模型 (*) – torch.nn.Module 模型

  • qconfig_mapping (*) – 用于配置模型如何量化的QConfigMapping对象,详情请参见QConfigMapping

  • example_inputs (*) – 模型前向函数的示例输入, 位置参数的元组(关键字参数也可以作为位置参数传递)

  • prepare_custom_config (*) – 量化工具的自定义配置。 参见 PrepareCustomConfig 了解更多详情

  • _equalization_config (*) – 用于指定如何在模型上执行均衡的配置

  • backend_config (*) – 配置指定如何在后端中量化操作符,包括如何观察操作符、支持的融合模式、如何插入量化/反量化操作、支持的数据类型等。详情请参见 BackendConfig

Returns

带有观察者(由 qconfig_mapping 配置)的 GraphModule,准备进行校准

Return type

图模块

示例:

import torch
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx

class Submodule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
    def forward(self, x):
        x = self.linear(x)
        return x

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
        self.sub = Submodule()

    def forward(self, x):
        x = self.linear(x)
        x = self.sub(x) + x
        return x

# 初始化一个浮点模型
float_model = M().eval()

# 定义校准函数
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)

# qconfig 是我们为特定操作插入观察者的配置
# qconfig = get_default_qconfig("fbgemm")
# 自定义 qconfig 的示例:
# qconfig = torch.ao.quantization.QConfig(
#    activation=MinMaxObserver.with_args(dtype=torch.qint8),
#    weight=MinMaxObserver.with_args(dtype=torch.qint8))
# `activation` 和 `weight` 是观察者模块的构造函数

# qconfig_mapping 是量化配置的集合,用户可以通过 qconfig_mapping 为模型中的每个操作(torch op 调用、函数调用、模块调用)设置 qconfig
# 以下调用将获取最适合目标为 "fbgemm" 后端的模型的 qconfig_mapping
qconfig_mapping = get_default_qconfig_mapping("fbgemm")

# 我们可以通过不同的方式自定义 qconfig_mapping。
# 例如,设置全局 qconfig,这意味着我们将为模型中的所有操作使用相同的 qconfig,这可以被其他设置覆盖
# qconfig_mapping = QConfigMapping().set_global(qconfig)
# 例如,使用特定的 qconfig 量化 linear 子模块
# qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig)
# 例如,使用特定的 qconfig 量化所有 nn.Linear 模块
# qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
# 有关更完整的列表,请参阅 :class:`torch.ao.quantization.QConfigMapping` 的文档字符串
# 参数

# example_inputs 是一个输入元组,用于推断模型中输出的类型
# 目前它没有被使用,但请确保 model(*example_inputs) 可以运行
example_inputs = (torch.randn(1, 3, 224, 224),)

# TODO: 在为 fbgemm 和 qnnpack 拆分 backend_config 后添加 backend_config
# 例如 backend_config = get_default_backend_config("fbgemm")
# `prepare_fx` 根据 qconfig_mapping 和 backend_config 在模型中插入观察者。如果 qconfig_mapping 中的操作配置在 backend_config 中受支持(意味着它被目标硬件支持),我们将根据 qconfig_mapping 插入观察者模块,否则 qconfig_mapping 中的配置将被忽略
#
# 示例:
# 在 qconfig_mapping 中,用户将 linear 模块设置为使用 quint8 进行激活量化和 qint8 进行权重量化:
# qconfig = torch.ao.quantization.QConfig(
#     observer=MinMaxObserver.with_args(dtype=torch.quint8),
#     weight=MinMaxObserver.with-args(dtype=torch.qint8))
# 注意:当前的 qconfig api 不支持设置输出观察者,但我们可能会扩展此功能以支持这些更细粒度的控制
#
# qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig)
# 在 backend config 中,linear 模块也支持此配置:
# weighted_int8_dtype_config = DTypeConfig(
#   input_dtype=torch.quint8,
#   output_dtype=torch.quint8,
#   weight_dtype=torch.qint8,
#   bias_type=torch.float)

# linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \
#    .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
#    .add_dtype_config(weighted_int8_dtype_config) \
#    ...

# backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config)
# `prepare_fx` 将检查 qconfig_mapping 中用户请求的设置是否被 backend_config 支持,并在模型中插入观察者和伪量化模块
prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs)
# 运行校准
calibrate(prepared_model, sample_inference_data)
优云智算