Shortcuts

张量视图

PyTorch 允许一个张量成为现有张量的 View。视图张量与其基础张量共享相同的基础数据。支持 View 避免了显式的数据复制,从而使我们能够进行快速且内存高效的重新塑形、切片和元素操作。

例如,要获取现有张量 t 的视图,您可以调用 t.view(...)

>>> t = torch.rand(4, 4)
>>> b = t.view(2, 8)
>>> t.storage().data_ptr() == b.storage().data_ptr()  # `t` 和 `b` 共享相同的基础数据。
True
# 修改视图张量也会改变基础张量。
>>> b[0][0] = 3.14
>>> t[0][0]
tensor(3.14)

由于视图与其基础张量共享底层数据,如果您在视图中编辑数据,它也会反映在基础张量中。

通常,PyTorch操作会返回一个新的张量作为输出,例如 add()。 但在视图操作的情况下,输出是输入张量的视图,以避免不必要的数据复制。 创建视图时不会发生数据移动,视图张量只是改变了它解释相同数据的方式。对连续张量进行视图操作可能会产生非连续张量。 用户应额外注意,因为连续性可能会对性能产生隐含的影响。 transpose() 是一个常见的例子。

>>> base = torch.tensor([[0, 1],[2, 3]])
>>> base.is_contiguous()
True
>>> t = base.transpose(0, 1)  # `t` 是 `base` 的一个视图。这里没有发生数据移动。
# 视图张量可能是非连续的。
>>> t.is_contiguous()
False
# 要获得一个连续的张量,调用 `.contiguous()` 以在 `t` 不连续时强制复制数据。
>>> c = t.contiguous()

供参考,以下是PyTorch中所有视图操作的完整列表:

注意

当通过索引访问张量的内容时,PyTorch遵循Numpy的行为,即基本索引返回视图,而高级索引返回副本。通过基本或高级索引进行的赋值是就地进行的。更多示例请参见Numpy索引文档

也值得提一下一些具有特殊行为的操作:

  • reshape(), reshape_as()flatten() 可以返回视图或新的张量,用户代码不应依赖于它是视图还是否。

  • contiguous() 如果输入张量已经是连续的,则返回自身,否则通过复制数据返回一个新的连续张量。

有关PyTorch内部实现的更详细介绍,请参阅ezyang关于PyTorch内部实现的博客文章

优云智算