Shortcuts

torch.nn.utils.parametrizations 的源代码

```html
from enum import Enum, auto

import torch
from torch import Tensor
from ..utils import parametrize
from ..modules import Module
from .. import functional as F

from typing import Optional

__all__ = ['orthogonal', 'spectral_norm', 'weight_norm']


def _is_orthogonal(Q, eps=None):
    n, k = Q.size(-2), Q.size(-1)
    Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
    # 一个合理的eps,但不要太大
    eps = 10. * n * torch.finfo(Q.dtype).eps
    return torch.allclose(Q.mH @ Q, Id, atol=eps)


def _make_orthogonal(A):
    """假设A是一个高矩阵。

    计算Q因子使得A = QR(A可能是复数)且diag(R)是实数且非负。
    """
    X, tau = torch.geqrf(A)
    Q = torch.linalg.householder_product(X, tau)
    # X的对角线是R的对角线(总是实数),所以我们通过其符号进行归一化
    Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
    return Q


class _OrthMaps(Enum):
    matrix_exp = auto()
    cayley = auto()
    householder = auto()


class _Orthogonal(Module):
    base: Tensor

    def __init__(self,
                 weight,
                 orthogonal_map: _OrthMaps,
                 *,
                 use_trivialization=True) -> None:
        super().__init__()

        # 注意 [Householder 复数]
        # 对于复数张量,无法计算张量`tau`,这是linalg.householder_product所需的反射器。
        # 要看到这一点,请注意反射器的形状如下:
        # 0 0 0
        # * 0 0
        # * * 0
        # 对于复数矩阵,给出n(n-1)(实数)参数。现在,你需要n^2个参数
        # 来参数化酉矩阵。单独保存tau不起作用,因为
        # 并非每对`(A, tau)`都给出酉矩阵,这意味着如果我们优化
        # 它们作为独立张量,我们将不会保持约束
        # 矩形矩阵的等效推理成立
        if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
            raise ValueError("Householder参数化不支持复数张量。")

        self.shape = weight.shape
        self.orthogonal_map = orthogonal_map
        if use_trivialization:
            self.register_buffer("base", None)

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        n, k = X.size(-2), X.size(-1)
        transposed = n < k
        if transposed:
            X = X.mT
            n, k = k, n
        # 这里n > k且X是一个高矩阵
        if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley:
            # 我们只需要n x k - k(k-1)/2个参数
            X = X.tril()
            if n != k:
                # 嵌入到一个方阵
                X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
            A = X - X.mH
            # A是斜对称(或斜厄米特)
            if</
优云智算