torch.nn.utils.fusion 的源代码
from __future__ import annotations
import copy
from typing import Optional, Tuple, TypeVar
import torch
__all__ = ['fuse_conv_bn_eval', 'fuse_conv_bn_weights', 'fuse_linear_bn_eval', 'fuse_linear_bn_weights']
ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd")
LinearT = TypeVar("LinearT", bound="torch.nn.Linear")
[docs]def fuse_conv_bn_eval(conv: ConvT, bn: torch.nn.modules.batchnorm._BatchNorm, transpose: bool = False) -> ConvT:
r"""将一个卷积模块和一个BatchNorm模块融合成一个单一的、新的卷积模块。
参数:
conv (torch.nn.modules.conv._ConvNd): 一个卷积模块。
bn (torch.nn.modules.batchnorm._BatchNorm): 一个BatchNorm模块。
transpose (bool, 可选): 如果为True,转置卷积权重。默认为False。
返回:
torch.nn.modules.conv._ConvNd: 融合后的卷积模块。
.. 注意::
两个模块 ``conv`` 和 ``bn`` 都必须处于eval模式,并且 ``bn`` 必须已经计算了其运行缓冲区。
"""
assert not (conv.training or bn.training), "仅在eval模式下进行融合!"
fused_conv = copy.deepcopy(conv)
assert bn.running_mean is not None and bn.running_var is not None
fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
fused_conv.weight, fused_conv.bias,
bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)
return fused_conv
[docs]def fuse_conv_bn_weights(
conv_w: torch.Tensor,
conv_b: Optional[torch.Tensor],
bn_rm: torch.Tensor,
bn_rv: torch.Tensor,
bn_eps: float,
bn_w: Optional[torch.Tensor],
bn_b: Optional[torch.Tensor],
transpose: bool = False
) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
r"""将卷积模块参数和BatchNorm模块参数融合成新的卷积模块参数。
参数:
conv_w (torch.Tensor): 卷积权重。
conv_b (Optional[torch.Tensor]): 卷积偏置。
bn_rm (torch.Tensor): BatchNorm运行均值。
bn_rv (torch.Tensor): BatchNorm运行方差。
bn_eps (float): BatchNorm epsilon。
bn_w (Optional[torch.Tensor]): BatchNorm权重。
bn_b (Optional[torch.Tensor]): BatchNorm偏置。
transpose (bool, 可选): 如果为True,转置卷积权重。默认为False。
返回:
Tuple[torch.nn.Parameter, torch.nn.Parameter]: 融合后的卷积权重和偏置。
"""
conv_weight_dtype = conv_w.dtype
conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
bn_w = torch.ones_like(bn_rm)
if bn_b is None:
bn_b = torch.zeros_like(bn_rm)
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
if transpose:
shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
else:
shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to<span class="