dgl.DGLGraph.find_edges
- DGLGraph.find_edges(eid, etype=None)[source]
根据边的ID返回源节点和目标节点的ID。
- Parameters:
eid (边 ID(s)) –
边的 ID。允许的格式有:
int
: 单个 ID。Int Tensor: 每个元素都是一个 ID。张量必须与图的设备类型和 ID 数据类型相同。
iterable[int]: 每个元素都是一个 ID。
etype (str or (str, str, str), optional) –
The type names of the edges. The allowed type name formats are:
(str, str, str)
for source node type, edge type and destination node type.or one
str
edge type name if the name can uniquely identify a triplet format in the graph.
Can be omitted if the graph has only one type of edges.
- Returns:
Tensor – 边的源节点ID。第i个元素是第i条边的源节点ID。
Tensor – 边的目标节点ID。第i个元素是第i条边的目标节点ID。
示例
以下示例使用PyTorch后端。
>>> import dgl >>> import torch
创建一个同构图。
>>> g = dgl.graph((torch.tensor([0, 0, 1, 1]), torch.tensor([1, 0, 2, 3])))
查找ID为0和2的边。
>>> g.find_edges(torch.tensor([0, 2])) (tensor([0, 1]), tensor([1, 2]))
对于具有多种边类型的图,需要在查询中指定边类型。
>>> hg = dgl.heterograph({ ... ('user', 'follows', 'user'): (torch.tensor([0, 1]), torch.tensor([1, 2])), ... ('user', 'plays', 'game'): (torch.tensor([3, 4]), torch.tensor([5, 6])) ... }) >>> hg.find_edges(torch.tensor([1, 0]), 'plays') (tensor([4, 3]), tensor([6, 5]))