Shortcuts

torch.take

torch.take(input, index) 张量

返回一个新张量,其中包含给定索引处的 input 元素。 输入张量被视为被视为一维张量。结果与索引具有相同的形状。

Parameters
  • 输入 (张量) – 输入张量。

  • 索引 (LongTensor) – 张量的索引

示例:

>>> src = torch.tensor([[4, 3, 5],
...                     [6, 7, 8]])
>>> torch.take(src, torch.tensor([0, 2, 5]))
tensor([ 4,  5,  8])