Shortcuts

torch.lu

torch.lu(*args, **kwargs)

计算矩阵或矩阵批次的LU分解。返回一个包含A的LU分解和枢轴的元组。如果pivot设置为True,则进行枢轴转换。

警告

torch.lu() 已被弃用,取而代之的是 torch.linalg.lu_factor()torch.linalg.lu_factor_ex()torch.lu() 将在未来的 PyTorch 版本中被移除。 LU, pivots, info = torch.lu(A, compute_pivots) 应替换为

LU, pivots = torch.linalg.lu_factor(A, compute_pivots)

LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True) 应该替换为

LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)

注意

  • 返回的每个矩阵的置换矩阵由大小为 min(A.shape[-2], A.shape[-1]) 的1索引向量表示。 pivots[i] == j 表示在算法的第 i 步中, 第 i 行与第 j-1 行进行了置换。

  • LU 分解在 pivot = False 时对 CPU 不可用,尝试这样做会抛出错误。然而,LU 分解在 pivot = False 时对 CUDA 可用。

  • 如果 get_infosTrue,此函数不会检查因式分解是否成功,因为因式分解的状态存在于返回元组的第三个元素中。

  • 在CUDA设备上,对于大小小于或等于32的方形矩阵批次,由于MAGMA库中的错误(参见magma问题13),LU分解会重复进行以处理奇异矩阵。

  • L, U, 和 P 可以使用 torch.lu_unpack() 导出。

警告

A 是满秩时,此函数的梯度才会是有限的。 这是因为LU分解仅在满秩矩阵处可微。 此外,如果 A 接近于非满秩, 梯度将会在数值上不稳定,因为它依赖于 L1L^{-1}U1U^{-1} 的计算。

Parameters
  • A (张量) – 要分解的张量,大小为 (,m,n)(*, m, n)

  • pivot (bool, 可选) – 控制是否进行透视。默认值: True

  • get_infos (bool, 可选) – 如果设置为 True,则返回一个信息 IntTensor。 默认值:False

  • out (元组, 可选) – 可选的输出元组。如果 get_infosTrue, 则元组中的元素是 Tensor, IntTensor, 和 IntTensor。如果 get_infosFalse,则元组中的 元素是 Tensor, IntTensor。默认值:None

Returns

包含张量的元组

  • 因式分解 (张量): 大小为 (,m,n)(*, m, n)

  • 枢轴 (IntTensor): 大小为 (,min(m,n))(*, \text{min}(m, n))pivots 存储了所有行的中间转置。 最终的排列 perm 可以通过 应用 swap(perm[i], perm[pivots[i] - 1]) 对于 i = 0, ..., pivots.size(-1) - 1, 其中 perm 最初是 mm 元素的单位排列 (本质上这就是 torch.lu_unpack() 所做的)。

  • 信息 (IntTensor, 可选): 如果 get_infosTrue,这是一个大小为 ()(*) 的张量,其中非零值表示矩阵或 每个小批次的因式分解是否成功或失败

Return type

(张量, 整型张量, 整型张量 (可选))

示例:

>>> A = torch.randn(2, 3, 3)
>>> A_LU, pivots = torch.lu(A)
>>> A_LU
tensor([[[ 1.3506,  2.5558, -0.0816],
         [ 0.1684,  1.1551,  0.1940],
         [ 0.1193,  0.6189, -0.5497]],

        [[ 0.4526,  1.2526, -0.3285],
         [-0.7988,  0.7175, -0.9701],
         [ 0.2634, -0.9255, -0.3459]]])
>>> pivots
tensor([[ 3,  3,  3],
        [ 3,  3,  3]], dtype=torch.int32)
>>> A_LU, pivots, info = torch.lu(A, get_infos=True)
>>> if info.nonzero().size(0) == 0:
...     print('LU factorization succeeded for all samples!')
LU factorization succeeded for all samples!