torch_geometric.nn.pool.knn_graph
- knn_graph(x: Tensor, k: int, batch: Optional[Tensor] = None, loop: bool = False, flow: str = 'source_to_target', cosine: bool = False, num_workers: int = 1, batch_size: Optional[int] = None) Tensor[源代码]
计算到最近的
k个点的图边。import torch from torch_geometric.nn import knn_graph x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) batch = torch.tensor([0, 0, 0, 0]) edge_index = knn_graph(x, k=2, batch=batch, loop=False)
- Parameters:
x (torch.Tensor) – 节点特征矩阵 \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
k (int) – 邻居的数量。
batch (torch.Tensor, optional) – 批次向量 \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), 它将每个 节点分配给一个特定的示例。(默认:
None)flow (str, 可选) – 当与消息传递结合使用时,流动方向 (
"source_to_target"或"target_to_source")。(默认:"source_to_target")cosine (bool, optional) – 如果为
True,将使用余弦距离而不是欧几里得距离来查找最近邻。 (默认:False)num_workers (int, optional) – 用于计算的工人数量。 如果
batch不是None,或者输入位于 GPU 上,则此参数无效。(默认值:1)batch_size (int, optional) – 示例的数量 \(B\)。 如果未给出,则自动计算。(默认值:
None)
- Return type: