torch.sparse.mm¶
- torch.sparse.mm()¶
对稀疏矩阵
mat1
和(稀疏或密集)矩阵mat2
进行矩阵乘法。类似于torch.mm()
,如果mat1
是一个 张量,mat2
是一个 张量,输出将是一个 张量。 当mat1
是 COO 张量时,它必须有 sparse_dim = 2。 当输入是 COO 张量时,此函数还支持两个输入的反向传播。支持CSR和COO存储格式。
注意
此函数不支持对CSR矩阵进行导数计算。
此函数还额外接受一个可选的
reduce
参数,该参数允许指定一个可选的归约操作,在数学上执行以下操作:其中 定义了归约运算符。
reduce
仅在 CPU 设备的 CSR 存储格式上实现。- Parameters
- Shape:
该函数的输出张量格式如下: - 稀疏 x 稀疏 -> 稀疏 - 稀疏 x 密集 -> 密集
示例:
>>> a = torch.tensor([[1., 0, 2], [0, 3, 0]]).to_sparse().requires_grad_() >>> a tensor(indices=tensor([[0, 0, 1], [0, 2, 1]]), values=tensor([1., 2., 3.]), size=(2, 3), nnz=3, layout=torch.sparse_coo, requires_grad=True) >>> b = torch.tensor([[0, 1.], [2, 0], [0, 0]], requires_grad=True) >>> b tensor([[0., 1.], [2., 0.], [0., 0.]], requires_grad=True) >>> y = torch.sparse.mm(a, b) >>> y tensor([[0., 1.], [6., 0.]], grad_fn=
) >>> y.sum().backward() >>> a.grad tensor(indices=tensor([[0, 0, 1], [0, 2, 1]]), values=tensor([1., 0., 2.]), size=(2, 3), nnz=3, layout=torch.sparse_coo) >>> c = a.detach().to_sparse_csr() >>> c tensor(crow_indices=tensor([0, 2, 3]), col_indices=tensor([0, 2, 1]), values=tensor([1., 2., 3.]), size=(2, 3), nnz=3, layout=torch.sparse_csr) >>> y1 = torch.sparse.mm(c, b, 'sum') >>> y1 tensor([[0., 1.], [6., 0.]], grad_fn=) >>> y2 = torch.sparse.mm(c, b, 'max') >>> y2 tensor([[0., 1.], [6., 0.]], grad_fn=)