Shortcuts

torch.kthvalue

torch.kthvalue(input, k, dim=None, keepdim=False, *, out=None)

返回一个命名元组 (values, indices),其中 values 是每行在给定维度 dim 中的第 k 个最小元素,indices 是找到的每个元素的索引位置。

如果未给出dim,则选择输入的最后一个维度。

如果 keepdimTrue,则 valuesindices 张量 与 input 的大小相同,除了在 dim 维度上 它们的大小为1。否则,dim 会被压缩 (参见 torch.squeeze()),导致 valuesindices 张量的维度比 input 张量少1。

注意

input 是 CUDA 张量并且存在多个有效的 k 值时,此函数可能会非确定性地返回 indices 中的任何一个。

Parameters
  • 输入 (张量) – 输入张量。

  • k (int) – 第k小的元素的k值

  • dim (int, 可选) – 沿着该维度查找第k个值

  • keepdim (布尔值) – 输出张量是否保留dim

Keyword Arguments

输出 (元组, 可选) – 输出元组 (Tensor, LongTensor) 可以可选地提供以用作输出缓冲区

示例:

>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1.,  2.,  3.,  4.,  5.])
>>> torch.kthvalue(x, 4)
torch.return_types.kthvalue(values=tensor(4.), indices=tensor(3))

>>> x= torch.arange(1., 7.).resize_(2, 3)
>>> x
tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.]])
>>> torch.kthvalue(x, 2, 0, True)
torch.return_types.kthvalue(values=tensor([[4., 5., 6.]]), indices=tensor([[1, 1, 1]]))
优云智算