异构图卷积
- class dgl.nn.pytorch.HeteroGraphConv(mods, aggregate='sum')[source]
Bases:
Module
一个用于在异质图上计算卷积的通用模块。
异构图卷积在其关联的关系图上应用子模块,该模块从源节点读取特征并将更新后的特征写入目标节点。如果多个关系具有相同的目标节点类型,则它们的结果将通过指定的方法进行聚合。如果关系图没有边,则不会调用相应的模块。
伪代码:
outputs = {nty : [] for nty in g.dsttypes} # Apply sub-modules on their associating relation graphs in parallel for relation in g.canonical_etypes: stype, etype, dtype = relation dstdata = relation_submodule(g[relation], ...) outputs[dtype].append(dstdata) # Aggregate the results for each destination node type rsts = {} for ntype, ntype_outputs in outputs.items(): if len(ntype_outputs) != 0: rsts[ntype] = aggregate(ntype_outputs) return rsts
示例
创建一个具有三种关系和节点类型的异构图。
>>> import dgl >>> g = dgl.heterograph({ ... ('user', 'follows', 'user') : edges1, ... ('user', 'plays', 'game') : edges2, ... ('store', 'sells', 'game') : edges3})
创建一个
HeteroGraphConv
,它将不同的卷积模块应用于不同的关系。请注意,'follows'
和'plays'
的模块不共享权重。>>> import dgl.nn.pytorch as dglnn >>> conv = dglnn.HeteroGraphConv({ ... 'follows' : dglnn.GraphConv(...), ... 'plays' : dglnn.GraphConv(...), ... 'sells' : dglnn.SAGEConv(...)}, ... aggregate='sum')
使用一些
'user'
特征进行前向调用。这将为'user'
和'game'
节点计算新的特征。>>> import torch as th >>> h1 = {'user' : th.randn((g.num_nodes('user'), 5))} >>> h2 = conv(g, h1) >>> print(h2.keys()) dict_keys(['user', 'game'])
使用
'user'
和'store'
特征进行调用转发。因为'plays'
和'sells'
关系都会更新'game'
特征,它们的结果通过指定的方法(即这里的求和)进行聚合。>>> f1 = {'user' : ..., 'store' : ...} >>> f2 = conv(g, f1) >>> print(f2.keys()) dict_keys(['user', 'game'])
调用转发带有一些
'store'
功能。这仅计算'game'
节点的新功能。>>> g1 = {'store' : ...} >>> g2 = conv(g, g1) >>> print(g2.keys()) dict_keys(['game'])
允许使用一对输入进行调用转发,并且每个子模块也将使用一对输入进行调用。
>>> x_src = {'user' : ..., 'store' : ...} >>> x_dst = {'user' : ..., 'game' : ...} >>> y_dst = conv(g, (x_src, x_dst)) >>> print(y_dst.keys()) dict_keys(['user', 'game'])
- Parameters:
mods (dict[str, nn.Module]) – 与每种边类型相关联的模块。每个模块的前向函数必须将DGLGraph对象作为第一个参数,第二个参数是表示节点特征的张量对象或表示源节点和目标节点特征的张量对象对。
aggregate (str, callable, optional) –
用于聚合由不同关系生成的节点特征的方法。 允许的字符串值为‘sum’、‘max’、‘min’、‘mean’、‘stack’。 ‘stack’聚合是沿着第二个维度执行的,其顺序是确定的。 用户还可以通过提供一个可调用的实例来自定义聚合器。 例如,通过求和进行聚合等同于以下内容:
def my_agg_func(tensors, dsttype): # tensors: 是要聚合的张量列表 # dsttype: 执行聚合的目标节点类型的字符串名称 stacked = torch.stack(tensors, dim=0) return torch.sum(stacked, dim=0)