Shortcuts

torch.linalg.matrix_power

torch.linalg.matrix_power(A, n, *, out=None) 张量

计算方阵的n次幂,其中n为整数。

支持输入 float、double、cfloat 和 cdouble 数据类型。 还支持矩阵的批处理,如果 A 是矩阵的批处理,则输出具有相同的批处理维度。

如果 n= 0,它返回与 A 相同形状的单位矩阵(或批次)。如果 n 为负数,它返回每个矩阵的逆(如果可逆)的 abs(n) 次幂。

注意

考虑使用 torch.linalg.solve() 如果可能的话,用于在左侧乘以一个负幂矩阵,因为如果 n> 0

torch.linalg.solve(matrix_power(A, n), B) == matrix_power(A, -n)  @ B

总是优先使用 solve(),因为它更快且在数值上更稳定,而不是显式计算 AnA^{-n}

另请参阅

torch.linalg.solve() 使用数值稳定的算法计算 A.inverse() @ B

Parameters
  • A (张量) – 形状为 (*, m, m) 的张量,其中 * 表示零个或多个批次维度。

  • n (int) – 指数。

Keyword Arguments

输出 (张量, 可选) – 输出张量。如果为,则忽略。默认值:

Raises

RuntimeError – 如果 n< 0 并且矩阵 A 或批量矩阵中的任何矩阵 A 不可逆。

示例:

>>> A = torch.randn(3, 3)
>>> torch.linalg.matrix_power(A, 0)
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
>>> torch.linalg.matrix_power(A, 3)
tensor([[ 1.0756,  0.4980,  0.0100],
        [-1.6617,  1.4994, -1.9980],
        [-0.4509,  0.2731,  0.8001]])
>>> torch.linalg.matrix_power(A.expand(2, -1, -1), -2)
tensor([[[ 0.2640,  0.4571, -0.5511],
        [-1.0163,  0.3491, -1.5292],
        [-0.4899,  0.0822,  0.2773]],
        [[ 0.2640,  0.4571, -0.5511],
        [-1.0163,  0.3491, -1.5292],
        [-0.4899,  0.0822,  0.2773]]])
优云智算