准备外汇¶
- class torch.ao.quantization.quantize_fx.prepare_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=None, _equalization_config=None, backend_config=None)[源代码]¶
准备一个用于训练后量化的模型
- Parameters
模型 (*) – torch.nn.Module 模型
qconfig_mapping (*) – 用于配置模型如何量化的QConfigMapping对象,详情请参见
QConfigMappingexample_inputs (*) – 模型前向函数的示例输入, 位置参数的元组(关键字参数也可以作为位置参数传递)
prepare_custom_config (*) – 量化工具的自定义配置。 参见
PrepareCustomConfig了解更多详情_equalization_config (*) – 用于指定如何在模型上执行均衡的配置
backend_config (*) – 配置指定如何在后端中量化操作符,包括如何观察操作符、支持的融合模式、如何插入量化/反量化操作、支持的数据类型等。详情请参见
BackendConfig
- Returns
带有观察者(由 qconfig_mapping 配置)的 GraphModule,准备进行校准
- Return type
示例:
import torch from torch.ao.quantization import get_default_qconfig_mapping from torch.ao.quantization.quantize_fx import prepare_fx class Submodule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 5) def forward(self, x): x = self.linear(x) return x class M(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 5) self.sub = Submodule() def forward(self, x): x = self.linear(x) x = self.sub(x) + x return x # 初始化一个浮点模型 float_model = M().eval() # 定义校准函数 def calibrate(model, data_loader): model.eval() with torch.no_grad(): for image, target in data_loader: model(image) # qconfig 是我们为特定操作插入观察者的配置 # qconfig = get_default_qconfig("fbgemm") # 自定义 qconfig 的示例: # qconfig = torch.ao.quantization.QConfig( # activation=MinMaxObserver.with_args(dtype=torch.qint8), # weight=MinMaxObserver.with_args(dtype=torch.qint8)) # `activation` 和 `weight` 是观察者模块的构造函数 # qconfig_mapping 是量化配置的集合,用户可以通过 qconfig_mapping 为模型中的每个操作(torch op 调用、函数调用、模块调用)设置 qconfig # 以下调用将获取最适合目标为 "fbgemm" 后端的模型的 qconfig_mapping qconfig_mapping = get_default_qconfig_mapping("fbgemm") # 我们可以通过不同的方式自定义 qconfig_mapping。 # 例如,设置全局 qconfig,这意味着我们将为模型中的所有操作使用相同的 qconfig,这可以被其他设置覆盖 # qconfig_mapping = QConfigMapping().set_global(qconfig) # 例如,使用特定的 qconfig 量化 linear 子模块 # qconfig_mapping = QConfigMapping().set_module_name("linear", qconfig) # 例如,使用特定的 qconfig 量化所有 nn.Linear 模块 # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) # 有关更完整的列表,请参阅 :class:`torch.ao.quantization.QConfigMapping` 的文档字符串 # 参数 # example_inputs 是一个输入元组,用于推断模型中输出的类型 # 目前它没有被使用,但请确保 model(*example_inputs) 可以运行 example_inputs = (torch.randn(1, 3, 224, 224),) # TODO: 在为 fbgemm 和 qnnpack 拆分 backend_config 后添加 backend_config # 例如 backend_config = get_default_backend_config("fbgemm") # `prepare_fx` 根据 qconfig_mapping 和 backend_config 在模型中插入观察者。如果 qconfig_mapping 中的操作配置在 backend_config 中受支持(意味着它被目标硬件支持),我们将根据 qconfig_mapping 插入观察者模块,否则 qconfig_mapping 中的配置将被忽略 # # 示例: # 在 qconfig_mapping 中,用户将 linear 模块设置为使用 quint8 进行激活量化和 qint8 进行权重量化: # qconfig = torch.ao.quantization.QConfig( # observer=MinMaxObserver.with_args(dtype=torch.quint8), # weight=MinMaxObserver.with-args(dtype=torch.qint8)) # 注意:当前的 qconfig api 不支持设置输出观察者,但我们可能会扩展此功能以支持这些更细粒度的控制 # # qconfig_mapping = QConfigMapping().set_object_type(torch.nn.Linear, qconfig) # 在 backend config 中,linear 模块也支持此配置: # weighted_int8_dtype_config = DTypeConfig( # input_dtype=torch.quint8, # output_dtype=torch.quint8, # weight_dtype=torch.qint8, # bias_type=torch.float) # linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \ # .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \ # .add_dtype_config(weighted_int8_dtype_config) \ # ... # backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config) # `prepare_fx` 将检查 qconfig_mapping 中用户请求的设置是否被 backend_config 支持,并在模型中插入观察者和伪量化模块 prepared_model = prepare_fx(float_model, qconfig_mapping, example_inputs) # 运行校准 calibrate(prepared_model, sample_inference_data)