Shortcuts

torch.unflatten

torch.unflatten(input, dim, sizes) 张量

扩展输入张量的一个维度到多个维度。

另请参阅

torch.flatten() 是此函数的逆函数。它将多个维度合并为一个。

Parameters
  • 输入 (张量) – 输入张量。

  • dim (int) – 要展开的维度,指定为索引到 input.shape

  • sizes (Tuple[int]) – 未展平维度的新形状。 其中一个元素可以是 -1,在这种情况下,相应的输出维度将被推断。否则,sizes 的乘积 必须 等于 input.shape[dim]

Returns

指定维度未展平的输入视图。

Examples::
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (2, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(3, 4, 1), 1, (-1, 2)).shape
torch.Size([3, 2, 2, 1])
>>> torch.unflatten(torch.randn(5, 12, 3), -2, (2, 2, 3, 1, 1)).shape
torch.Size([5, 2, 2, 3, 1, 1, 3])
优云智算