torch.take_along_dim¶
- torch.take_along_dim(input, indices, dim=None, *, out=None) 张量¶
从
input中选择沿给定dim的 1 维索引处的值。如果
dim为 None,输入数组将被视为已展平为 1D。返回沿某个维度的索引的函数,如
torch.argmax()和torch.argsort(), 是为与此函数配合使用而设计的。请参见下面的示例。注意
此函数类似于 NumPy 的 take_along_axis。 另请参阅
torch.gather()。- Parameters
- Keyword Arguments
输出 (张量, 可选) – 输出张量。
示例:
>>> t = torch.tensor([[10, 30, 20], [60, 40, 50]]) >>> max_idx = torch.argmax(t) >>> torch.take_along_dim(t, max_idx) tensor([60]) >>> sorted_idx = torch.argsort(t, dim=1) >>> torch.take_along_dim(t, sorted_idx, dim=1) tensor([[10, 20, 30], [40, 50, 60]])