Shortcuts

torch.linalg.tensorinv

torch.linalg.tensorinv(A, ind=2, *, out=None) 张量

计算 torch.tensordot() 的乘法逆元。

如果 mA 的前 ind 维度的乘积,而 n 是其余维度的乘积,此函数期望 mn 相等。 如果满足此条件,它将计算一个张量 X,使得 tensordot(A, X, ind) 在维度 m 上是单位矩阵。 X 将具有 A 的形状,但前 ind 维度被推到末尾。

X.shape == A.shape[ind:] + A.shape[:ind]

支持float、double、cfloat和cdouble数据类型的输入。

注意

A 是一个 2 维张量且 ind= 1 时, 此函数计算 A 的(乘法)逆(参见 torch.linalg.inv())。

注意

考虑使用 torch.linalg.tensorsolve() 来在左侧通过张量逆乘以一个张量,如:

linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B)  # 当 B 是一个形状为 A.shape[:B.ndim] 的张量时

在可能的情况下,总是优先使用 tensorsolve(),因为它更快且在数值上更稳定,而不是显式计算伪逆。

另请参阅

torch.linalg.tensorsolve() 计算 torch.tensordot(tensorinv(A), B).

Parameters
  • A (张量) – 要反转的张量。其形状必须满足 prod(A.shape[:ind]) == prod(A.shape[ind:]).

  • ind (int) – 计算 torch.tensordot() 的逆的索引位置。默认值:2

Keyword Arguments

输出 (张量, 可选) – 输出张量。如果为,则忽略。默认值:

Raises

RuntimeError – 如果重塑后的 A 不可逆或前 ind 维度的乘积不等于其余维度的乘积。

示例:

>>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3))
>>> Ainv = torch.linalg.tensorinv(A, ind=2)
>>> Ainv.shape
torch.Size([8, 3, 4, 6])
>>> B = torch.randn(4, 6)
>>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B))
True

>>> A = torch.randn(4, 4)
>>> Atensorinv = torch.linalg.tensorinv(A, ind=1)
>>> Ainv = torch.linalg.inv(A)
>>> torch.allclose(Atensorinv, Ainv)
True
优云智算