KNNGraph
- class dgl.nn.pytorch.factory.KNNGraph(k)[source]
Bases:
Module
将一组点转换为图的层,或者将具有相同点数的多组点转换为这些图的批量联合。
KNNGraph 的实现步骤如下:
计算所有点的NxN成对距离矩阵。
为每个点选择距离最小的k个点作为它们的k个最近邻居。
构建一个图,其中每个点作为一个节点,连接到其k个最近邻居。
整体计算复杂度为 \(O(N^2(logN + D)\)。
如果提供了一批点集,点集 \(i\) 中的点 \(j\) 将被映射到图节点 ID:\(i \times M + j\),其中 \(M\) 是每个点集中的节点数量。
每个节点的前驱节点是对应点的k近邻。
- Parameters:
k (int) – 邻居的数量。
注释
为一个节点找到的最近邻居包括节点本身。
示例
以下示例使用PyTorch后端。
>>> import torch >>> from dgl.nn.pytorch.factory import KNNGraph >>> >>> kg = KNNGraph(2) >>> x = torch.tensor([[0,1], [1,2], [1,3], [100, 101], [101, 102], [50, 50]]) >>> g = kg(x) >>> print(g.edges()) (tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]), tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5]))
- forward(x, algorithm='bruteforce-blas', dist='euclidean', exclude_self=False)[source]
前向计算。
- Parameters:
x (Tensor) – \((M, D)\) 或 \((N, M, D)\),其中 \(N\) 表示 点集的数量,\(M\) 表示每个点集中的点的数量,\(D\) 表示特征的大小。
algorithm (str, optional) –
用于计算k近邻的算法。
’bruteforce-blas’ 将首先使用后端框架提供的BLAS矩阵乘法操作计算距离矩阵。然后使用topk算法获取k近邻。当点集较小时,此方法速度较快,但具有\(O(N^2)\)的内存复杂度,其中\(N\)是点的数量。
’bruteforce’ 将逐对计算距离,并在距离计算过程中直接选择k近邻。此方法比’bruteforce-blas’慢,但内存开销较小(即\(O(Nk)\),其中\(N\)是点的数量,\(k\)是每个节点的近邻数量),因为我们不需要存储所有距离。
’bruteforce-sharemem’(仅限CUDA)与’bruteforce’类似,但在CUDA设备中使用共享内存作为缓冲区。当输入点的维度不大时,此方法比’bruteforce’更快。此方法仅在CUDA设备上可用。
’kd-tree’ 将使用kd-tree算法(仅限CPU)。此方法适用于低维数据(例如3D点云)。
’nn-descent’ 是一种近似方法,来自论文Efficient k-nearest neighbor graph construction for generic similarity measures。此方法将在“邻居的邻居”中搜索近邻候选。
(默认:’bruteforce-blas’)
dist (str, optional) –
用于计算点之间距离的距离度量。它可以是以下度量: * ‘euclidean’: 使用欧几里得距离(L2范数)
\(\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}\).
’cosine’: 使用余弦距离。
(默认: ‘euclidean’)
exclude_self (bool, optional) – 如果为True,输出图将不包含自循环边,并且每个节点不会被视为其自身的k个邻居之一。如果为False,输出图将包含自循环边,并且节点将被视为其自身的k个邻居之一。
- Returns:
一个没有特征的DGLGraph。
- Return type: