GroupRevRes
- class dgl.nn.pytorch.conv.GroupRevRes(gnn_module, groups=2)[source]
Bases:
Module
为GNNs引入的分组可逆残差连接,如使用1000层训练图神经网络中所述
它将输入节点特征 \(X\) 均匀地划分为 \(C\) 组 \(X_1, X_2, \cdots, X_C\),跨通道维度。此外,它还会制作 \(C\) 个输入 GNN 模块 \(f_{w1}, \cdots, f_{wC}\) 的副本。在 前向传播过程中,每个 GNN 模块仅处理相应的节点特征组。
输出节点表示 \(X^{'}\) 的计算方式如下。
\[ \begin{align}\begin{aligned}X_0^{'} = \sum_{i=2}^{C}X_i\\X_i^{'} = f_{wi}(X_{i-1}^{'}, g, U) + X_i, i\in\{1,\cdots,C\}\\X^{'} = X_1^{'} \, \Vert \, \ldots \, \Vert \, X_C^{'}\end{aligned}\end{align} \]其中 \(g\) 是输入图,\(U\) 是任意附加的输入参数,如边特征,\(\, \Vert \,\) 是连接操作。
- Parameters:
gnn_module (nn.Module) – 用于消息传递的GNN模块。
GroupRevRes
将会克隆该模块groups
-1 次,总共生成groups
个副本。 输入和输出的节点表示大小需要相同。它的前向 函数需要按顺序接收一个DGLGraph和相关的输入节点特征, 可选地跟随其他参数如边特征。groups (int, 可选) – 组的数量。
示例
>>> import dgl >>> import torch >>> import torch.nn as nn >>> from dgl.nn import GraphConv, GroupRevRes
>>> class GNNLayer(nn.Module): ... def __init__(self, feats, dropout=0.2): ... super(GNNLayer, self).__init__() ... # Use BatchNorm and dropout to prevent gradient vanishing ... # In particular if you use a large number of GNN layers ... self.norm = nn.BatchNorm1d(feats) ... self.conv = GraphConv(feats, feats) ... self.dropout = nn.Dropout(dropout) ... ... def forward(self, g, x): ... x = self.norm(x) ... x = self.dropout(x) ... return self.conv(g, x)
>>> num_nodes = 5 >>> num_edges = 20 >>> feats = 32 >>> groups = 2 >>> g = dgl.rand_graph(num_nodes, num_edges) >>> x = torch.randn(num_nodes, feats) >>> conv = GNNLayer(feats // groups) >>> model = GroupRevRes(conv, groups) >>> out = model(g, x)