Shortcuts

展开

class torch.nn.Unflatten(dim, unflattened_size)[源代码]

将一个张量维度展开为所需的形状。用于与Sequential一起使用。

  • dim 指定输入张量要被展平的维度,当使用 TensorNamedTensor 时,它可以是 intstr

  • unflattened_size 是张量未展平维度的新形状,它可以是 tuplelisttorch.Size 用于 Tensor 输入;对于 NamedTensor 输入,它可以是 NamedShape(name, size) 元组的元组)。

Shape:
  • 输入:(,Sdim,)(*, S_{\text{dim}}, *),其中 SdimS_{\text{dim}} 是维度 dim 的大小,而 * 表示包括无在内的任意数量的维度。

  • 输出: (,U1,...,Un,)(*, U_1, ..., U_n, *), 其中 UU = unflattened_sizei=1nUi=Sdim\prod_{i=1}^n U_i = S_{\text{dim}}.

Parameters
  • dim (联合[int, str]) – 要展开的维度

  • unflattened_size (Union[torch.Size, Tuple, List, NamedShape]) – 未展平维度的新形状

示例

>>> input = torch.randn(2, 50)
>>> # 使用整数元组
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, (2, 5, 5))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # 使用torch.Size
>>> m = nn.Sequential(
>>>     nn.Linear(50, 50),
>>>     nn.Unflatten(1, torch.Size([2, 5, 5]))
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # 使用命名形状(元组的元组)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])
优云智算