Shortcuts

torch.where

torch.where(condition, input, other, *, out=None) 张量

返回一个张量,其中的元素根据conditioninputother中选择。

该操作定义为:

outi={inputiif conditioniotheriotherwise\text{out}_i = \begin{cases} \text{input}_i & \text{if } \text{condition}_i \\ \text{other}_i & \text{otherwise} \\ \end{cases}

注意

张量 condition, input, other 必须是 可广播的

Parameters
  • 条件 (BoolTensor) – 当为真(非零)时,生成输入,否则生成其他值

  • 输入 (张量标量) – 值(如果 input 是标量)或选择在索引处的值 其中 conditionTrue

  • 其他 (张量标量) – 值(如果 other 是标量)或选择在索引处的值 其中 conditionFalse

Keyword Arguments

输出 (张量, 可选) – 输出张量。

Returns

形状等于 conditioninputother 广播后的形状的张量

Return type

张量

示例:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, 1.0, 0.0)
tensor([[0., 1.],
        [1., 0.],
        [1., 0.]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)
torch.where(condition) tuple of LongTensor

torch.where(condition)torch.nonzero(condition, as_tuple=True) 相同。

注意

另请参阅 torch.nonzero()

优云智算