Shortcuts

torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None) 张量

沿由dim指定的轴收集值。

对于一个3-D张量,输出由以下内容指定:

out[i][j][k] = input[index[i][j][k]][j][k]  # 如果 dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # 如果 dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # 如果 dim == 2

inputindex 必须具有相同的维度数量。 还要求对于所有维度 d != dimindex.size(d) <= input.size(d)out 将具有与 index 相同的形状。 请注意,inputindex 不会相互广播。

Parameters
  • 输入 (张量) – 源张量

  • dim (int) – 要沿其索引的轴

  • 索引 (LongTensor) – 要收集的元素的索引

Keyword Arguments
  • sparse_grad (布尔值, 可选) – 如果True,相对于input的梯度将是一个稀疏张量。

  • 输出 (张量, 可选) – 目标张量

示例:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])
优云智算