torch.diagonal_scatter¶
- torch.diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) 张量¶
将
src张量的值嵌入到input中,沿着input的对角元素,相对于dim1和dim2。此函数返回一个具有新存储的张量;它不会返回一个视图。
参数
offset控制要考虑的对角线:如果
offset= 0,则是主对角线。如果
offset大于 0,则位于主对角线上方。如果
offset小于 0,则位于主对角线下方。
- Parameters
注意
src必须具有适当的大小才能嵌入到input中。具体来说,它应该与torch.diagonal(input, offset, dim1, dim2)具有相同的形状。示例:
>>> a = torch.zeros(3, 3) >>> a tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]) >>> torch.diagonal_scatter(a, torch.ones(3), 0) tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) >>> torch.diagonal_scatter(a, torch.ones(2), 1) tensor([[0., 1., 0.], [0., 0., 1.], [0., 0., 0.]])