dgl.unbatch
- dgl.unbatch(g, node_split=None, edge_split=None)[source]
通过将给定的图拆分为一系列小图来撤销批处理操作。
这是 :func:
dgl.batch
的反向操作。如果未提供node_split
或edge_split
,它将调用DGLGraph.batch_num_nodes()
和DGLGraph.batch_num_edges()
来获取输入图的信息。如果提供了
node_split
或edge_split
参数, 它将根据给定的段对图进行分区。必须确保分区是有效的——第i个图的边仅连接属于第i个图的节点。 否则,DGL将抛出错误。该函数支持异构图输入,在这种情况下,两个分割部分的参数应为字典类型——类似于异构图中的
DGLGraph.batch_num_nodes()
和DGLGraph.batch_num_edges()
属性。- Parameters:
- Returns:
未批处理的图列表。
- Return type:
示例
解批处理一个批处理图
>>> import dgl >>> import torch as th >>> # 4 nodes, 3 edges >>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3]))) >>> # 3 nodes, 4 edges >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0]))) >>> # add features >>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3) >>> g1.edata['w'] = th.ones(g1.num_edges(), 2) >>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3) >>> g2.edata['w'] = th.zeros(g2.num_edges(), 2) >>> bg = dgl.batch([g1, g2]) >>> f1, f2 = dgl.unbatch(bg) >>> f1 Graph(num_nodes=4, num_edges=3, ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)}) >>> f2 Graph(num_nodes=3, num_edges=4, ndata_schemes={‘x’ : Scheme(shape=(3,), dtype=torch.float32)} edata_schemes={‘w’ : Scheme(shape=(2,), dtype=torch.float32)})
使用提供的分割参数:
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3]))) >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0]))) >>> g3 = dgl.graph((th.tensor([0]), th.tensor([1]))) >>> bg = dgl.batch([g1, g2, g3]) >>> bg.batch_num_nodes() tensor([4, 3, 2]) >>> bg.batch_num_edges() tensor([3, 4, 1]) >>> # unbatch but merge g2 and g3 >>> f1, f2 = dgl.unbatch(bg, th.tensor([4, 5]), th.tensor([3, 5])) >>> f1 Graph(num_nodes=4, num_edges=3, ndata_schemes={} edata_schemes={}) >>> f2 Graph(num_nodes=5, num_edges=5, ndata_schemes={} edata_schemes={})
异构图输入
>>> hg1 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))}) >>> hg2 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))}) >>> bhg = dgl.batch([hg1, hg2]) >>> f1, f2 = dgl.unbatch(bhg) >>> f1 Graph(num_nodes={'user': 2, 'game': 1}, num_edges={('user', 'plays', 'game'): 2}, metagraph=[('drug', 'game')]) >>> f2 Graph(num_nodes={'user': 1, 'game': 3}, num_edges={('user', 'plays', 'game'): 3}, metagraph=[('drug', 'game')])
另请参阅