torch.ao.quantization.fuse_modules 的源代码
```html
import copy import torch.nn as nn from torch.ao.quantization.fuser_method_mappings import get_fuser_method # 为了向后兼容 from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn # noqa: F401 from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn_relu # noqa: F401 from torch.nn.utils.parametrize import type_before_parametrizations from typing import List, Optional __all__ = [ "fuse_known_modules", "fuse_modules", "fuse_modules_qat", ] # getattr的泛化 def _get_module(model, submodule_key): tokens = submodule_key.split('.') cur_mod = model for s in tokens: cur_mod = getattr(cur_mod, s) return cur_mod # setattr的泛化 def _set_module(model, submodule_key, module): tokens = submodule_key.split('.') sub_tokens = tokens[:-1] cur_mod = model for s in sub_tokens: cur_mod = getattr(cur_mod, s) setattr(cur_mod, tokens[-1], module) def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None): r"""返回已知融合模块的列表。 返回一个模块列表,该列表融合了输入模块列表中指定的操作。 仅融合以下模块序列: conv, bn conv, bn, relu conv, relu linear, bn linear, relu 对于这些序列,输出模块列表中的第一个元素执行融合操作。其余元素设置为nn.Identity() """ types = tuple(type_before_parametrizations(m) for m in mod_list) fuser_method = get_fuser_method(types, additional_fuser_method_mapping) if fuser_method is None: raise NotImplementedError(f"无法融合模块: {types}") new_mod : List[Optional[nn.Module]] = [None] * len(mod_list) fused = fuser_method(is_qat, *mod_list) # 注意:在以下两个for循环中未处理的forward hooks将在融合后丢失 # 将基础模块的pre forward hooks移动到融合后的模块 for pre_hook_fn in mod_list[0]._forward_pre_hooks.values(): fused.register_forward_pre_hook(pre_hook_fn) mod_list[0]._forward_pre_hooks.clear() # 将最后一个模块的post forward hooks移动到融合后的模块 for hook_fn in mod_list[-1]._forward_hooks.values(): fused.register_forward_hook(hook_fn) mod_list[-1]._forward_hooks.clear() new_mod[0] = fused for i in range(1, len(mod_list)): identity = nn.Identity() identity.training = mod_list[0].training new_mod[i] = identity return new_mod def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None): if fuse_custom_config_dict is None: fuse_custom_config_dict = {} additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {}) mod_list = [] for item in modules_to_fuse: mod_list.append(_get_module(model, item)) # 融合模块列表 new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping) # 用融合后的模块列表替换原始模块列表 for i, item in enumerate(modules_to_fuse): _set_module(model<span