展平¶
- class torch.nn.Flatten(start_dim=1, end_dim=-1)[源码]¶
将一个连续范围的维度展平为一个张量。
用于与
Sequential
一起使用,详情请参见torch.flatten()
。- Shape:
输入:,其中是维度的大小,表示任意数量的维度,包括没有维度。
输出: 。
- Examples::
>>> input = torch.randn(32, 1, 5, 5) >>> # 使用默认参数 >>> m = nn.Flatten() >>> output = m(input) >>> output.size() torch.Size([32, 25]) >>> # 使用非默认参数 >>> m = nn.Flatten(0, 2) >>> output = m(input) >>> output.size() torch.Size([160, 5])