torch.linalg.multi_dot¶
- torch.linalg.multi_dot(tensors, *, out=None)¶
通过重新排序乘法操作,使得执行的算术运算最少,从而高效地对两个或更多矩阵进行乘法运算。
支持float、double、cfloat和cdouble数据类型的输入。 此函数不支持批量输入。
在
tensors中的每个张量都必须是二维的,除了第一个和最后一个张量可以是1D。如果第一个张量是形状为 (n,) 的1D向量,它将被视为形状为 (1, n) 的行向量,同样,如果最后一个张量是形状为 (n,) 的1D向量,它将被视为形状为 (n, 1) 的列向量。如果第一个和最后一个张量是矩阵,输出将是一个矩阵。 然而,如果其中一个是1D向量,那么输出将是一个1D向量。
与numpy.linalg.multi_dot的区别:
与numpy.linalg.multi_dot不同,第一个和最后一个张量必须是1D或2D,而NumPy允许它们为nD
警告
此函数不进行广播。
注意
此函数通过在计算最佳矩阵乘法顺序后链接
torch.mm()调用来实现。注意
两个形状为(a, b)和(b, c)的矩阵相乘的成本是a * b * c。给定矩阵A、B、C,其形状分别为(10, 100)、(100, 5)、(5, 50),我们可以计算不同乘法顺序的成本如下:
在这种情况下,先乘以A和B,然后再乘以C的速度快了10倍。
- Parameters
张量 (序列[张量]) – 要相乘的两个或多个张量。第一个和最后一个张量可以是1D或2D。其他所有张量必须是2D。
- Keyword Arguments
输出 (张量, 可选) – 输出张量。如果为无,则忽略。默认值:无。
示例:
>>> from torch.linalg import multi_dot >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])]) tensor(8) >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])]) tensor([8]) >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])]) tensor([[8]]) >>> A = torch.arange(2 * 3).view(2, 3) >>> B = torch.arange(3 * 2).view(3, 2) >>> C = torch.arange(2 * 2).view(2, 2) >>> multi_dot((A, B, C)) tensor([[ 26, 49], [ 80, 148]])