dgl.batch
- dgl.batch(graphs, ndata='__ALL__', edata='__ALL__')[source]
将一组
DGLGraph
批量处理为一个图,以提高图计算的效率。每个输入图成为批处理图的一个不相交组件。节点和边被重新标记为不相交的段:
原始节点ID
0 ~ N_0
0 ~ N_1
…
0 ~ N_k
新节点ID
0 ~ N_0
N_0 ~ N_0+N_1
…
sum_{i=0}^{k-1} N_i ~ sum_{i=0}^k N_i
因此,在批处理图上的许多计算与在单个图上执行的计算相同,但由于可以轻松并行化,因此效率更高。这使得
dgl.batch
在处理许多图样本的任务(如图分类任务)时非常有用。对于异构图输入,它们必须共享相同的关系集(即节点类型和边类型),并且该函数将逐个对每个关系进行批处理。因此,结果也是一个异构图,并且具有与输入相同的关系集。
输入图的节点和边的数量可以通过结果图的
DGLGraph.batch_num_nodes()
和DGLGraph.batch_num_edges()
属性访问。对于同构图,它们是一维整数张量,每个元素是对应输入图的节点/边数量。对于异构图,它们是一维整数张量的字典,键为节点类型或边类型。该函数支持批处理批处理图。结果图的批处理大小是所有输入图的批处理大小的总和。
默认情况下,节点/边特征通过将所有输入图的特征张量连接起来进行批处理。因此,这要求相同名称的特征具有相同的数据类型和特征大小。可以将
None
传递给ndata
或edata
参数以防止特征批处理,或者传递一个字符串列表来指定要批处理的特征。要将图解批处理回列表,请使用
dgl.unbatch()
函数。- Parameters:
- Returns:
批量图。
- Return type:
示例
批量同构图
>>> import dgl >>> import torch as th >>> # 4 nodes, 3 edges >>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3]))) >>> # 3 nodes, 4 edges >>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0]))) >>> bg = dgl.batch([g1, g2]) >>> bg Graph(num_nodes=7, num_edges=7, ndata_schemes={} edata_schemes={}) >>> bg.batch_size 2 >>> bg.batch_num_nodes() tensor([4, 3]) >>> bg.batch_num_edges() tensor([3, 4]) >>> bg.edges() (tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))
批量批处理图形
>>> bbg = dgl.batch([bg, bg]) >>> bbg.batch_size 4 >>> bbg.batch_num_nodes() tensor([4, 3, 4, 3]) >>> bbg.batch_num_edges() tensor([3, 4, 3, 4])
批量处理带有特征数据的图表
>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3) >>> g1.edata['w'] = th.ones(g1.num_edges(), 2) >>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3) >>> g2.edata['w'] = th.zeros(g2.num_edges(), 2) >>> bg = dgl.batch([g1, g2]) >>> bg.ndata['x'] tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 1, 1], [1, 1, 1], [1, 1, 1]]) >>> bg.edata['w'] tensor([[1, 1], [1, 1], [1, 1], [0, 0], [0, 0], [0, 0], [0, 0]])
批量异构图
>>> hg1 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))}) >>> hg2 = dgl.heterograph({ ... ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))}) >>> bhg = dgl.batch([hg1, hg2]) >>> bhg Graph(num_nodes={'user': 3, 'game': 4}, num_edges={('user', 'plays', 'game'): 5}, metagraph=[('drug', 'game')]) >>> bhg.batch_size 2 >>> bhg.batch_num_nodes() {'user' : tensor([2, 1]), 'game' : tensor([1, 3])} >>> bhg.batch_num_edges() {('user', 'plays', 'game') : tensor([2, 3])}
另请参阅