Shortcuts

torch.linalg.pinv

torch.linalg.pinv(A, *, atol=None, rtol=None, hermitian=False, out=None) 张量

计算矩阵的伪逆(Moore-Penrose逆)。

伪逆可以通过代数定义,但通过奇异值分解(SVD)来理解它在计算上更为方便。

支持输入 float、double、cfloat 和 cdouble 数据类型。 还支持矩阵的批处理,如果 A 是矩阵的批处理,则输出具有相同的批处理维度。

如果 hermitian= True,则假设 A 在复数情况下为厄米矩阵,在实数情况下为对称矩阵,但内部不会进行检查。相反,计算中仅使用矩阵的下三角部分。

奇异值(或当 hermitian= True 时的特征值的范数) 低于 max(atol,σ1rtol)\max(\text{atol}, \sigma_1 \cdot \text{rtol}) 阈值的值 在计算中被视为零并被丢弃, 其中 σ1\sigma_1 是最大的奇异值(或特征值)。

如果未指定 rtolA 是一个维度为 (m, n) 的矩阵, 相对容差将设置为 rtol=max(m,n)ε\text{rtol} = \max(m, n) \varepsilon 并且 ε\varepsilonA 的数据类型(dtype)的 epsilon 值(参见 finfo)。 如果未指定 rtolatol 被指定为大于零,则 rtol 将设置为零。

如果 atolrtol 是一个 torch.Tensor,它的形状必须可以广播到由 torch.linalg.svd() 返回的 A 的奇异值的形状。

注意

如果 hermitian= False,此函数使用 torch.linalg.svd(); 如果 hermitian= True,则使用 torch.linalg.eigh()。 对于 CUDA 输入,此函数会与 CPU 同步该设备。

注意

如果可能的话,考虑使用 torch.linalg.lstsq() 来将矩阵左乘伪逆,如下所示:

torch.linalg.lstsq(A, B).solution == A.pinv() @ B

在可能的情况下,始终建议使用 lstsq(),因为它更快且在数值上更稳定,而不是显式计算伪逆。

注意

此函数具有与NumPy兼容的变体 linalg.pinv(A, rcond, hermitian=False)。 然而,使用位置参数 rcond 已被弃用,取而代之的是 rtol

警告

此函数内部使用 torch.linalg.svd()(或 torch.linalg.eigh()hermitian= True),因此其导数与这些函数具有相同的问题。有关更多详细信息,请参阅 torch.linalg.svd()torch.linalg.eigh() 中的警告。

另请参阅

torch.linalg.inv() 计算方阵的逆。

torch.linalg.lstsq() 使用数值稳定的算法计算 A.pinv() @ B

Parameters
  • A (张量) – 形状为 (*, m, n) 的张量,其中 * 表示零个或多个批次维度。

  • rcond (float, Tensor, 可选) – [NumPy 兼容]. rtol 的别名。默认值: None

Keyword Arguments
  • atol (float, Tensor, 可选) – 绝对容差值。当为None时,视为零。 默认值: None

  • rtol (float, Tensor, 可选) – 相对容差值。参见上文以了解当 None 时的取值。 默认值: None

  • hermitian (bool, 可选) – 指示如果复数,A 是否为 Hermitian;如果实数,是否为对称。默认值:False

  • 输出 (张量, 可选) – 输出张量。如果为,则忽略。默认值:

示例:

>>> A = torch.randn(3, 5)
>>> A
tensor([[ 0.5495,  0.0979, -1.4092, -0.1128,  0.4132],
        [-1.1143, -0.3662,  0.3042,  1.6374, -0.9294],
        [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]])
>>> torch.linalg.pinv(A)
tensor([[ 0.0600, -0.1933, -0.2090],
        [-0.0903, -0.0817, -0.4752],
        [-0.7124, -0.1631, -0.2272],
        [ 0.1356,  0.3933, -0.5023],
        [-0.0308, -0.1725, -0.5216]])

>>> A = torch.randn(2, 6, 3)
>>> Apinv = torch.linalg.pinv(A)
>>> torch.dist(Apinv @ A, torch.eye(3))
tensor(8.5633e-07)

>>> A = torch.randn(3, 3, dtype=torch.complex64)
>>> A = A + A.T.conj()  # 创建一个厄米矩阵
>>> Apinv = torch.linalg.pinv(A, hermitian=True)
>>> torch.dist(Apinv @ A, torch.eye(3))
tensor(1.0830e-06)
优云智算