• Docs >
  • Broadcasting semantics
Shortcuts

广播语义

许多 PyTorch 操作支持 NumPy 的广播语义。 详情请参见 https://numpy.org/doc/stable/user/basics.broadcasting.html

简而言之,如果一个 PyTorch 操作支持广播,那么它的 Tensor 参数可以自动扩展为相同的大小(而不需要复制数据)。

通用语义

两个张量是“可广播的”如果以下规则成立:

  • 每个张量至少有一个维度。

  • 当从尾随维度开始迭代维度大小时,维度大小必须相等,其中一个为1,或者其中一个不存在。

例如:

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# 相同的形状总是可以广播的(即上述规则总是成立)

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x 和 y 不能广播,因为 x 至少没有一维

# 可以对齐尾随维度
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x 和 y 可以广播。
# 第一个尾随维度:两者的大小都是 1
# 第二个尾随维度:y 的大小是 1
# 第三个尾随维度:x 的大小等于 y 的大小
# 第四个尾随维度:y 的维度不存在

# 但:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x 和 y 不能广播,因为在第三个尾随维度中 2 != 3

如果两个张量 x, y 是“可广播的”,则结果张量的大小计算如下:

  • 如果xy的维度数量不相等,则在维度较少的张量前添加1,使其维度数量相等。

  • 然后,对于每个维度大小,结果维度大小是沿该维度的xy大小的最大值。

例如:

# 可以使尾随维度对齐以使阅读更容易
>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

# 但不是必须的:
>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

就地语义

一个复杂的问题是,就地操作不允许就地张量由于广播而改变形状。

例如:

>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])

# 但是:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

向后兼容性

PyTorch 的早期版本允许某些逐点函数在具有不同形状的张量上执行,只要每个张量中的元素数量相等即可。逐点操作将通过将每个张量视为一维来进行。现在,PyTorch 支持广播,并且“一维”逐点行为被视为已弃用,在张量不可广播但具有相同元素数量的情况下,将生成 Python 警告。

请注意,广播的引入可能会导致在两个张量形状不同但可广播且元素数量相同的情况下,出现向后不兼容的变化。例如:

>>> torch.add(torch.ones(4,1), torch.randn(4))

之前会生成一个大小为:torch.Size([4,1])的张量,但现在会生成一个大小为:torch.Size([4,4])的张量。 为了帮助识别代码中可能存在的由于广播引入的向后不兼容情况, 您可以将torch.utils.backcompat.broadcast_warning.enabled设置为True,这将在这种情况下生成一个Python警告。

例如:

>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self 和 other 的形状不同,但可以广播,并且具有相同数量的元素。
以向后不兼容的方式更改行为,改为广播而不是视为一维。