展开¶
- class torch.nn.Unflatten(dim, unflattened_size)[源代码]¶
将一个张量维度展开为所需的形状。用于与
Sequential
一起使用。dim
指定输入张量要被展平的维度,当使用 Tensor 或 NamedTensor 时,它可以是 int 或 str。unflattened_size
是张量未展平维度的新形状,它可以是 tuple 或 list 或 torch.Size 用于 Tensor 输入;对于 NamedTensor 输入,它可以是 NamedShape((name, size) 元组的元组)。
- Shape:
输入:,其中 是维度
dim
的大小,而 表示包括无在内的任意数量的维度。输出: , 其中 =
unflattened_size
且 .
- Parameters
示例
>>> input = torch.randn(2, 50) >>> # 使用整数元组 >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, (2, 5, 5)) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # 使用torch.Size >>> m = nn.Sequential( >>> nn.Linear(50, 50), >>> nn.Unflatten(1, torch.Size([2, 5, 5])) >>> ) >>> output = m(input) >>> output.size() torch.Size([2, 2, 5, 5]) >>> # 使用命名形状(元组的元组) >>> input = torch.randn(2, 50, names=('N', 'features')) >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) >>> output = unflatten(input) >>> output.size() torch.Size([2, 2, 5, 5])