torch.linalg.svd¶
- torch.linalg.svd(A, full_matrices=True, *, driver=None, out=None)¶
计算矩阵的奇异值分解(SVD)。
设 为 或 , 矩阵的 完全奇异值分解 ,如果 k = min(m,n),则定义为
其中 , 是当 为复数时的共轭转置,以及当 为实数时的转置。 矩阵 , (以及因此 )在实数情况下是正交的,在复数情况下是酉的。
当 m > n(或 m < n)时,我们可以去掉最后的 m - n(或 n - m)列的 U(或 V)来形成 简化SVD:
其中 。 在这种情况下, 和 也具有正交列。
支持输入 float、double、cfloat 和 cdouble 数据类型。 还支持矩阵的批处理,如果
A是矩阵的批处理,则输出具有相同的批处理维度。返回的分解是一个命名元组 (U, S, Vh),它对应于 , , 上述内容。
奇异值按降序返回。
参数
full_matrices选择全(默认)和简化的SVD。在CUDA中使用cuSOLVER后端时,可以使用
driver关键字参数来选择用于计算SVD的算法。 选择驱动程序是在准确性和速度之间进行权衡。如果
A是良态的(它的 条件数 不是太大),或者你不介意一些精度损失。对于一般矩阵:‘gesvdj’(Jacobi 方法)
如果
A是高或宽的(m >> n 或 m << n):‘gesvda’(近似方法)
如果
A条件不好或精度相关:‘gesvd’(基于QR分解)
默认情况下(
driver= None),我们调用‘gesvdj’,如果失败,则回退到‘gesvd’。与numpy.linalg.svd的区别:
与numpy.linalg.svd不同,此函数始终返回三个张量的元组,并且不支持compute_uv参数。 请使用
torch.linalg.svdvals(),它仅计算奇异值, 而不是compute_uv=False。
注意
当
full_matrices= True 时,对于 U[…, :, min(m, n):] 和 Vh[…, min(m, n):, :] 的梯度将被忽略,因为这些向量可以是相应子空间的任意基。警告
返回的张量U和V不是唯一的,也不是关于
A连续的。由于这种非唯一性,不同的硬件和软件可能会计算出不同的奇异向量。这种非唯一性是由于在实数情况下,任何一对奇异向量 乘以 -1 或在复数情况下乘以 会产生另外两个 有效的矩阵奇异向量。 因此,损失函数不应依赖于这个 量, 因为它没有明确定义。 在计算此函数的梯度时,会检查复数输入。因此, 当输入是复数且位于CUDA设备上时,此函数的梯度计算会同步该设备与CPU。
警告
使用 U 或 Vh 计算的梯度只有在
A没有重复的奇异值时才是有限的。如果A是矩形的,则其奇异值中也不能包含零。 此外,如果任意两个奇异值之间的距离接近于零,梯度将会在数值上不稳定,因为它依赖于奇异值 通过计算 。 在矩形情况下,当A具有较小的奇异值时,梯度也会在数值上不稳定,因为它还依赖于计算 。另请参阅
torch.linalg.svdvals()仅计算奇异值。 与torch.linalg.svd()不同,svdvals()的梯度总是 数值稳定的。torch.linalg.eig()用于计算矩阵的另一种谱分解的函数。特征分解仅适用于方阵。torch.linalg.eigh()用于计算厄米矩阵和对称矩阵的特征值分解的(更快)函数。torch.linalg.qr()用于另一种(速度快得多)适用于一般矩阵的分解方法。- Parameters
- Keyword Arguments
- Returns
一个命名元组 (U, S, Vh),对应于 、、 如上所述。
S 将始终为实数值,即使
A是复数。它也将按降序排列。U 和 Vh 将具有与
A相同的 dtype。左/右奇异向量将分别由 U 的列和 Vh 的行给出。
示例:
>>> A = torch.randn(5, 3) >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) >>> U.shape, S.shape, Vh.shape (torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3])) >>> torch.dist(A, U @ torch.diag(S) @ Vh) tensor(1.0486e-06) >>> U, S, Vh = torch.linalg.svd(A) >>> U.shape, S.shape, Vh.shape (torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3])) >>> torch.dist(A, U[:, :3] @ torch.diag(S) @ Vh) tensor(1.0486e-06) >>> A = torch.randn(7, 5, 3) >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) >>> torch.dist(A, U @ torch.diag_embed(S) @ Vh) tensor(3.0957e-06)