prepare_qat_fx¶
- class torch.ao.quantization.quantize_fx.prepare_qat_fx(model, qconfig_mapping, example_inputs, prepare_custom_config=None, backend_config=None)[源代码]¶
- 准备一个用于量化感知训练的模型 - Parameters
- 模型 (*) – torch.nn.Module 模型 
- qconfig_mapping (*) – 见 - prepare_fx()
- example_inputs (*) – 见 - prepare_fx()
- prepare_custom_config (*) – 见 - prepare_fx()
- backend_config (*) – 见 - prepare_fx()
 
- Returns
- 一个带有伪量化模块的GraphModule(由qconfig_mapping和backend_config配置),准备进行量化感知训练 
- Return type
 - 示例: - import torch from torch.ao.quantization import get_default_qat_qconfig_mapping from torch.ao.quantization.quantize_fx import prepare_qat_fx class Submodule(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 5) def forward(self, x): x = self.linear(x) return x class M(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(5, 5) self.sub = Submodule() def forward(self, x): x = self.linear(x) x = self.sub(x) + x return x # 初始化一个浮点模型 float_model = M().train() # (可选,但首选) 从预训练模型加载权重 # float_model.load_weights(...) # 定义用于量化感知训练的训练循环 def train_loop(model, train_data): model.train() for image, target in data_loader: ... # qconfig 是我们为特定操作符插入观察者的配置 # qconfig = get_default_qconfig("fbgemm") # 自定义 qconfig 的示例: # qconfig = torch.ao.quantization.QConfig( # activation=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8)), # weight=FakeQuantize.with_args(observer=MinMaxObserver.with_args(dtype=torch.qint8))) # `activation` 和 `weight` 是观察者模块的构造函数 # qconfig_mapping 是量化配置的集合,用户可以通过 qconfig_mapping 为模型中的每个操作符 # (torch op 调用、功能调用、模块调用) 设置 qconfig # 以下调用将获取最适合 "fbgemm" 后端模型的 qconfig_mapping qconfig_mapping = get_default_qat_qconfig("fbgemm") # 我们可以通过不同的方式自定义 qconfig_mapping,请查看 :func:`~torch.ao.quantization.prepare_fx` 的文档字符串 # 以了解不同的配置方式 # example_inputs 是一个输入元组,用于推断模型中输出的类型 # 目前它没有被使用,但请确保 model(*example_inputs) 可以运行 example_inputs = (torch.randn(1, 3, 224, 224),) # TODO: 在拆分 fbgemm 和 qnnpack 的后端配置后添加 backend_config # 例如 backend_config = get_default_backend_config("fbgemm") # `prepare_qat_fx` 根据 qconfig_mapping 和 backend_config 在模型中插入观察者, # 如果 qconfig_mapping 中的操作符配置在 backend_config 中受支持(意味着它被目标硬件支持), # 我们将根据 qconfig_mapping 插入 fake_quantize 模块,否则 qconfig_mapping 中的配置将被忽略 # 请参阅 :func:`~torch.ao.quantization.prepare_fx` 以详细了解 qconfig_mapping 如何与 backend_config 交互 prepared_model = prepare_qat_fx(float_model, qconfig_mapping, example_inputs) # 运行训练 train_loop(prepared_model, train_loop) 
