dgl.DGLGraph.filter_nodes

DGLGraph.filter_nodes(predicate, nodes='__ALL__', ntype=None)[source]

返回满足给定谓词的具有给定节点类型的节点的ID。

Parameters:
  • predicate (callable) – 一个签名为 func(nodes) -> Tensor 的函数。 nodesdgl.NodeBatch 对象。 它的输出张量应该是一个一维布尔张量, 每个元素指示批次中的相应节点是否满足谓词。

  • 节点 (节点 ID(s), 可选) –

    查询的节点。允许的格式有:

    • 张量:一个包含查询节点的1D张量,其数据类型和设备应与图的idtype和设备相同。

    • 可迭代[int]:类似于张量,但将节点ID存储在序列中(例如列表、元组、numpy.ndarray)。

    默认情况下,它考虑所有节点。

  • ntype (str, optional) – 查询的节点类型。如果图中有多种节点类型,则必须指定该参数。否则,可以省略。

Returns:

一个包含满足谓词的节点ID的一维张量。

Return type:

张量

示例

以下示例使用PyTorch后端。

>>> import dgl
>>> import torch

定义一个谓词函数。

>>> def nodes_with_feature_one(nodes):
...     # Whether a node has feature 1
...     return (nodes.data['h'] == 1.).squeeze(1)

为同构图过滤节点。

>>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3])))
>>> g.ndata['h'] = torch.tensor([[0.], [1.], [1.], [0.]])
>>> print(g.filter_nodes(nodes_with_feature_one))
tensor([1, 2])

过滤ID为0和1的节点

>>> print(g.filter_nodes(nodes_with_feature_one, nodes=torch.tensor([0, 1])))
tensor([1])

为异构图过滤节点。

>>> g = dgl.heterograph({
...     ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),
...                                 torch.tensor([0, 0, 1, 1]))})
>>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.]])
>>> g.nodes['game'].data['h'] = torch.tensor([[0.], [1.]])
>>> # Filter for 'user' nodes
>>> print(g.filter_nodes(nodes_with_feature_one, ntype='user'))
tensor([1, 2])