Shortcuts

torch.Tensor.masked_scatter_

Tensor.masked_scatter_(mask, source)

source 中复制元素到 self 张量中,位置在 mask 为 True 的地方。source 中的元素从 self 的第 0 个位置开始,按顺序逐个复制,每次 mask 为 True 时复制一个元素。 mask 的形状必须与底层张量的形状 可广播source 应至少包含与 mask 中 1 的数量相同的元素。

Parameters
  • mask (BoolTensor) – 布尔掩码

  • (Tensor) – 要复制的张量

注意

The mask 作用于 self 张量,而不是给定的 source 张量。

示例

>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]])
>>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]])
>>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])
>>> self.masked_scatter_(mask, source)
tensor([[0, 0, 0, 0, 1],
        [2, 3, 0, 4, 5]])
优云智算