torch_geometric.nn.pool.KNNIndex

class KNNIndex(index_factory: Optional[str] = None, emb: Optional[Tensor] = None, reserve: Optional[int] = None)[source]

Bases: object

一个基类,用于通过faiss库执行快速的\(k\)-最近邻搜索(\(k\)-NN)。

请确保通过运行以下命令安装faiss

pip install faiss-cpu
# or
pip install faiss-gpu

取决于是否计划使用GPU处理进行\(k\)-NN搜索。

Parameters:
  • index_factory (str, optional) – 要使用的索引工厂的名称, 例如, "IndexFlatL2""IndexFlatIP"。更多信息请参见 这里

  • emb (torch.Tensor, optional) – The data points to add. (default: None)

  • reserve (int, optional) – 在重新分配内存之前保留的元素数量(仅限GPU)。(默认值:None

property numel: int

要搜索的数据点数量。

Return type:

int

add(emb: Tensor)[source]

KNNIndex添加新的数据点以进行搜索。

Parameters:

emb (torch.Tensor) – 要添加的数据点。

search(emb: Tensor, k: int, exclude_links: Optional[Tensor] = None) KNNOutput[source]

搜索给定数据点的\(k\)个最近邻居。返回最近邻居的距离/相似度分数及其索引。

Parameters:
  • emb (torch.Tensor) – 要添加的数据点。

  • k (int) – 返回的最近邻居的数量。

  • exclude_links (torch.Tensor) – 需要从搜索中排除的链接。 需要是一个形状为 [2, num_links] 的 COO 张量,其中 exclude_links[0] 指的是 emb 中的索引,而 exclude_links[1] 指的是 KNNIndex 中的数据点。(默认值:None

Return type:

KNNOutput

get_emb() Tensor[source]

返回存储在KNNIndex中的数据点。

Return type:

Tensor