GroupNorm¶
- class torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)[源代码]¶
对小批量输入应用组归一化。
该层实现了论文中描述的操作 Group Normalization
输入通道被分成
num_groups组,每组包含num_channels / num_groups个通道。num_channels必须能被num_groups整除。每个组的均值和标准差分别计算。和是可学习的 每个通道的仿射变换参数向量,大小为num_channels,如果affine为True。 标准差通过有偏估计量计算,等价于 torch.var(input, unbiased=False)。该层在训练和评估模式下均使用从输入数据计算的统计数据。
- Parameters
- Shape:
输入: 其中
输出: (与输入形状相同)
示例:
>>> input = torch.randn(20, 6, 10, 10) >>> # 将6个通道分成3组 >>> m = nn.GroupNorm(3, 6) >>> # 将6个通道分成6组(等同于InstanceNorm) >>> m = nn.GroupNorm(6, 6) >>> # 将所有6个通道放入一个组(等同于LayerNorm) >>> m = nn.GroupNorm(1, 6) >>> # 激活模块 >>> output = m(input)