dgl.sparse.sddmm
- dgl.sparse.sddmm(A: SparseMatrix, X1: Tensor, X2: Tensor) SparseMatrix [source]
采样-密集-密集矩阵乘法 (SDDMM)。
sddmm
将两个密集矩阵X1
和X2
进行矩阵乘法, 然后在稀疏矩阵A
的非零位置处对结果进行逐元素乘法。数学上
sddmm
的公式为:\[out = (X1 @ X2) * A\]特别是,
X1
和X2
可以是一维的,那么X1 @ X2
就变成了两个向量的外积(结果是一个矩阵)。- Parameters:
A (SparseMatrix) – 形状为
(L, N)
的稀疏矩阵X1 (torch.Tensor) – 形状为
(L, M)
或(L,)
的密集矩阵X2 (torch.Tensor) – 形状为
(M, N)
或(N,)
的密集矩阵
- Returns:
形状为
(L, N)
的稀疏矩阵- Return type:
示例
>>> indices = torch.tensor([[1, 1, 2], [2, 3, 3]]) >>> val = torch.arange(1, 4).float() >>> A = dglsp.spmatrix(indices, val, (3, 4)) >>> X1 = torch.randn(3, 5) >>> X2 = torch.randn(5, 4) >>> dglsp.sddmm(A, X1, X2) SparseMatrix(indices=tensor([[1, 1, 2], [2, 3, 3]]), values=tensor([-1.6585, -3.9714, -0.5406]), shape=(3, 4), nnz=3)