dgl.to_block
- dgl.to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None)[source]
将图转换为用于消息传递的二分结构块。
一个块是由两组节点组成的图:源节点和目标节点。源节点和目标节点可以有多种节点类型。所有的边都从源节点连接到目标节点。
具体来说,源节点和目标节点将具有与原始图中相同的节点类型。DGL将原始图中具有边类型
(utype, etype, vtype)
的每条边(u, v)
映射为连接源侧类型为utype
的节点IDu
到目标侧类型为vtype
的节点IDv
的边类型etype
。对于由
to_block()
返回的块,块的目标节点将仅包含至少有一条任何类型的入边的节点。块的源节点将仅包含出现在目标节点中的节点,以及至少有一条出边连接到其中一个目标节点的节点。目标节点由
dst_nodes
参数指定,如果该参数不为None。- Parameters:
graph (DGLGraph) – 图。可以在CPU或GPU上。
dst_nodes (Tensor 或 dict[str, Tensor], optional) –
目标节点列表。
如果给定的是一个张量,图必须只有一种节点类型。
如果给定,它必须是所有至少有一个入边的节点的超集。否则将引发错误。
include_dst_in_src (bool) –
如果为False,则不将目标节点包含在源节点中。
(默认值: True)
src_nodes (Tensor 或 disct[str, Tensor], optional) –
源节点列表(如果include_dst_in_src为True,则前缀为目标节点)。
如果给定一个张量,图必须只有一种节点类型。
- Returns:
描述块的新图。
在两侧为每种类型诱导的节点ID将存储在特征
dgl.NID
中。为每种类型诱导的边ID将存储在特征
dgl.EID
中。- Return type:
DGLBlock
- Raises:
DGLError – 如果指定了
dst_nodes
,但它不是所有至少有一条入边的节点的超集。 如果dst_nodes
不是None,并且g
和dst_nodes
不在同一个上下文中。
注释
to_block()
最常用于在大图上进行随机训练的邻居采样定制。请参阅用户指南中的 第6章:大图上的随机训练 以获取关于随机训练方法的更详细讨论。另请参阅
create_block()
以获取更灵活的块构建方法。示例
将同质图转换为块,如上所述:
>>> g = dgl.graph(([1, 2], [2, 3])) >>> block = dgl.to_block(g, torch.LongTensor([3, 2]))
目标节点将与给定的节点完全相同:[3, 2]。
>>> induced_dst = block.dstdata[dgl.NID] >>> induced_dst tensor([3, 2])
最初的几个源节点也将与给定的节点完全相同。其余的节点是消息传递到节点3和2所必需的。这意味着节点1将被包括在内。
>>> induced_src = block.srcdata[dgl.NID] >>> induced_src tensor([3, 2, 1])
你可以注意到前两个节点与给定的节点以及目标节点相同。
诱导边也可以通过以下方式获得:
>>> block.edata[dgl.EID] tensor([2, 1])
这表明边 (2, 3) 和 (1, 2) 包含在结果图中。你可以验证块中的第一条边确实映射到边 (2, 3),块中的第二条边确实映射到边 (1, 2):
>>> src, dst = block.edges(order='eid') >>> induced_src[src], induced_dst[dst] (tensor([2, 1]), tensor([3, 2]))
指定的目标节点必须是具有连接到它们的边的节点的超集。例如,以下内容将引发错误,因为目标节点不包含节点3,而节点3具有连接到它的边。
>>> g = dgl.graph(([1, 2], [2, 3])) >>> dgl.to_block(g, torch.LongTensor([2])) # error
将异构图转换为块的过程类似,只是在指定目标节点时,你需要提供一个字典:
>>> g = dgl.heterograph({('A', '_E', 'B'): ([1, 2], [2, 3])})
如果您在目标端没有指定任何类型为A的节点,块中的节点类型
A
在目标端将没有节点。>>> block = dgl.to_block(g, {'B': torch.LongTensor([3, 2])}) >>> block.number_of_dst_nodes('A') 0 >>> block.number_of_dst_nodes('B') 2 >>> block.dstnodes['B'].data[dgl.NID] tensor([3, 2])
源端将包含目标端的所有节点:
>>> block.srcnodes['B'].data[dgl.NID] tensor([3, 2])
以及所有与目标侧节点有连接的节点:
>>> block.srcnodes['A'].data[dgl.NID] tensor([2, 1])
另请参阅