dgl.batch

dgl.batch(graphs, ndata='__ALL__', edata='__ALL__')[source]

将一组DGLGraph批量处理为一个图,以提高图计算的效率。

每个输入图成为批处理图的一个不相交组件。节点和边被重新标记为不相交的段:

原始节点ID

0 ~ N_0

0 ~ N_1

0 ~ N_k

新节点ID

0 ~ N_0

N_0 ~ N_0+N_1

sum_{i=0}^{k-1} N_i ~ sum_{i=0}^k N_i

因此,在批处理图上的许多计算与在单个图上执行的计算相同,但由于可以轻松并行化,因此效率更高。这使得dgl.batch在处理许多图样本的任务(如图分类任务)时非常有用。

对于异构图输入,它们必须共享相同的关系集(即节点类型和边类型),并且该函数将逐个对每个关系进行批处理。因此,结果也是一个异构图,并且具有与输入相同的关系集。

输入图的节点和边的数量可以通过结果图的DGLGraph.batch_num_nodes()DGLGraph.batch_num_edges()属性访问。对于同构图,它们是一维整数张量,每个元素是对应输入图的节点/边数量。对于异构图,它们是一维整数张量的字典,键为节点类型或边类型。

该函数支持批处理批处理图。结果图的批处理大小是所有输入图的批处理大小的总和。

默认情况下,节点/边特征通过将所有输入图的特征张量连接起来进行批处理。因此,这要求相同名称的特征具有相同的数据类型和特征大小。可以将None传递给ndataedata参数以防止特征批处理,或者传递一个字符串列表来指定要批处理的特征。

要将图解批处理回列表,请使用 dgl.unbatch() 函数。

Parameters:
  • graphs (list[DGLGraph]) – Input graphs.

  • ndata (list[str], None, optional) – 要批处理的节点特征。

  • edata (list[str], None, optional) – 要批处理的边特征。

Returns:

批量图。

Return type:

DGLGraph

示例

批量同构图

>>> import dgl
>>> import torch as th
>>> # 4 nodes, 3 edges
>>> g1 = dgl.graph((th.tensor([0, 1, 2]), th.tensor([1, 2, 3])))
>>> # 3 nodes, 4 edges
>>> g2 = dgl.graph((th.tensor([0, 0, 0, 1]), th.tensor([0, 1, 2, 0])))
>>> bg = dgl.batch([g1, g2])
>>> bg
Graph(num_nodes=7, num_edges=7,
      ndata_schemes={}
      edata_schemes={})
>>> bg.batch_size
2
>>> bg.batch_num_nodes()
tensor([4, 3])
>>> bg.batch_num_edges()
tensor([3, 4])
>>> bg.edges()
(tensor([0, 1, 2, 4, 4, 4, 5], tensor([1, 2, 3, 4, 5, 6, 4]))

批量批处理图形

>>> bbg = dgl.batch([bg, bg])
>>> bbg.batch_size
4
>>> bbg.batch_num_nodes()
tensor([4, 3, 4, 3])
>>> bbg.batch_num_edges()
tensor([3, 4, 3, 4])

批量处理带有特征数据的图表

>>> g1.ndata['x'] = th.zeros(g1.num_nodes(), 3)
>>> g1.edata['w'] = th.ones(g1.num_edges(), 2)
>>> g2.ndata['x'] = th.ones(g2.num_nodes(), 3)
>>> g2.edata['w'] = th.zeros(g2.num_edges(), 2)
>>> bg = dgl.batch([g1, g2])
>>> bg.ndata['x']
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1]])
>>> bg.edata['w']
tensor([[1, 1],
        [1, 1],
        [1, 1],
        [0, 0],
        [0, 0],
        [0, 0],
        [0, 0]])

批量异构图

>>> hg1 = dgl.heterograph({
...     ('user', 'plays', 'game') : (th.tensor([0, 1]), th.tensor([0, 0]))})
>>> hg2 = dgl.heterograph({
...     ('user', 'plays', 'game') : (th.tensor([0, 0, 0]), th.tensor([1, 0, 2]))})
>>> bhg = dgl.batch([hg1, hg2])
>>> bhg
Graph(num_nodes={'user': 3, 'game': 4},
      num_edges={('user', 'plays', 'game'): 5},
      metagraph=[('drug', 'game')])
>>> bhg.batch_size
2
>>> bhg.batch_num_nodes()
{'user' : tensor([2, 1]), 'game' : tensor([1, 3])}
>>> bhg.batch_num_edges()
{('user', 'plays', 'game') : tensor([2, 3])}

另请参阅

unbatch