dgl.ops.gather_mm
- dgl.ops.gather_mm(a, b, *, idx_b)[source]
根据给定的索引收集数据并执行矩阵乘法。
让结果张量为
c
,操作符执行以下计算:c[i] = a[i] @ b[idx_b[i]] , 其中 len(c) == len(idx_b)
- Parameters:
a (张量) – 一个形状为
(N, D1)
的二维张量b (张量) – 一个形状为
(R, D1, D2)
的三维张量idx_b (Tensor, optional) – 一个形状为
(N,)
的一维整数张量。
- Returns:
The output dense matrix of shape
(N, D2)
- Return type:
张量