• Docs >
  • torch.nn >
  • torch.nn.utils.parametrize.register_parametrization
Shortcuts

torch.nn.utils.parametrize.register_parametrization

torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization, *, unsafe=False)[源代码]

在模块中为张量注册一个参数化。

假设为了简单起见,tensor_name="weight"。当访问module.weight时, 模块将返回参数化版本parametrization(module.weight)。 如果原始张量需要梯度,反向传播将通过parametrization进行微分,并且优化器将相应地更新张量。

模块首次注册参数化时,此函数将向模块添加一个属性 parametrizations,其类型为 ParametrizationList

张量 weight 上的参数化列表将可以在 module.parametrizations.weight 下访问。

原始张量将在 module.parametrizations.weight.original下访问。

可以通过在同一属性上注册多个参数化来连接参数化。

注册参数化的训练模式在注册时更新,以匹配宿主模块的训练模式

参数化和缓冲区有一个内置的缓存系统,可以使用上下文管理器 cached() 来激活。

一个 parametrization 可以选择性地实现一个具有签名的方法

def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]

当第一个参数化注册以计算原始张量的初始值时,此方法会在未参数化的张量上调用。如果未实现此方法,原始张量将仅是未参数化的张量。

如果注册在张量上的所有参数化方法都实现了right_inverse,则可以通过赋值来初始化一个参数化张量,如下例所示。

第一个参数化可能依赖于多个输入。 这可以通过从right_inverse返回一个张量元组来实现 (参见下面的RankOne参数化的示例实现)。

在这种情况下,无约束的张量也位于 module.parametrizations.weight 下,名称为 original0, original1,…

注意

如果 unsafe=False(默认),forward 和 right_inverse 方法都将被调用一次以执行一系列一致性检查。如果 unsafe=True,则在张量未参数化的情况下将调用 right_inverse,否则不会调用任何方法。

注意

在大多数情况下,right_inverse 将是一个函数,使得 forward(right_inverse(X)) == X(参见 右逆)。 有时,当参数化不是满射时,放宽这一点可能是合理的。

警告

如果参数化依赖于多个输入,register_parametrization() 将会注册多个新参数。如果在优化器创建后注册了这种参数化,这些新参数需要手动添加到优化器中。请参阅torch.Optimizer.add_param_group()

Parameters
  • 模块 (nn.Module) – 要注册参数化的模块

  • tensor_name (str) – 要注册参数化的参数或缓冲区的名称

  • 参数化 (nn.Module) – 要注册的参数化

Keyword Arguments

不安全 (bool) – 一个布尔标志,表示参数化是否可能改变张量的dtype和形状。默认值:False 警告:在注册时不会检查参数化的一致性。启用此标志需自行承担风险。

Raises

ValueError – 如果模块没有名为 tensor_name 的参数或缓冲区

Return type

模块

示例

>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.utils.parametrize as P
>>>
>>> class Symmetric(nn.Module):
>>>     def forward(self, X):
>>>         return X.triu() + X.triu(1).T  # 返回一个对称矩阵
>>>
>>>     def right_inverse(self, A):
>>>         return A.triu()
>>>
>>> m = nn.Linear(5, 5)
>>> P.register_parametrization(m, "weight", Symmetric())
>>> print(torch.allclose(m.weight, m.weight.T))  # m.weight 现在是对称的
True
>>> A = torch.rand(5, 5)
>>> A = A + A.T   # A 现在是对称的
>>> m.weight = A  # 将权重初始化为对称矩阵 A
>>> print(torch.allclose(m.weight, A))
True
>>> class RankOne(nn.Module):
>>>     def forward(self, x, y):
>>>         # 形成一个秩为1的矩阵,通过两个向量相乘
>>>         return x.unsqueeze(-1) @ y.unsqueeze(-2)
>>>
>>>     def right_inverse(self, Z):
>>>         # 将Z投影到秩为1的矩阵上
>>>         U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
>>>         # 返回重新缩放的奇异向量
>>>         s0_sqrt = S[0].sqrt().unsqueeze(-1)
>>>         return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
>>>
>>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
>>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
1
优云智算