torch_geometric.nn.pool.knn_graph

knn_graph(x: Tensor, k: int, batch: Optional[Tensor] = None, loop: bool = False, flow: str = 'source_to_target', cosine: bool = False, num_workers: int = 1, batch_size: Optional[int] = None) Tensor[源代码]

计算到最近的k个点的图边。

import torch
from torch_geometric.nn import knn_graph

x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
batch = torch.tensor([0, 0, 0, 0])
edge_index = knn_graph(x, k=2, batch=batch, loop=False)
Parameters:
  • x (torch.Tensor) – 节点特征矩阵 \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • k (int) – 邻居的数量。

  • batch (torch.Tensor, optional) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将每个 节点分配给一个特定的示例。(默认: None)

  • loop (bool, 可选) – 如果为 True,则图将包含自环。(默认值:False

  • flow (str, 可选) – 当与消息传递结合使用时,流动方向 ("source_to_target""target_to_source")。(默认: "source_to_target")

  • cosine (bool, optional) – 如果为 True,将使用余弦距离而不是欧几里得距离来查找最近邻。 (默认: False)

  • num_workers (int, optional) – 用于计算的工人数量。 如果 batch 不是 None,或者输入位于 GPU 上,则此参数无效。(默认值:1

  • batch_size (int, optional) – 示例的数量 \(B\)。 如果未给出,则自动计算。(默认值:None

Return type:

torch.Tensor