TypedLinear

class dgl.nn.pytorch.TypedLinear(in_size, out_size, num_types, regularizer=None, num_bases=None)[source]

Bases: Module

根据类型进行线性变换。

对于输入批次中的每个样本 \(x \in X\),应用线性变换 \(xW_t\),其中 \(t\)\(x\) 的类型。

该模块支持两种正则化方法(基础分解和块对角分解),这些方法由“使用图卷积网络建模关系数据”提出。

基础正则化通过以下方式分解 \(W_t\)

\[W_t^{(l)} = \sum_{b=1}^B a_{tb}^{(l)}V_b^{(l)}\]

其中 \(B\) 是基数的数量,\(V_b^{(l)}\) 与系数 \(a_{tb}^{(l)}\) 线性组合。

块对角分解正则化将 \(W_t\) 分解为 \(B\) 个块对角矩阵。我们将 \(B\) 称为基的数量:

\[W_t^{(l)} = \oplus_{b=1}^B Q_{tb}^{(l)}\]

其中 \(B\) 是基数的数量,\(Q_{tb}^{(l)}\) 是具有形状 \(R^{(d^{(l+1)}/B)\times(d^{l}/B)}\) 的块基数。

Parameters:
  • in_size (int) – 输入特征大小。

  • out_size (int) – Output feature size.

  • num_types (int) – 类型的总数。

  • regularizer (str, optional) –

    使用哪种权重正则化器“basis”或“bdd”:

    • “basis”是basis-decomposition的缩写。

    • “bdd”是block-diagonal-decomposition的缩写。

    默认情况下不应用正则化。

  • num_bases (int, 可选) – 基础数量。当指定regularizer时需要。通常小于num_types。 默认值: None

示例

无正则化。

>>> from dgl.nn import TypedLinear
>>> import torch
>>>
>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])

基于正则化

>>> x = torch.randn(100, 32)
>>> x_type = torch.randint(0, 5, (100,))
>>> m = TypedLinear(32, 64, 5, regularizer='basis', num_bases=4)
>>> y = m(x, x_type)
>>> print(y.shape)
torch.Size([100, 64])
forward(x, x_type, sorted_by_type=False)[source]

前向计算。

Parameters:
  • x (torch.Tensor) – 一个2D输入张量。形状:(N, D1)

  • x_type (torch.Tensor) – 一个一维整数张量,存储x中元素的类型,具有一一对应关系。形状:(N,)

  • sorted_by_type (bool, optional) – 输入是否已按类型排序。在预排序的输入上向前传递可能会更快。

Returns:

y – 转换后的输出张量。形状:(N, D2)

Return type:

torch.Tensor

reset_parameters()[source]

重置参数