Shortcuts

torch.linalg.lstsq

torch.linalg.lstsq(A, B, rcond=None, *, driver=None)

计算线性方程组的最小二乘问题的解。

K\mathbb{K}R\mathbb{R}C\mathbb{C}, 线性系统 AX=BAX = B最小二乘问题 定义为 AKm×n,BKm×kA \in \mathbb{K}^{m \times n}, B \in \mathbb{K}^{m \times k}

minXKn×kAXBF\min_{X \in \mathbb{K}^{n \times k}} \|AX - B\|_F

其中 F\|-\|_F 表示 Frobenius 范数。

支持float、double、cfloat和cdouble数据类型的输入。 还支持矩阵的批处理,如果输入是矩阵的批处理,则输出具有相同的批处理维度。

driver 选择将使用的后端函数。 对于CPU输入,有效值为 ‘gels’, ‘gelsy’, ‘gelsd’, ‘gelss’。 要在CPU上选择最佳驱动程序,请考虑:

  • 如果 A 是良态的(它的 条件数 不是太大),或者你不介意一些精度损失。

    • 对于一般矩阵:‘gelsy’(带旋转的QR分解)(默认)

    • 如果 A 是满秩的:‘gels’ (QR)

  • 如果 A 不是良态的。

    • ‘gelsd’(三对角矩阵约简和SVD)

    • 但如果你遇到内存问题:‘gelss’(完整SVD)。

对于CUDA输入,唯一有效的驱动程序是‘gels’,它假设A是满秩的。

参见这些驱动程序的完整描述

rcond 用于在 driver 为 (‘gelsy’, ‘gelsd’, ‘gelss’) 之一时确定矩阵 A 的有效秩。 在这种情况下,如果 σi\sigma_i 是按降序排列的 A 的奇异值, σi\sigma_i 将被向下舍入为零,如果 σircondσ1\sigma_i \leq \text{rcond} \cdot \sigma_1。 如果 rcond= None(默认),rcond 被设置为 A 的 dtype 的机器精度乘以 max(m, n)

此函数返回问题的解决方案以及一些额外的信息,这些信息以包含四个张量的命名元组形式提供 (solution, residuals, rank, singular_values)。对于形状分别为 (*, m, n)(*, m, k) 的输入 AB,它包含

  • : 最小二乘解。它的形状为 (*, n, k)

  • 残差:解的平方残差,即AXBF2\|AX - B\|_F^2。 它的形状等于A的批次维度。 当m > nA中的每个矩阵都是满秩时,它会计算, 否则,它是一个空张量。 如果A是一批矩阵,并且批次中的任何矩阵不是满秩, 则返回一个空张量。此行为可能会在未来的PyTorch版本中更改。

  • rank: 矩阵在 A 中的秩的张量。 它的形状等于 A 的批次维度。 当 driver 是 (‘gelsy’, ‘gelsd’, ‘gelss’) 之一时计算, 否则它是一个空张量。

  • singular_values: 矩阵的奇异值张量,位于 A 中。 其形状为 (*, min(m, n))。 当 driver 为 (‘gelsd’, ‘gelss’) 之一时计算, 否则为空张量。

注意

此函数以比分别执行计算更快且更数值稳定的方式计算 X = A.pinverse() @ B

警告

在未来的 PyTorch 版本中,rcond 的默认值可能会发生变化。 因此,建议使用固定值以避免潜在的破坏性更改。

Parameters
  • A (张量) – 形状为 (*, m, n) 的左操作数张量,其中 * 表示零个或多个批次维度。

  • B (张量) – 形状为 (*, m, k) 的右侧张量,其中 * 表示零个或多个批次维度。

  • rcond (float, 可选) – 用于确定A的有效秩。 如果rcond= Nonercond被设置为A的数据类型机器精度乘以max(m, n)。默认值:None

Keyword Arguments

驱动程序 (字符串, 可选) – 要使用的LAPACK/MAGMA方法的名称。 如果为,则对于CPU输入使用‘gelsy’,对于CUDA输入使用‘gels’。 默认值:

Returns

一个命名元组 (解, 残差, 秩, 奇异值)

示例:

>>> A = torch.randn(1,3,3)
>>> A
tensor([[[-1.0838,  0.0225,  0.2275],
     [ 0.2438,  0.3844,  0.5499],
     [ 0.1175, -0.9102,  2.0870]]])
>>> B = torch.randn(2,3,3)
>>> B
tensor([[[-0.6772,  0.7758,  0.5109],
     [-1.4382,  1.3769,  1.1818],
     [-0.3450,  0.0806,  0.3967]],
    [[-1.3994, -0.1521, -0.1473],
     [ 1.9194,  1.0458,  0.6705],
     [-1.1802, -0.9796,  1.4086]]])
>>> X = torch.linalg.lstsq(A, B).solution # A 被广播到形状 (2, 3, 3)
>>> torch.dist(X, torch.linalg.pinv(A) @ B)
tensor(1.5152e-06)

>>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values
>>> torch.dist(S, torch.linalg.svdvals(A))
tensor(2.3842e-07)

>>> A[:, 0].zero_()  # 减少 A 的秩
>>> rank = torch.linalg.lstsq(A, B).rank
>>> rank
tensor([2])
优云智算