torch.nn.utils.parametrizations 的源代码
```html
from enum import Enum, auto import torch from torch import Tensor from ..utils import parametrize from ..modules import Module from .. import functional as F from typing import Optional __all__ = ['orthogonal', 'spectral_norm', 'weight_norm'] def _is_orthogonal(Q, eps=None): n, k = Q.size(-2), Q.size(-1) Id = torch.eye(k, dtype=Q.dtype, device=Q.device) # 一个合理的eps,但不要太大 eps = 10. * n * torch.finfo(Q.dtype).eps return torch.allclose(Q.mH @ Q, Id, atol=eps) def _make_orthogonal(A): """假设A是一个高矩阵。 计算Q因子使得A = QR(A可能是复数)且diag(R)是实数且非负。 """ X, tau = torch.geqrf(A) Q = torch.linalg.householder_product(X, tau) # X的对角线是R的对角线(总是实数),所以我们通过其符号进行归一化 Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) return Q class _OrthMaps(Enum): matrix_exp = auto() cayley = auto() householder = auto() class _Orthogonal(Module): base: Tensor def __init__(self, weight, orthogonal_map: _OrthMaps, *, use_trivialization=True) -> None: super().__init__() # 注意 [Householder 复数] # 对于复数张量,无法计算张量`tau`,这是linalg.householder_product所需的反射器。 # 要看到这一点,请注意反射器的形状如下: # 0 0 0 # * 0 0 # * * 0 # 对于复数矩阵,给出n(n-1)(实数)参数。现在,你需要n^2个参数 # 来参数化酉矩阵。单独保存tau不起作用,因为 # 并非每对`(A, tau)`都给出酉矩阵,这意味着如果我们优化 # 它们作为独立张量,我们将不会保持约束 # 矩形矩阵的等效推理成立 if weight.is_complex() and orthogonal_map == _OrthMaps.householder: raise ValueError("Householder参数化不支持复数张量。") self.shape = weight.shape self.orthogonal_map = orthogonal_map if use_trivialization: self.register_buffer("base", None) def forward(self, X: torch.Tensor) -> torch.Tensor: n, k = X.size(-2), X.size(-1) transposed = n < k if transposed: X = X.mT n, k = k, n # 这里n > k且X是一个高矩阵 if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley: # 我们只需要n x k - k(k-1)/2个参数 X = X.tril() if n != k: # 嵌入到一个方阵 X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1) A = X - X.mH # A是斜对称(或斜厄米特) if</