Shortcuts

torch.nn.functional.one_hot

torch.nn.functional.one_hot(tensor, num_classes=-1) LongTensor

接受形状为 (*) 的 LongTensor 索引值,并返回一个形状为 (*, num_classes) 的张量,该张量在除最后一个维度的索引与输入张量的相应值匹配的位置外,其他位置均为零,在这些匹配的位置上值为 1。

参见 维基百科上的One-hot

Parameters
  • 张量 (LongTensor) – 任意形状的类别值。

  • num_classes (int) – 类别总数。如果设置为 -1,类别数量将被推断为输入张量中最大类别值加一。

Returns

具有一个额外维度的LongTensor,在输入指示的最后一个维度的索引处具有1值,其他地方为0。

示例

>>> F.one_hot(torch.arange(0, 5) % 3)
tensor([[1, 0, 0],
        [0, 1, 0],
        [0, 0, 1],
        [1, 0, 0],
        [0, 1, 0]])
>>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
tensor([[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0]])
>>> F.one_hot(torch.arange(0, 6).view(3,2) % 3)
tensor([[[1, 0, 0],
         [0, 1, 0]],
        [[0, 0, 1],
         [1, 0, 0]],
        [[0, 1, 0],
         [0, 0, 1]]])
优云智算