GroupRevRes

class dgl.nn.pytorch.conv.GroupRevRes(gnn_module, groups=2)[source]

Bases: Module

为GNNs引入的分组可逆残差连接,如使用1000层训练图神经网络中所述

它将输入节点特征 \(X\) 均匀地划分为 \(C\)\(X_1, X_2, \cdots, X_C\),跨通道维度。此外,它还会制作 \(C\) 个输入 GNN 模块 \(f_{w1}, \cdots, f_{wC}\) 的副本。在 前向传播过程中,每个 GNN 模块仅处理相应的节点特征组。

输出节点表示 \(X^{'}\) 的计算方式如下。

\[ \begin{align}\begin{aligned}X_0^{'} = \sum_{i=2}^{C}X_i\\X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\in\{1,\cdots,C\}\\X^{'} = X_1^{'} \, \Vert \, \ldots \, \Vert \, X_C^{'}\end{aligned}\end{align} \]

其中 \(g\) 是输入图,\(U\) 是任意附加的输入参数,如边特征,\(\, \Vert \,\) 是连接操作。

Parameters:
  • gnn_module (nn.Module) – 用于消息传递的GNN模块。GroupRevRes 将会克隆该模块 groups-1 次,总共生成 groups 个副本。 输入和输出的节点表示大小需要相同。它的前向 函数需要按顺序接收一个DGLGraph和相关的输入节点特征, 可选地跟随其他参数如边特征。

  • groups (int, 可选) – 组的数量。

示例

>>> import dgl
>>> import torch
>>> import torch.nn as nn
>>> from dgl.nn import GraphConv, GroupRevRes
>>> class GNNLayer(nn.Module):
...     def __init__(self, feats, dropout=0.2):
...         super(GNNLayer, self).__init__()
...         # Use BatchNorm and dropout to prevent gradient vanishing
...         # In particular if you use a large number of GNN layers
...         self.norm = nn.BatchNorm1d(feats)
...         self.conv = GraphConv(feats, feats)
...         self.dropout = nn.Dropout(dropout)
...
...     def forward(self, g, x):
...         x = self.norm(x)
...         x = self.dropout(x)
...         return self.conv(g, x)
>>> num_nodes = 5
>>> num_edges = 20
>>> feats = 32
>>> groups = 2
>>> g = dgl.rand_graph(num_nodes, num_edges)
>>> x = torch.randn(num_nodes, feats)
>>> conv = GNNLayer(feats // groups)
>>> model = GroupRevRes(conv, groups)
>>> out = model(g, x)
forward(g, x, *args)[source]

应用带有分组可逆残差连接的GNN模块。

Parameters:
  • g (DGLGraph) – The graph.

  • x (torch.Tensor) – 输入特征的形状为 \((N, D_{in})\),其中 \(D_{in}\) 是输入特征的大小,\(N\) 是节点的数量。

  • args – 传递给 gnn_module 的额外参数。

Returns:

输出特征的形状为 \((N, D_{in})\)

Return type:

torch.Tensor