Shortcuts

torch.ao.quantization.fuse_modules 的源代码

```html
import copy

import torch.nn as nn

from torch.ao.quantization.fuser_method_mappings import get_fuser_method
# 为了向后兼容
from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn  # noqa: F401
from torch.ao.quantization.fuser_method_mappings import fuse_conv_bn_relu  # noqa: F401
from torch.nn.utils.parametrize import type_before_parametrizations

from typing import List, Optional

__all__ = [
    "fuse_known_modules",
    "fuse_modules",
    "fuse_modules_qat",
]

# getattr的泛化
def _get_module(model, submodule_key):
    tokens = submodule_key.split('.')
    cur_mod = model
    for s in tokens:
        cur_mod = getattr(cur_mod, s)
    return cur_mod

# setattr的泛化
def _set_module(model, submodule_key, module):
    tokens = submodule_key.split('.')
    sub_tokens = tokens[:-1]
    cur_mod = model
    for s in sub_tokens:
        cur_mod = getattr(cur_mod, s)

    setattr(cur_mod, tokens[-1], module)

def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
    r"""返回已知融合模块的列表。

    返回一个模块列表,该列表融合了输入模块列表中指定的操作。

    仅融合以下模块序列:
    conv, bn
    conv, bn, relu
    conv, relu
    linear, bn
    linear, relu
    对于这些序列,输出模块列表中的第一个元素执行融合操作。其余元素设置为nn.Identity()
    """
    types = tuple(type_before_parametrizations(m) for m in mod_list)
    fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
    if fuser_method is None:
        raise NotImplementedError(f"无法融合模块: {types}")
    new_mod : List[Optional[nn.Module]] = [None] * len(mod_list)
    fused = fuser_method(is_qat, *mod_list)
    # 注意:在以下两个for循环中未处理的forward hooks将在融合后丢失
    # 将基础模块的pre forward hooks移动到融合后的模块
    for pre_hook_fn in mod_list[0]._forward_pre_hooks.values():
        fused.register_forward_pre_hook(pre_hook_fn)
    mod_list[0]._forward_pre_hooks.clear()
    # 将最后一个模块的post forward hooks移动到融合后的模块
    for hook_fn in mod_list[-1]._forward_hooks.values():
        fused.register_forward_hook(hook_fn)
    mod_list[-1]._forward_hooks.clear()
    new_mod[0] = fused

    for i in range(1, len(mod_list)):
        identity = nn.Identity()
        identity.training = mod_list[0].training
        new_mod[i] = identity

    return new_mod

def _fuse_modules_helper(model, modules_to_fuse, is_qat, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
    if fuse_custom_config_dict is None:
        fuse_custom_config_dict = {}
    additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
    mod_list = []
    for item in modules_to_fuse:
        mod_list.append(_get_module(model, item))

    # 融合模块列表
    new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)

    # 用融合后的模块列表替换原始模块列表
    for i, item in enumerate(modules_to_fuse):
        _set_module(model<span