dgl.DGLGraph.pull
- DGLGraph.pull(v, message_func, reduce_func, apply_node_func=None, etype=None)[source]
从指定节点的前驱节点沿指定边类型拉取消息,聚合它们以更新节点特征。
- Parameters:
v (节点ID) –
节点ID。允许的格式有:
int
: 单个节点。Int Tensor: 每个元素都是一个节点ID。张量必须具有与图相同的设备类型和ID数据类型。
iterable[int]: 每个元素都是一个节点ID。
message_func (dgl.function.BuiltinFunction or callable) – The message function to generate messages along the edges. It must be either a DGL Built-in Function or a User-defined Functions.
reduce_func (dgl.function.BuiltinFunction or callable) – The reduce function to aggregate the messages. It must be either a DGL Built-in Function or a User-defined Functions.
apply_node_func (callable, optional) – An optional apply function to further update the node features after the message reduction. It must be a User-defined Functions.
etype (str or (str, str, str), optional) –
The type name of the edges. The allowed type name formats are:
(str, str, str)
for source node type, edge type and destination node type.or one
str
edge type name if the name can uniquely identify a triplet format in the graph.
Can be omitted if the graph has only one type of edges.
注释
如果给定的某些节点
v
没有入边,DGL 不会为这些节点调用消息和减少函数,并用零填充它们的聚合消息。用户可以通过set_n_initializer()
控制填充的值。如果提供了apply_node_func
,DGL 仍然会调用它。DGL recommends using DGL’s bulit-in function for the
message_func
and thereduce_func
arguments, because DGL will invoke efficient kernels that avoids copying node features to edge features in this case.
示例
>>> import dgl >>> import dgl.function as fn >>> import torch
同构图
>>> g = dgl.graph(([0, 1, 2, 3], [1, 2, 3, 4])) >>> g.ndata['x'] = torch.ones(5, 2) >>> g.pull([0, 3, 4], fn.copy_u('x', 'm'), fn.sum('m', 'h')) >>> g.ndata['h'] tensor([[0., 0.], [0., 0.], [0., 0.], [1., 1.], [1., 1.]])
异构图
>>> g = dgl.heterograph({ ... ('user', 'follows', 'user'): ([0, 1], [1, 2]), ... ('user', 'plays', 'game'): ([0, 2], [0, 1]) ... }) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [2.]])
拉。
>>> g['follows'].pull(2, fn.copy_u('h', 'm'), fn.sum('m', 'h'), etype='follows') >>> g.nodes['user'].data['h'] tensor([[0.], [1.], [1.]])