Shortcuts

torch.nn.utils.fusion 的源代码

from __future__ import annotations

import copy
from typing import Optional, Tuple, TypeVar

import torch

__all__ = ['fuse_conv_bn_eval', 'fuse_conv_bn_weights', 'fuse_linear_bn_eval', 'fuse_linear_bn_weights']

ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd")
LinearT = TypeVar("LinearT", bound="torch.nn.Linear")

[docs]def fuse_conv_bn_eval(conv: ConvT, bn: torch.nn.modules.batchnorm._BatchNorm, transpose: bool = False) -> ConvT: r"""将一个卷积模块和一个BatchNorm模块融合成一个单一的、新的卷积模块。 参数: conv (torch.nn.modules.conv._ConvNd): 一个卷积模块。 bn (torch.nn.modules.batchnorm._BatchNorm): 一个BatchNorm模块。 transpose (bool, 可选): 如果为True,转置卷积权重。默认为False。 返回: torch.nn.modules.conv._ConvNd: 融合后的卷积模块。 .. 注意:: 两个模块 ``conv`` 和 ``bn`` 都必须处于eval模式,并且 ``bn`` 必须已经计算了其运行缓冲区。 """ assert not (conv.training or bn.training), "仅在eval模式下进行融合!" fused_conv = copy.deepcopy(conv) assert bn.running_mean is not None and bn.running_var is not None fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights( fused_conv.weight, fused_conv.bias, bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose) return fused_conv
[docs]def fuse_conv_bn_weights( conv_w: torch.Tensor, conv_b: Optional[torch.Tensor], bn_rm: torch.Tensor, bn_rv: torch.Tensor, bn_eps: float, bn_w: Optional[torch.Tensor], bn_b: Optional[torch.Tensor], transpose: bool = False ) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]: r"""将卷积模块参数和BatchNorm模块参数融合成新的卷积模块参数。 参数: conv_w (torch.Tensor): 卷积权重。 conv_b (Optional[torch.Tensor]): 卷积偏置。 bn_rm (torch.Tensor): BatchNorm运行均值。 bn_rv (torch.Tensor): BatchNorm运行方差。 bn_eps (float): BatchNorm epsilon。 bn_w (Optional[torch.Tensor]): BatchNorm权重。 bn_b (Optional[torch.Tensor]): BatchNorm偏置。 transpose (bool, 可选): 如果为True,转置卷积权重。默认为False。 返回: Tuple[torch.nn.Parameter, torch.nn.Parameter]: 融合后的卷积权重和偏置。 """ conv_weight_dtype = conv_w.dtype conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype if conv_b is None: conv_b = torch.zeros_like(bn_rm) if bn_w is None: bn_w = torch.ones_like(bn_rm) if bn_b is None: bn_b = torch.zeros_like(bn_rm) bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps) if transpose: shape = [1, -1] + [1] * (len(conv_w.shape) - 2) else: shape = [-1, 1] + [1] * (len(conv_w.shape) - 2) fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to<span class="