torch.gather¶
- torch.gather(input, dim, index, *, sparse_grad=False, out=None) 张量¶
沿由dim指定的轴收集值。
对于一个3-D张量,输出由以下内容指定:
out[i][j][k] = input[index[i][j][k]][j][k] # 如果 dim == 0 out[i][j][k] = input[i][index[i][j][k]][k] # 如果 dim == 1 out[i][j][k] = input[i][j][index[i][j][k]] # 如果 dim == 2
input和index必须具有相同的维度数量。 还要求对于所有维度d != dim,index.size(d) <= input.size(d)。out将具有与index相同的形状。 请注意,input和index不会相互广播。- Parameters
- Keyword Arguments
示例:
>>> t = torch.tensor([[1, 2], [3, 4]]) >>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]])) tensor([[ 1, 1], [ 4, 3]])