6.6 实现自定义GNN模块以进行小批量训练
注意
本教程 的内容与本节的同构图案例相似。
如果您熟悉如何为同质或异质图编写自定义GNN模块以更新整个图(参见第3章:构建GNN模块),那么在MFGs上计算的代码是类似的,不同之处在于节点被分为输入节点和输出节点。
例如,考虑以下自定义图卷积模块代码。请注意,它不一定是最有效的实现之一 - 它们仅作为自定义GNN模块可能是什么样子的示例。
class CustomGraphConv(nn.Module):
def __init__(self, in_feats, out_feats):
super().__init__()
self.W = nn.Linear(in_feats * 2, out_feats)
def forward(self, g, h):
with g.local_scope():
g.ndata['h'] = h
g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))
如果你有一个用于全图的自定义消息传递神经网络模块,并且你想让它适用于MFGs,你只需要重写前向函数如下。请注意,来自全图实现的相应语句已被注释;你可以将原始语句与新语句进行比较。
class CustomGraphConv(nn.Module):
def __init__(self, in_feats, out_feats):
super().__init__()
self.W = nn.Linear(in_feats * 2, out_feats)
# h is now a pair of feature tensors for input and output nodes, instead of
# a single feature tensor.
# def forward(self, g, h):
def forward(self, block, h):
# with g.local_scope():
with block.local_scope():
# g.ndata['h'] = h
h_src = h
h_dst = h[:block.number_of_dst_nodes()]
block.srcdata['h'] = h_src
block.dstdata['h'] = h_dst
# g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
block.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
# return self.W(torch.cat([g.ndata['h'], g.ndata['h_neigh']], 1))
return self.W(torch.cat(
[block.dstdata['h'], block.dstdata['h_neigh']], 1))
通常,你需要执行以下操作以使你的NN模块适用于MFGs。
通过切片前几行从输入特征中获取输出节点的特征。行数可以通过
block.number_of_dst_nodes
获取。将
g.ndata
替换为block.srcdata
用于输入节点的特征,或block.dstdata
用于输出节点的特征,如果 原始图只有一个节点类型。如果原始图具有多种节点类型,请将
g.nodes
替换为block.srcnodes
以获取输入节点的特征,或block.dstnodes
以获取输出节点的特征。将
g.num_nodes
替换为block.number_of_src_nodes
或block.number_of_dst_nodes
以分别获取输入节点或输出节点的数量。
Heterogeneous graphs
对于异构图,编写自定义GNN模块的方式是相似的。例如,考虑以下适用于整个图的模块。
class CustomHeteroGraphConv(nn.Module):
def __init__(self, g, in_feats, out_feats):
super().__init__()
self.Ws = nn.ModuleDict()
for etype in g.canonical_etypes:
utype, _, vtype = etype
self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
for ntype in g.ntypes:
self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])
def forward(self, g, h):
with g.local_scope():
for ntype in g.ntypes:
g.nodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
g.nodes[ntype].data['h_src'] = h[ntype]
for etype in g.canonical_etypes:
utype, _, vtype = etype
g.update_all(
fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
etype=etype)
g.nodes[vtype].data['h_dst'] = g.nodes[vtype].data['h_dst'] + \
self.Ws[etype](g.nodes[vtype].data['h_neigh'])
return {ntype: g.nodes[ntype].data['h_dst'] for ntype in g.ntypes}
对于CustomHeteroGraphConv
,原则是根据特征是用于输入还是输出,将g.nodes
替换为g.srcnodes
或g.dstnodes
。
class CustomHeteroGraphConv(nn.Module):
def __init__(self, g, in_feats, out_feats):
super().__init__()
self.Ws = nn.ModuleDict()
for etype in g.canonical_etypes:
utype, _, vtype = etype
self.Ws[etype] = nn.Linear(in_feats[utype], out_feats[vtype])
for ntype in g.ntypes:
self.Vs[ntype] = nn.Linear(in_feats[ntype], out_feats[ntype])
def forward(self, g, h):
with g.local_scope():
for ntype in g.ntypes:
h_src, h_dst = h[ntype]
g.dstnodes[ntype].data['h_dst'] = self.Vs[ntype](h[ntype])
g.srcnodes[ntype].data['h_src'] = h[ntype]
for etype in g.canonical_etypes:
utype, _, vtype = etype
g.update_all(
fn.copy_u('h_src', 'm'), fn.mean('m', 'h_neigh'),
etype=etype)
g.dstnodes[vtype].data['h_dst'] = \
g.dstnodes[vtype].data['h_dst'] + \
self.Ws[etype](g.dstnodes[vtype].data['h_neigh'])
return {ntype: g.dstnodes[ntype].data['h_dst']
for ntype in g.ntypes}
编写适用于同构图、二分图和MFG的模块
DGL中的所有消息传递模块都适用于同构图、单向二分图(具有两种节点类型和一种边类型)以及具有一种边类型的MFG。本质上,内置DGL神经网络模块的输入图和特征必须满足以下任一情况。
如果输入特征是一对张量,那么输入图必须是单向二分图。
如果输入特征是一个单一的张量,并且输入图是一个MFG,DGL将自动将输出节点上的特征设置为输入节点特征的前几行。
如果输入特征必须是一个单一的张量且输入图不是MFG,那么输入图必须是同质的。
例如,以下是从PyTorch实现的dgl.nn.pytorch.SAGEConv
(也可在MXNet和Tensorflow中使用)中简化的内容(去除了归一化,仅处理均值聚合等)。
import dgl.function as fn
class SAGEConv(nn.Module):
def __init__(self, in_feats, out_feats):
super().__init__()
self.W = nn.Linear(in_feats * 2, out_feats)
def forward(self, g, h):
if isinstance(h, tuple):
h_src, h_dst = h
elif g.is_block:
h_src = h
h_dst = h[:g.number_of_dst_nodes()]
else:
h_src = h_dst = h
g.srcdata['h'] = h_src
g.dstdata['h'] = h_dst
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_neigh'))
return F.relu(
self.W(torch.cat([g.dstdata['h'], g.dstdata['h_neigh']], 1)))
第3章:构建GNN模块 还提供了关于 dgl.nn.pytorch.SAGEConv
的详细说明,
它适用于单向二分图、同构图和MFGs。