KNNGraph

class dgl.nn.pytorch.factory.KNNGraph(k)[source]

Bases: Module

将一组点转换为图的层,或者将具有相同点数的多组点转换为这些图的批量联合。

KNNGraph 的实现步骤如下:

  1. 计算所有点的NxN成对距离矩阵。

  2. 为每个点选择距离最小的k个点作为它们的k个最近邻居。

  3. 构建一个图,其中每个点作为一个节点,连接到其k个最近邻居。

整体计算复杂度为 \(O(N^2(logN + D)\)

如果提供了一批点集,点集 \(i\) 中的点 \(j\) 将被映射到图节点 ID:\(i \times M + j\),其中 \(M\) 是每个点集中的节点数量。

每个节点的前驱节点是对应点的k近邻。

Parameters:

k (int) – 邻居的数量。

注释

为一个节点找到的最近邻居包括节点本身。

示例

以下示例使用PyTorch后端。

>>> import torch
>>> from dgl.nn.pytorch.factory import KNNGraph
>>>
>>> kg = KNNGraph(2)
>>> x = torch.tensor([[0,1],
                      [1,2],
                      [1,3],
                      [100, 101],
                      [101, 102],
                      [50, 50]])
>>> g = kg(x)
>>> print(g.edges())
    (tensor([0, 1, 1, 1, 2, 2, 2, 3, 3, 4, 4, 5]),
     tensor([0, 0, 1, 2, 1, 2, 5, 3, 4, 3, 4, 5]))
forward(x, algorithm='bruteforce-blas', dist='euclidean', exclude_self=False)[source]

前向计算。

Parameters:
  • x (Tensor) – \((M, D)\)\((N, M, D)\),其中 \(N\) 表示 点集的数量,\(M\) 表示每个点集中的点的数量,\(D\) 表示特征的大小。

  • algorithm (str, optional) –

    用于计算k近邻的算法。

    • ’bruteforce-blas’ 将首先使用后端框架提供的BLAS矩阵乘法操作计算距离矩阵。然后使用topk算法获取k近邻。当点集较小时,此方法速度较快,但具有\(O(N^2)\)的内存复杂度,其中\(N\)是点的数量。

    • ’bruteforce’ 将逐对计算距离,并在距离计算过程中直接选择k近邻。此方法比’bruteforce-blas’慢,但内存开销较小(即\(O(Nk)\),其中\(N\)是点的数量,\(k\)是每个节点的近邻数量),因为我们不需要存储所有距离。

    • ’bruteforce-sharemem’(仅限CUDA)与’bruteforce’类似,但在CUDA设备中使用共享内存作为缓冲区。当输入点的维度不大时,此方法比’bruteforce’更快。此方法仅在CUDA设备上可用。

    • ’kd-tree’ 将使用kd-tree算法(仅限CPU)。此方法适用于低维数据(例如3D点云)。

    • ’nn-descent’ 是一种近似方法,来自论文Efficient k-nearest neighbor graph construction for generic similarity measures。此方法将在“邻居的邻居”中搜索近邻候选。

    (默认:’bruteforce-blas’)

  • dist (str, optional) –

    用于计算点之间距离的距离度量。它可以是以下度量: * ‘euclidean’: 使用欧几里得距离(L2范数)

    \(\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}\).

    • ’cosine’: 使用余弦距离。

    (默认: ‘euclidean’)

  • exclude_self (bool, optional) – 如果为True,输出图将不包含自循环边,并且每个节点不会被视为其自身的k个邻居之一。如果为False,输出图将包含自循环边,并且节点将被视为其自身的k个邻居之一。

Returns:

一个没有特征的DGLGraph。

Return type:

DGLGraph