torch_geometric.nn.pool.KNNIndex
- class KNNIndex(index_factory: Optional[str] = None, emb: Optional[Tensor] = None, reserve: Optional[int] = None)[source]
Bases:
object一个基类,用于通过
faiss库执行快速的\(k\)-最近邻搜索(\(k\)-NN)。请确保通过运行以下命令安装
faisspip install faiss-cpu # or pip install faiss-gpu
取决于是否计划使用GPU处理进行\(k\)-NN搜索。
- Parameters:
- add(emb: Tensor)[source]
向
KNNIndex添加新的数据点以进行搜索。- Parameters:
emb (torch.Tensor) – 要添加的数据点。
- search(emb: Tensor, k: int, exclude_links: Optional[Tensor] = None) KNNOutput[source]
搜索给定数据点的\(k\)个最近邻居。返回最近邻居的距离/相似度分数及其索引。
- Parameters:
emb (torch.Tensor) – 要添加的数据点。
k (int) – 返回的最近邻居的数量。
exclude_links (torch.Tensor) – 需要从搜索中排除的链接。 需要是一个形状为
[2, num_links]的 COO 张量,其中exclude_links[0]指的是emb中的索引,而exclude_links[1]指的是KNNIndex中的数据点。(默认值:None)
- Return type:
KNNOutput