Shortcuts

torch.nn.utils.parametrizations.spectral_norm

torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[源代码]

对给定模块中的参数应用谱归一化。

WSN=Wσ(W),σ(W)=maxh:h0Wh2h2\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}

当应用于向量时,它简化为

xSN=xx2\mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}

谱归一化通过降低模型的Lipschitz常数,稳定了生成对抗网络(GANs)中判别器(批评者)的训练。σ\sigma是通过每次访问权重时执行一次幂方法来近似的。如果权重张量的维度大于2,则在幂迭代方法中将其重塑为2D以获得谱范数。

参见 用于生成对抗网络的光谱归一化

注意

此函数使用register_parametrization()中的参数化功能实现。它是torch.nn.utils.spectral_norm()的重新实现。

注意

当注册此约束时,与最大奇异值相关的奇异向量是通过估计而非随机采样得到的。然后,每当以训练模式访问张量时,都会执行n_power_iterations幂方法来更新这些向量。

注意

如果_SpectralNorm模块,即module.parametrization.weight[idx],在移除时处于训练模式,它将执行另一次幂迭代。如果您希望避免这次迭代,请在移除之前将模块设置为评估模式。

Parameters
  • 模块 (nn.Module) – 包含的模块

  • 名称 (字符串, 可选) – 权重参数的名称。默认值: "weight"

  • n_power_iterations (整数, 可选) – 计算谱范数的幂迭代次数。默认值:1

  • eps (float, 可选) – 在计算范数时用于数值稳定的epsilon。默认值:1e-12

  • dim (int, 可选) – 对应于输出数量的维度。 默认值: 0,除了当模块是 ConvTranspose{1,2,3}d 的实例时,默认值为 1

Returns

原始模块已使用新的参数化方法注册到指定的权重

Return type

模块

示例:

>>> snm = spectral_norm(nn.Linear(20, 40))
>>> snm
参数化线性(
  in_features=20, out_features=40, bias=True
  (参数化): ModuleDict(
    (weight): 参数化列表(
      (0): _SpectralNorm()
    )
  )
)
>>> torch.linalg.matrix_norm(snm.weight, 2)
张量(1.0081, grad_fn=)