dgl.ops.gather_mm

dgl.ops.gather_mm(a, b, *, idx_b)[source]

根据给定的索引收集数据并执行矩阵乘法。

让结果张量为 c,操作符执行以下计算:

c[i] = a[i] @ b[idx_b[i]] , 其中 len(c) == len(idx_b)

Parameters:
  • a (张量) – 一个形状为 (N, D1) 的二维张量

  • b (张量) – 一个形状为 (R, D1, D2) 的三维张量

  • idx_b (Tensor, optional) – 一个形状为 (N,) 的一维整数张量。

Returns:

The output dense matrix of shape (N, D2)

Return type:

张量