torch.linalg.inv¶
- torch.linalg.inv(A, *, out=None) 张量¶
如果存在,计算方阵的逆矩阵。 如果矩阵不可逆,则抛出RuntimeError。
设 为 或 , 对于矩阵 , 其 逆矩阵 (如果存在)定义为
其中 是 n 维单位矩阵。
当且仅当 是 可逆 时,逆矩阵存在。在这种情况下,逆矩阵是唯一的。
支持输入 float、double、cfloat 和 cdouble 数据类型。 还支持矩阵的批处理,如果
A是矩阵的批处理, 则输出具有相同的批处理维度。注意
当输入位于CUDA设备上时,此函数会与CPU同步该设备。如需不进行同步的版本,请参阅
torch.linalg.inv_ex()。注意
考虑使用
torch.linalg.solve()如果可能的话,用于通过逆矩阵在左侧乘以矩阵,如:linalg.solve(A, B) == linalg.inv(A) @ B # 当 B 是矩阵时
尽可能使用
solve()总是更可取的,因为它更快且在数值上更稳定,而不是显式计算逆矩阵。- Parameters
A (张量) – 形状为 (*, n, n) 的张量,其中 * 表示零个或多个批次维度,由可逆矩阵组成。
- Keyword Arguments
输出 (张量, 可选) – 输出张量。如果为无,则忽略。默认值:无。
- Raises
RuntimeError – 如果矩阵
A或批量矩阵中的任何矩阵A不可逆。
示例:
>>> A = torch.randn(4, 4) >>> Ainv = torch.linalg.inv(A) >>> torch.dist(A @ Ainv, torch.eye(4)) tensor(1.1921e-07) >>> A = torch.randn(2, 3, 4, 4) # 矩阵批次 >>> Ainv = torch.linalg.inv(A) >>> torch.dist(A @ Ainv, torch.eye(4)) tensor(1.9073e-06) >>> A = torch.randn(4, 4, dtype=torch.complex128) # 复数矩阵 >>> Ainv = torch.linalg.inv(A) >>> torch.dist(A @ Ainv, torch.eye(4)) tensor(7.5107e-16, dtype=torch.float64)