torch.linalg.matrix_power¶
- torch.linalg.matrix_power(A, n, *, out=None) 张量¶
计算方阵的n次幂,其中n为整数。
支持输入 float、double、cfloat 和 cdouble 数据类型。 还支持矩阵的批处理,如果
A是矩阵的批处理,则输出具有相同的批处理维度。如果
n= 0,它返回与A相同形状的单位矩阵(或批次)。如果n为负数,它返回每个矩阵的逆(如果可逆)的 abs(n) 次幂。注意
考虑使用
torch.linalg.solve()如果可能的话,用于在左侧乘以一个负幂矩阵,因为如果n> 0:torch.linalg.solve(matrix_power(A, n), B) == matrix_power(A, -n) @ B
总是优先使用
solve(),因为它更快且在数值上更稳定,而不是显式计算 。另请参阅
torch.linalg.solve()使用数值稳定的算法计算A.inverse() @B。- Parameters
- Keyword Arguments
输出 (张量, 可选) – 输出张量。如果为无,则忽略。默认值:无。
- Raises
RuntimeError – 如果
n< 0 并且矩阵A或批量矩阵中的任何矩阵A不可逆。
示例:
>>> A = torch.randn(3, 3) >>> torch.linalg.matrix_power(A, 0) tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) >>> torch.linalg.matrix_power(A, 3) tensor([[ 1.0756, 0.4980, 0.0100], [-1.6617, 1.4994, -1.9980], [-0.4509, 0.2731, 0.8001]]) >>> torch.linalg.matrix_power(A.expand(2, -1, -1), -2) tensor([[[ 0.2640, 0.4571, -0.5511], [-1.0163, 0.3491, -1.5292], [-0.4899, 0.0822, 0.2773]], [[ 0.2640, 0.4571, -0.5511], [-1.0163, 0.3491, -1.5292], [-0.4899, 0.0822, 0.2773]]])