torch.Tensor.masked_scatter_¶
- Tensor.masked_scatter_(mask, source)¶
从
source中复制元素到self张量中,位置在mask为 True 的地方。source中的元素从self的第 0 个位置开始,按顺序逐个复制,每次mask为 True 时复制一个元素。mask的形状必须与底层张量的形状 可广播。source应至少包含与mask中 1 的数量相同的元素。- Parameters
mask (BoolTensor) – 布尔掩码
源 (Tensor) – 要复制的张量
注意
The
mask作用于self张量,而不是给定的source张量。示例
>>> self = torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]) >>> mask = torch.tensor([[0, 0, 0, 1, 1], [1, 1, 0, 1, 1]]) >>> source = torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]) >>> self.masked_scatter_(mask, source) tensor([[0, 0, 0, 0, 1], [2, 3, 0, 4, 5]])