torch.nn.utils.parametrizations.spectral_norm¶
- torch.nn.utils.parametrizations.spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None)[源代码]¶
对给定模块中的参数应用谱归一化。
当应用于向量时,它简化为
谱归一化通过降低模型的Lipschitz常数,稳定了生成对抗网络(GANs)中判别器(批评者)的训练。是通过每次访问权重时执行一次幂方法来近似的。如果权重张量的维度大于2,则在幂迭代方法中将其重塑为2D以获得谱范数。
参见 用于生成对抗网络的光谱归一化。
注意
此函数使用
register_parametrization()
中的参数化功能实现。它是torch.nn.utils.spectral_norm()
的重新实现。注意
当注册此约束时,与最大奇异值相关的奇异向量是通过估计而非随机采样得到的。然后,每当以训练模式访问张量时,都会执行
n_power_iterations
次幂方法来更新这些向量。注意
如果_SpectralNorm模块,即module.parametrization.weight[idx],在移除时处于训练模式,它将执行另一次幂迭代。如果您希望避免这次迭代,请在移除之前将模块设置为评估模式。
- Parameters
- Returns
原始模块已使用新的参数化方法注册到指定的权重
- Return type
示例:
>>> snm = spectral_norm(nn.Linear(20, 40)) >>> snm 参数化线性( in_features=20, out_features=40, bias=True (参数化): ModuleDict( (weight): 参数化列表( (0): _SpectralNorm() ) ) ) >>> torch.linalg.matrix_norm(snm.weight, 2) 张量(1.0081, grad_fn=
)