Shortcuts

torch.nn.utils.weight_norm 的源代码

r"""来自 https://arxiv.org/abs/1602.07868 的权重归一化。"""
from torch.nn.parameter import Parameter, UninitializedParameter
from torch import _weight_norm, norm_except_dim
from typing import Any, TypeVar
import warnings
from ..modules import Module

__all__ = ['WeightNorm', 'weight_norm', 'remove_weight_norm']

class WeightNorm:
    name: str
    dim: int

    def __init__(self, name: str, dim: int) -> None:
        if dim is None:
            dim = -1
        self.name = name
        self.dim = dim

    # TODO 使返回类型更具体
    def compute_weight(self, module: Module) -> Any:
        g = getattr(module, self.name + '_g')
        v = getattr(module, self.name + '_v')
        return _weight_norm(v, g, self.dim)

    @staticmethod
    def apply(module, name: str, dim: int) -> 'WeightNorm':
        warnings.warn("torch.nn.utils.weight_norm 已被弃用,建议使用 torch.nn.utils.parametrizations.weight_norm。")

        for hook in module._forward_pre_hooks.values():
            if isinstance(hook, WeightNorm) and hook.name == name:
                raise RuntimeError(f"不能在同一个参数 {name} 上注册两个 weight_norm 钩子")

        if dim is None:
            dim = -1

        fn = WeightNorm(name, dim)

        weight = getattr(module, name)
        if isinstance(weight, UninitializedParameter):
            raise ValueError(
                '传递给 `WeightNorm` 的模块不能有未初始化的参数。'
                '在应用权重归一化之前,请确保运行了虚拟前向传播')
        # 从参数列表中删除 w
        del module._parameters[name]

        # 将 g 和 v 作为新参数添加,并将 w 表示为 g/||v|| * v
        module.register_parameter(name + '_g', Parameter(norm_except_dim(weight, 2, dim).data))
        module.register_parameter(name + '_v', Parameter(weight.data))
        setattr(module, name, fn.compute_weight(module))

        # 在每次 forward() 之前重新计算权重
        module.register_forward_pre_hook(fn)

        return fn

    def remove(self, module: Module) -> None:
        weight = self.compute_weight(module)
        delattr(module, self.name)
        del module._parameters[self.name + '_g']
        del module._parameters[self.name + '_v']
        setattr(module, self.name, Parameter(weight.data))

    def __call__(self, module: Module, inputs: Any) -> None:
        setattr(module, self.name, self.compute_weight(module))


T_module = TypeVar('T_module', bound=Module)

[docs]def weight_norm(module: T_module, name: str = 'weight', dim: int = 0) -> T_module: r"""将权重归一化应用于给定模块中的参数。 .. math:: \mathbf{w
优云智算