权重基础

class dgl.nn.pytorch.utils.WeightBasis(shape, num_bases, num_outputs)[source]

Bases: Module

基于使用图卷积网络建模关系数据的基础分解

可以描述如下:

\[W_o = \sum_{b=1}^B a_{ob} V_b\]

每个权重输出 \(W_o\) 本质上是基变换 \(V_b\) 与系数 \(a_{ob}\) 的线性组合。

如果作为一种正则化形式在大参数矩阵上有用。因此,权重输出的数量通常大于基的数量。

Parameters:
  • shape (tuple[int]) – 基础参数的形状。

  • num_bases (int) – 基础数量。

  • num_outputs (int) – 输出的数量。

forward()[source]

前向计算

Returns:

weight – 由形状为 (num_outputs,) + shape 组成的权重张量

Return type:

torch.Tensor