dgl.ops.segment_mm
- dgl.ops.segment_mm(a, b, seglen_a)[source]
根据段执行矩阵乘法。
假设
seglen_a == [10, 5, 0, 3],操作符将执行四次矩阵乘法:a[0:10] @ b[0], a[10:15] @ b[1], a[15:15] @ b[2], a[15:18] @ b[3]
- Parameters:
a (张量) – 左操作数,形状为
(N, D1)的二维张量b (张量) – 右操作数,形状为
(R, D1, D2)的三维张量seglen_a (Tensor) – 一个形状为
(R,)的整数张量。每个元素是输入a的段长度。所有元素的总和必须等于N。
- Returns:
输出的密集矩阵形状为
(N, D2)- Return type:
张量