torch.stack¶
- torch.stack(tensors, dim=0, *, out=None) 张量¶
沿新维度连接一系列张量。
所有张量需要具有相同的大小。
另请参阅
torch.cat()沿现有维度连接给定的序列。- Parameters
张量 (序列 的 张量) – 要连接的张量序列
dim (int, 可选) – 要插入的维度。必须在 0 和连接张量的维度数量之间(包括 0 和维度数量)。默认值:0
- Keyword Arguments
输出 (张量, 可选) – 输出张量。
示例:
>>> x = torch.randn(2, 3) >>> x tensor([[ 0.3367, 0.1288, 0.2345], [ 0.2303, -1.1229, -0.1863]]) >>> x = torch.stack((x, x)) # 与 torch.stack((x, x), dim=0) 相同 >>> x tensor([[[ 0.3367, 0.1288, 0.2345], [ 0.2303, -1.1229, -0.1863]], [[ 0.3367, 0.1288, 0.2345], [ 0.2303, -1.1229, -0.1863]]]) >>> x.size() torch.Size([2, 2, 3]) >>> x = torch.stack((x, x), dim=1) tensor([[[ 0.3367, 0.1288, 0.2345], [ 0.3367, 0.1288, 0.2345]], [[ 0.2303, -1.1229, -0.1863], [ 0.2303, -1.1229, -0.1863]]]) >>> x = torch.stack((x, x), dim=2) tensor([[[ 0.3367, 0.3367], [ 0.1288, 0.1288], [ 0.2345, 0.2345]], [[ 0.2303, 0.2303], [-1.1229, -1.1229], [-0.1863, -0.1863]]]) >>> x = torch.stack((x, x), dim=-1) tensor([[[ 0.3367, 0.3367], [ 0.1288, 0.1288], [ 0.2345, 0.2345]], [[ 0.2303, 0.2303], [-1.1229, -1.1229], [-0.1863, -0.1863]]])