融合模块¶
- class torch.ao.quantization.fuse_modules.fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=<function fuse_known_modules>, fuse_custom_config_dict=None)[源代码]¶
将模块列表合并为一个模块。
仅融合以下模块序列: conv, bn conv, bn, relu conv, relu linear, relu bn, relu 所有其他序列保持不变。 对于这些序列,将列表中的第一个项目替换为融合模块,并将剩余的模块替换为恒等映射。
- Parameters
model – 包含要融合模块的模型
modules_to_fuse – 要融合的模块名称列表。如果只有一个模块列表要融合,也可以是一个字符串列表。
inplace – 布尔值,指定是否在模型上就地进行融合,默认情况下返回一个新模型
fuser_func – 一个函数,接受一个模块列表并输出相同长度的融合模块列表。例如, fuser_func([convModule, BNModule]) 返回列表 [ConvBNModule, nn.Identity()] 默认为 torch.ao.quantization.fuse_known_modules
fuse_custom_config_dict – 融合的自定义配置
# fuse_custom_config_dict 的示例 fuse_custom_config_dict = { # 额外的 fuser_method 映射 "additional_fuser_method_mapping": { (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn }, }
- Returns
带有融合模块的模型。如果 inplace=True,则会创建一个新副本。
示例:
>>> m = M().eval() >>> # m 是一个包含以下子模块的模块 >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input) >>> m = M().eval() >>> # 或者提供一个单独的模块列表进行融合 >>> modules_to_fuse = ['conv1', 'bn1', 'relu1'] >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse) >>> output = fused_m(input)