分段KNN图
- class dgl.nn.pytorch.factory.SegmentedKNNGraph(k)[source]
Bases:
Module
将一组点转换为图的层,或者将具有不同点数的多组点转换为这些图的批量联合。
如果提供了一批点集,那么点集 \(i\) 中的点 \(j\) 将被映射到图节点 ID: \(\sum_{p,其中 \(|V_p|\) 表示点集 \(p\) 中的点的数量。
每个节点的前驱节点是对应点的k近邻。
- Parameters:
k (int) – The number of neighbors.
注释
为一个节点找到的最近邻居包括节点本身。
示例
以下示例使用PyTorch后端。
>>> import torch >>> from dgl.nn.pytorch.factory import SegmentedKNNGraph >>> >>> kg = SegmentedKNNGraph(2) >>> x = torch.tensor([[0,1], ... [1,2], ... [1,3], ... [100, 101], ... [101, 102], ... [50, 50], ... [24,25], ... [25,24]]) >>> g = kg(x, [3,3,2]) >>> print(g.edges()) (tensor([0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 6, 6, 7, 7]), tensor([0, 0, 1, 2, 1, 2, 3, 4, 5, 3, 4, 5, 6, 7, 6, 7])) >>>
- forward(x, segs, algorithm='bruteforce-blas', dist='euclidean', exclude_self=False)[source]
前向计算。
- Parameters:
x (Tensor) – \((M, D)\) 其中 \(M\) 表示所有点集中的总点数,\(D\) 表示特征的大小。
segs (可迭代的 int) – \((N)\) 整数,其中 \(N\) 表示点集的数量。元素的数量必须总和为 \(M\)。并且任何 \(N\) 应该 \(\ge k\)
algorithm (str, optional) –
Algorithm used to compute the k-nearest neighbors.
’bruteforce-blas’ will first compute the distance matrix using BLAS matrix multiplication operation provided by backend frameworks. Then use topk algorithm to get k-nearest neighbors. This method is fast when the point set is small but has \(O(N^2)\) memory complexity where \(N\) is the number of points.
’bruteforce’ will compute distances pair by pair and directly select the k-nearest neighbors during distance computation. This method is slower than ‘bruteforce-blas’ but has less memory overhead (i.e., \(O(Nk)\) where \(N\) is the number of points, \(k\) is the number of nearest neighbors per node) since we do not need to store all distances.
’bruteforce-sharemem’ (CUDA only) is similar to ‘bruteforce’ but use shared memory in CUDA devices for buffer. This method is faster than ‘bruteforce’ when the dimension of input points is not large. This method is only available on CUDA device.
’kd-tree’ will use the kd-tree algorithm (CPU only). This method is suitable for low-dimensional data (e.g. 3D point clouds)
’nn-descent’ is a approximate approach from paper Efficient k-nearest neighbor graph construction for generic similarity measures. This method will search for nearest neighbor candidates in “neighbors’ neighbors”.
(default: ‘bruteforce-blas’)
dist (str, optional) –
The distance metric used to compute distance between points. It can be the following metrics: * ‘euclidean’: Use Euclidean distance (L2 norm)
\(\sqrt{\sum_{i} (x_{i} - y_{i})^{2}}\).
’cosine’: Use cosine distance.
(default: ‘euclidean’)
exclude_self (bool, optional) – If True, the output graph will not contain self loop edges, and each node will not be counted as one of its own k neighbors. If False, the output graph will contain self loop edges, and a node will be counted as one of its own k neighbors.
- Returns:
一个没有特征的批量DGLGraph。
- Return type: