torch.Tensor.scatter_add_¶
- Tensor.scatter_add_(dim, index, src) 张量¶
将张量
src中的所有值添加到self中,索引由index张量指定,类似于scatter_()。对于src中的每个值,它被添加到self中的一个索引,该索引由src中的索引指定 对于dimension != dim,并且由index中的相应值指定 对于dimension = dim。对于一个3维张量,
self被更新为:self[index[i][j][k]][j][k] += src[i][j][k] # 如果 dim == 0 self[i][index[i][j][k]][k] += src[i][j][k] # 如果 dim == 1 self[i][j][index[i][j][k]] += src[i][j][k] # 如果 dim == 2
self,index和src应该具有相同的维度数量。还要求对于所有维度d,index.size(d) <= src.size(d),并且对于所有维度d != dim,index.size(d) <= self.size(d)。请注意,index和src不会进行广播。注意
当在CUDA设备上使用张量时,此操作可能会表现出不确定性行为。更多信息请参见可重复性。
注意
反向传播仅在
src.shape == index.shape时实现。- Parameters
示例:
>>> src = torch.ones((2, 5)) >>> index = torch.tensor([[0, 1, 2, 0, 0]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) tensor([[1., 0., 0., 1., 1.], [0., 1., 0., 0., 0.], [0., 0., 1., 0., 0.]]) >>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src) tensor([[2., 0., 0., 1., 1.], [0., 2., 0., 0., 0.], [0., 0., 2., 1., 1.]])