torch.einsum¶
- torch.einsum(equation, *operands) 张量 [源代码]¶
基于爱因斯坦求和约定,对输入的
操作数
元素的乘积沿指定维度求和。Einsum允许通过基于爱因斯坦求和约定的简写格式来计算许多常见的多维线性代数数组操作,该格式由
equation
给出。此格式的详细信息如下所述,但总体思路是为输入operands
的每个维度标记一些下标,并定义哪些下标是输出的一部分。然后通过沿不包含在输出中的下标的维度对operands
元素的乘积进行求和来计算输出。例如,矩阵乘法可以使用einsum计算为torch.einsum(“ij,jk->ik”, A, B)。这里,j是求和下标,i和k是输出下标(详见下文部分)。方程式:
字符串
equation
指定了输入operands
的每个维度的下标([a-zA-Z] 中的字母),顺序与维度相同,并通过逗号(‘,’)分隔每个操作数的下标,例如 ‘ij,jk’ 指定了两个二维操作数的下标。标记有相同下标的维度必须是可广播的,即它们的大小必须匹配或为 1。例外情况是,如果同一个输入操作数的下标重复,则标记有此下标的该操作数的维度必须在大小上匹配,并且该操作数将被替换为其在这些维度上的对角线。在equation
中恰好出现一次的下标将成为输出的一部分,按字母顺序递增排序。输出是通过将输入operands
逐元素相乘,并根据下标对齐其维度,然后对那些下标不属于输出的维度进行求和计算得出的。可选地,可以通过在方程末尾添加箭头(‘->’)并随后定义输出的下标来显式定义输出下标。例如,以下方程计算矩阵乘法的转置:‘ij,jk->ki’。输出下标必须至少出现在某个输入操作数中一次,并且在输出中最多出现一次。
省略号(‘…’)可以用来代替下标,以广播省略号所覆盖的维度。 每个输入操作数最多可以包含一个省略号,该省略号将覆盖未被下标覆盖的维度, 例如,对于具有5个维度的输入操作数,方程式‘ab…c’中的省略号覆盖第三和第四维度。省略号不需要在
操作数
之间覆盖相同数量的维度,但省略号的“形状”(它们所覆盖的维度的大小)必须一起广播。如果输出没有用箭头(‘->’)符号显式定义,省略号将首先出现在输出中(最左边的维度),然后是输入操作数中恰好出现一次的下标标签。例如,以下方程式实现了批量矩阵乘法‘…ij,…jk’。几点最后的说明:方程式中可能包含不同元素(下标、省略号、箭头和逗号)之间的空白,但像‘…’这样的内容是无效的。对于标量操作数,空字符串‘’是有效的。
注意
torch.einsum
处理省略号(‘…’)的方式与NumPy不同,它允许省略号覆盖的维度被求和,也就是说,省略号不需要成为输出的一部分。注意
此函数使用 opt_einsum(https://optimized-einsum.readthedocs.io/en/stable/)来加速计算或通过优化收缩顺序来减少内存消耗。当至少有三个输入时,会发生这种优化,因为否则顺序并不重要。请注意,找到_最优_路径是一个NP难问题,因此,opt_einsum 依赖于不同的启发式方法来实现接近最优的结果。如果 opt_einsum 不可用,默认顺序是从左到右进行收缩。
要绕过此默认行为,请添加以下行以禁用opt_einsum的使用并跳过路径计算:torch.backends.opt_einsum.enabled = False
要指定您希望 opt_einsum 用于计算收缩路径的策略,请添加以下行: torch.backends.opt_einsum.strategy = ‘auto’。默认策略是 ‘auto’,我们还支持 ‘greedy’ 和 ‘optimal’。需要注意的是,‘optimal’ 的运行时间是输入数量的阶乘!更多详情请参阅 opt_einsum 文档(https://optimized-einsum.readthedocs.io/en/stable/path_finding.html)。
注意
截至 PyTorch 1.10,
torch.einsum()
也支持子列表格式(见下文示例)。在这种格式中, 每个操作数的下标由子列表指定,子列表是范围在 [0, 52) 内的整数列表。这些子列表 跟随其操作数,并且可以在输入的末尾出现一个额外的子列表来指定输出的 下标,例如 torch.einsum(op1, sublist1, op2, sublist2, …, [subslist_out])。Python 的 Ellipsis 对象 可以在子列表中提供,以启用如上文“方程”部分所述的广播。示例:
>>> # 迹 >>> torch.einsum('ii', torch.randn(4, 4)) tensor(-1.2104) >>> # 对角线 >>> torch.einsum('ii->i', torch.randn(4, 4)) tensor([-0.1034, 0.7952, -0.2433, 0.4545]) >>> # 外积 >>> x = torch.randn(5) >>> y = torch.randn(4) >>> torch.einsum('i,j->ij', x, y) tensor([[ 0.1156, -0.2897, -0.3918, 0.4963], [-0.3744, 0.9381, 1.2685, -1.6070], [ 0.7208, -1.8058, -2.4419, 3.0936], [ 0.1713, -0.4291, -0.5802, 0.7350], [ 0.5704, -1.4290, -1.9323, 2.4480]]) >>> # 批量矩阵乘法 >>> As = torch.randn(3, 2, 5) >>> Bs = torch.randn(3, 5, 4) >>> torch.einsum('bij,bjk->bik', As, Bs) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # 使用子列表格式和省略号 >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2]) tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], [-1.6706, -0.8097, -0.8025, -2.1183]], [[ 4.2239, 0.3107, -0.5756, -0.2354], [-1.4558, -0.3460, 1.5087, -0.8530]], [[ 2.8153, 1.8787, -4.3839, -1.2112], [ 0.3728, -2.1131, 0.0921, 0.8305]]]) >>> # 批量置换 >>> A = torch.randn(2, 3, 4, 5) >>> torch.einsum('...ij->...ji', A).shape torch.Size([2, 3, 5, 4]) >>> # 等价于 torch.nn.functional.bilinear >>> A = torch.randn(3, 5, 4) >>> l = torch.randn(2, 5) >>> r = torch.randn(2, 4) >>> torch.einsum('bn,anm,bm->ba', l, A, r) tensor([[-0.3430, -5.2405, 0.4494], [ 0.3311, 5.5201, -3.0356]])