torch.linalg.lu_solve¶
- torch.linalg.lu_solve(LU, pivots, B, *, left=True, adjoint=False, out=None) 张量¶
计算具有唯一解的方阵线性方程组的解,给定LU分解。
设 为 或 , 此函数计算与 相关的 线性系统 的解 ,其定义为
其中 被分解为
lu_factor()返回的形式。如果
left= False,此函数返回矩阵 解决系统如果
adjoint= True(并且left= True),给定一个 的LU分解,此函数返回 解该系统其中 是当 为复数时的共轭转置,以及当 为实值时的转置。
left= False 的情况类似。支持float、double、cfloat和cdouble数据类型的输入。 还支持矩阵的批处理,如果输入是矩阵的批处理,则输出具有相同的批处理维度。
- Parameters
LU (张量) – 形状为 (*, n, n)(或 (*, k, k) 如果
left= True)的张量 其中 * 是由lu_factor()返回的零个或多个批次维度。pivots (张量) – 形状为 (*, n) 的张量(如果
left= True,则为 (*, k)) 其中 * 是零个或多个批次维度,由lu_factor()返回。B (张量) – 形状为 (*, n, k) 的右侧张量。
- Keyword Arguments
示例:
>>> A = torch.randn(3, 3) >>> LU, pivots = torch.linalg.lu_factor(A) >>> B = torch.randn(3, 2) >>> X = torch.linalg.lu_solve(LU, pivots, B) >>> torch.allclose(A @ X, B) True >>> B = torch.randn(3, 3, 2) # 广播规则适用:A被广播 >>> X = torch.linalg.lu_solve(LU, pivots, B) >>> torch.allclose(A @ X, B) True >>> B = torch.randn(3, 5, 3) >>> X = torch.linalg.lu_solve(LU, pivots, B, left=False) >>> torch.allclose(X @ A, B) True >>> B = torch.randn(3, 3, 4) # 现在求解 A^T >>> X = torch.linalg.lu_solve(LU, pivots, B, adjoint=True) >>> torch.allclose(A.mT @ X, B) True