dgl.DGLGraph.batch_num_nodes

DGLGraph.batch_num_nodes(ntype=None)[source]

返回批次中每个具有指定节点类型的图的节点数量。

Parameters:

ntype (str, optional) – 用于查询的节点类型。如果图有多个节点类型,则必须指定该参数。否则,可以省略。如果图不是批处理的,它将返回一个长度为1的列表,其中包含图中的节点数量。

Returns:

批次中每个图具有指定类型的节点数。它的第i个元素是第i个图中具有指定类型的节点数。

Return type:

张量

示例

以下示例使用PyTorch后端。

>>> import dgl
>>> import torch

查询同构图。

>>> g1 = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
>>> g1.batch_num_nodes()
tensor([4])
>>> g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0])))
>>> bg = dgl.batch([g1, g2])
>>> bg.batch_num_nodes()
tensor([4, 3])

查询异构图。

>>> hg1 = dgl.heterograph({
...       ('user', 'plays', 'game') : (torch.tensor([0, 1]), torch.tensor([0, 0]))})
>>> hg2 = dgl.heterograph({
...       ('user', 'plays', 'game') : (torch.tensor([0, 0]), torch.tensor([1, 0]))})
>>> bg = dgl.batch([hg1, hg2])
>>> bg.batch_num_nodes('user')
tensor([2, 1])