命名张量操作符覆盖范围¶
请先阅读命名张量以了解命名张量的介绍。
本文档是关于名称推断的参考资料,该过程定义了如何处理命名张量:
使用名称来提供额外的自动运行时正确性检查
将输入张量的名称传播到输出张量
以下是支持命名张量的所有操作及其相关的名称推断规则列表。
如果您在此处没有看到列出的操作,但它对您的用例有帮助,请搜索是否已提交问题,如果没有,请提交一个。
警告
命名张量 API 是实验性的,可能会发生变化。
API |
名称推断规则 |
|---|---|
查看文档 |
|
查看文档 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
|
无 |
|
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
查看文档 |
|
无 |
|
无 |
|
|
无 |
无 |
|
|
查看文档 |
|
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
|
无 |
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
|
|
无 |
|
对齐掩码到输入,然后从输入张量统一名称 |
|
查看文档 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
|
无 |
无 |
|
无 |
|
查看文档 |
|
无 |
|
无 |
|
查看文档 |
|
查看文档 |
|
无 |
|
无 |
|
仅允许不改变形状的调整大小 |
|
仅允许不改变形状的调整大小 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
无 |
|
查看文档 |
|
无 |
|
无 |
|
保留输入名称¶
所有点对点一元函数也遵循这一规则,以及其他一些一元函数。
检查名称:无
传播名称:输入张量的名称被传播到输出。
>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.abs().names
('N', 'C')
移除维度¶
所有归约操作,如sum(),通过在所需维度上进行归约来移除维度。其他操作,如select()和
squeeze(),也会移除维度。
无论何处,只要可以传递一个整数维度索引给操作符,也可以传递一个维度名称。接受维度索引列表的函数也可以接受一个维度名称列表。
检查名称:如果
dim或dims作为名称列表传递进来,检查这些名称是否存在于self中。传播名称:如果输入张量的维度由
dim或dims指定,但这些维度在输出张量中不存在,那么这些维度的相应名称 不会出现在output.names中。
>>> x = torch.randn(1, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.squeeze('N').names
('C', 'H', 'W')
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C']).names
('H', 'W')
# 使用keepdim=True的缩减操作实际上不会移除维度。
>>> x = torch.randn(3, 3, 3, 3, names=('N', 'C', 'H', 'W'))
>>> x.sum(['N', 'C'], keepdim=True).names
('N', 'C', 'H', 'W')
统一输入中的名称¶
所有二元算术运算都遵循这一规则。广播操作仍然从右到左按位置广播,以保持与未命名张量的兼容性。要通过名称执行显式广播,请使用 Tensor.align_as()。
检查名称:所有名称必须从右侧开始按位置匹配。即,在
tensor + other中,对于所有i在(-min(tensor.dim(), other.dim()) + 1, -1]范围内,match(tensor.names[i], other.names[i])必须为真。检查名称:此外,所有命名维度必须从右侧对齐。 在匹配过程中,如果我们匹配一个命名维度
A与一个未命名维度None,那么A不得出现在具有未命名维度的张量中。传播名称:从两个张量的右侧统一成对的名称以生成输出名称。
例如,
# tensor: 张量[ N, None]
# other: 张量[None, C]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, 3, names=(None, 'C'))
>>> (tensor + other).names
('N', 'C')
检查名称:
match(tensor.names[-1], other.names[-1])是Truematch(tensor.names[-2], tensor.names[-2])是True因为我们匹配了
None在tensor中与'C', 检查以确保'C'不存在于tensor中(它不存在)。检查确保
'N'不存在于other中(确实不存在)。
最后,输出名称通过以下方式计算:
[unify('N', None), unify(None, 'C')] = ['N', 'C']
更多示例:
# 从右边开始的维度不匹配:
# 张量: Tensor[N, C]
# 其他: Tensor[ N]
>>> tensor = torch.randn(3, 3, names=('N', 'C'))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: 尝试广播维度时出错 ['N', 'C'] 和维度
['N']: 维度 'C' 和维度 'N' 从右边开始处于相同位置但
不匹配.
# 匹配tensor.names[-1]和other.names[-1]时维度未对齐:
# 张量: Tensor[N, None]
# 其他: Tensor[ N]
>>> tensor = torch.randn(3, 3, names=('N', None))
>>> other = torch.randn(3, names=('N',))
>>> (tensor + other).names
RuntimeError: 尝试广播维度时维度未对齐 ['N'] 和
维度 ['N', None]: 维度 'N' 在两个列表中从右边开始出现在不同位置
不同位置.
注意
在最后两个示例中,可以通过名称对齐张量,然后执行加法操作。使用 Tensor.align_as() 按名称对齐张量,或使用 Tensor.align_to() 将张量对齐到自定义的维度顺序。
置换维度¶
一些操作,如Tensor.t(),会改变维度的顺序。维度名称
附加在各个维度上,因此它们也会随之改变顺序。
如果操作符接受位置索引 dim,它也可以接受一个维度名称作为 dim。
检查名称:如果将
dim作为名称传递,检查它是否存在于张量中。传播名称:以与正在排列的维度相同的方式排列维度名称。
>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.transpose('N', 'C').names
('C', 'N')
合同之外的维度¶
矩阵乘法函数遵循某种变体。让我们先了解torch.mm(),然后再推广到批量矩阵乘法的规则。
对于 torch.mm(tensor, other):
检查名称:无
传播名称:结果名称是
(tensor.names[-2], other.names[-1])。
>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, 3, names=('in', 'out'))
>>> x.mm(y).names
('N', 'out')
本质上,矩阵乘法在两个维度上执行点积,将其折叠。当两个张量进行矩阵乘法时,收缩的维度消失,不会出现在输出张量中。
torch.mv(), torch.dot() 以类似的方式工作:名称推断不检查输入名称并删除参与点积的维度:
>>> x = torch.randn(3, 3, names=('N', 'D'))
>>> y = torch.randn(3, names=('something',))
>>> x.mv(y).names
('N',)
现在,让我们来看一下 torch.matmul(tensor, other)。假设 tensor.dim() >= 2
并且 other.dim() >= 2。
检查名称:检查输入的批次维度是否对齐且可广播。 参见统一输入名称以了解输入对齐的含义。
传播名称:结果名称通过统一批次维度并移除被收缩的维度来获得:
unify(tensor.names[:-2], other.names[:-2]) + (tensor.names[-2], other.names[-1])。
示例:
# 矩阵张量['C', 'D']和矩阵张量['E', 'F']的批量矩阵乘法。
# 'A', 'B' 是批次维度。
>>> x = torch.randn(3, 3, 3, 3, names=('A', 'B', 'C', 'D'))
>>> y = torch.randn(3, 3, 3, names=('B', 'E', 'F'))
>>> torch.matmul(x, y).names
('A', 'B', 'C', 'F')
最后,有许多矩阵乘法函数的融合add版本。例如,addmm()
和 addmv()。这些被视为组合名称推断,例如mm()和
名称推断为add()。
工厂函数¶
工厂函数现在接受一个新的 names 参数,用于将名称与每个维度关联起来。
>>> torch.zeros(2, 3, names=('N', 'C'))
张量([[0., 0., 0.],
[0., 0., 0.]], names=('N', 'C'))
输出函数和就地变体¶
指定为 out= 张量的张量具有以下行为:
如果没有命名维度,则从操作中计算出的名称会传播到它。
如果它有任何命名维度,那么从操作中计算出的名称必须与现有名称完全相等。否则,操作会出错。
所有就地方法都会修改输入,使其名称等于从名称推断中计算出的名称。例如:
>>> x = torch.randn(3, 3)
>>> y = torch.randn(3, 3, names=('N', 'C'))
>>> x.names
(None, None)
>>> x += y
>>> x.names
('N', 'C')