dgl.DGLGraph.filter_nodes
- DGLGraph.filter_nodes(predicate, nodes='__ALL__', ntype=None)[source]
返回满足给定谓词的具有给定节点类型的节点的ID。
- Parameters:
predicate (callable) – 一个签名为
func(nodes) -> Tensor
的函数。nodes
是dgl.NodeBatch
对象。 它的输出张量应该是一个一维布尔张量, 每个元素指示批次中的相应节点是否满足谓词。节点 (节点 ID(s), 可选) –
查询的节点。允许的格式有:
张量:一个包含查询节点的1D张量,其数据类型和设备应与图的
idtype
和设备相同。可迭代[int]:类似于张量,但将节点ID存储在序列中(例如列表、元组、numpy.ndarray)。
默认情况下,它考虑所有节点。
ntype (str, optional) – 查询的节点类型。如果图中有多种节点类型,则必须指定该参数。否则,可以省略。
- Returns:
一个包含满足谓词的节点ID的一维张量。
- Return type:
张量
示例
以下示例使用PyTorch后端。
>>> import dgl >>> import torch
定义一个谓词函数。
>>> def nodes_with_feature_one(nodes): ... # Whether a node has feature 1 ... return (nodes.data['h'] == 1.).squeeze(1)
为同构图过滤节点。
>>> g = dgl.graph((torch.tensor([0, 1, 2]), torch.tensor([1, 2, 3]))) >>> g.ndata['h'] = torch.tensor([[0.], [1.], [1.], [0.]]) >>> print(g.filter_nodes(nodes_with_feature_one)) tensor([1, 2])
过滤ID为0和1的节点
>>> print(g.filter_nodes(nodes_with_feature_one, nodes=torch.tensor([0, 1]))) tensor([1])
为异构图过滤节点。
>>> g = dgl.heterograph({ ... ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]), ... torch.tensor([0, 0, 1, 1]))}) >>> g.nodes['user'].data['h'] = torch.tensor([[0.], [1.], [1.]]) >>> g.nodes['game'].data['h'] = torch.tensor([[0.], [1.]]) >>> # Filter for 'user' nodes >>> print(g.filter_nodes(nodes_with_feature_one, ntype='user')) tensor([1, 2])