torch.argwhere¶
- torch.argwhere(input) 张量¶
返回一个包含
input中所有非零元素索引的张量。结果中的每一行包含input中一个非零元素的索引。结果按字典顺序排序,最后一个索引变化最快(C 风格)。如果
input有 维,那么生成的索引张量out的大小为 ,其中 是input张量中非零元素的总数。注意
此函数类似于 NumPy 的 argwhere。
当
input在 CUDA 上时,此函数会导致主机-设备同步。- Parameters
{input} –
示例:
>>> t = torch.tensor([1, 0, 1]) >>> torch.argwhere(t) tensor([[0], [2]]) >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) >>> torch.argwhere(t) tensor([[0, 0], [0, 2], [1, 1], [1, 2]])