torch_geometric.nn.pool.fps

fps(x: Tensor, batch: Optional[Tensor] = None, ratio: float = 0.5, random_start: bool = True, batch_size: Optional[int] = None) Tensor[source]

一种采样算法,来自“PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space”论文,该算法迭代地采样相对于其余点最远的点。

import torch
from torch_geometric.nn import fps

x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
batch = torch.tensor([0, 0, 0, 0])
index = fps(x, batch, ratio=0.5)
Parameters:
  • x (torch.Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • batch (torch.Tensor, optional) – Batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each node to a specific example. (default: None)

  • ratio (float, optional) – 采样比例。(默认: 0.5)

  • random_start (bool, 可选) – 如果设置为 False,则使用 \(\mathbf{X}\) 中的第一个节点作为起始节点。(默认值:obj:True

  • batch_size (int, optional) – The number of examples \(B\). Automatically calculated if not given. (default: None)

Return type:

torch.Tensor