torch.linalg.tensorinv¶
- torch.linalg.tensorinv(A, ind=2, *, out=None) 张量 ¶
计算
torch.tensordot()
的乘法逆元。如果 m 是
A
的前ind
维度的乘积,而 n 是其余维度的乘积,此函数期望 m 和 n 相等。 如果满足此条件,它将计算一个张量 X,使得 tensordot(A
, X,ind
) 在维度 m 上是单位矩阵。 X 将具有A
的形状,但前ind
维度被推到末尾。X.shape == A.shape[ind:] + A.shape[:ind]
支持float、double、cfloat和cdouble数据类型的输入。
注意
当
A
是一个 2 维张量且ind
= 1 时, 此函数计算A
的(乘法)逆(参见torch.linalg.inv()
)。注意
考虑使用
torch.linalg.tensorsolve()
来在左侧通过张量逆乘以一个张量,如:linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B) # 当 B 是一个形状为 A.shape[:B.ndim] 的张量时
在可能的情况下,总是优先使用
tensorsolve()
,因为它更快且在数值上更稳定,而不是显式计算伪逆。另请参阅
torch.linalg.tensorsolve()
计算 torch.tensordot(tensorinv(A
),B
).- Parameters
A (张量) – 要反转的张量。其形状必须满足 prod(
A
.shape[:ind
]) == prod(A
.shape[ind
:]).ind (int) – 计算
torch.tensordot()
的逆的索引位置。默认值:2。
- Keyword Arguments
输出 (张量, 可选) – 输出张量。如果为无,则忽略。默认值:无。
- Raises
RuntimeError – 如果重塑后的
A
不可逆或前ind
维度的乘积不等于其余维度的乘积。
示例:
>>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3)) >>> Ainv = torch.linalg.tensorinv(A, ind=2) >>> Ainv.shape torch.Size([8, 3, 4, 6]) >>> B = torch.randn(4, 6) >>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B)) True >>> A = torch.randn(4, 4) >>> Atensorinv = torch.linalg.tensorinv(A, ind=1) >>> Ainv = torch.linalg.inv(A) >>> torch.allclose(Atensorinv, Ainv) True