dgl.ops.segment_mm
- dgl.ops.segment_mm(a, b, seglen_a)[source]
根据段执行矩阵乘法。
假设
seglen_a == [10, 5, 0, 3]
,操作符将执行四次矩阵乘法:a[0:10] @ b[0], a[10:15] @ b[1], a[15:15] @ b[2], a[15:18] @ b[3]
- Parameters:
a (张量) – 左操作数,形状为
(N, D1)
的二维张量b (张量) – 右操作数,形状为
(R, D1, D2)
的三维张量seglen_a (Tensor) – 一个形状为
(R,)
的整数张量。每个元素是输入a
的段长度。所有元素的总和必须等于N
。
- Returns:
输出的密集矩阵形状为
(N, D2)
- Return type:
张量