torch_geometric.nn.unpool.knn_interpolate

knn_interpolate(x: Tensor, pos_x: Tensor, pos_y: Tensor, batch_x: Optional[Tensor] = None, batch_y: Optional[Tensor] = None, k: int = 3, num_workers: int = 1)[source]

来自“PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space”论文的k-NN插值。

对于每个点 \(y\) 的位置 \(\mathbf{p}(y)\),其插值特征 \(\mathbf{f}(y)\) 由以下公式给出:

\[\mathbf{f}(y) = \frac{\sum_{i=1}^k w(x_i) \mathbf{f}(x_i)}{\sum_{i=1}^k w(x_i)} \textrm{, where } w(x_i) = \frac{1}{d(\mathbf{p}(y), \mathbf{p}(x_i))^2}\]

并且 \(\{ x_1, \ldots, x_k \}\) 表示 \(k\) 个最近的点 到 \(y\)

Parameters:
  • x (torch.Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).

  • pos_x (torch.Tensor) – 节点位置矩阵 \(\in \mathbb{R}^{N \times d}\).

  • pos_y (torch.Tensor) – 上采样的节点位置矩阵 \(\in \mathbb{R}^{M \times d}\).

  • batch_x (torch.Tensor, optional) – 批次向量 \(\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个节点从 \(\mathbf{X}\) 分配到特定的示例。 (默认: None)

  • batch_y (torch.Tensor, optional) – 批次向量 \(\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N\), 它将 \(\mathbf{Y}\) 中的每个节点分配给特定的示例。 (默认: None)

  • k (int, 可选) – 邻居的数量。(默认: 3)

  • num_workers (int, optional) – 用于计算的工人数量。 在batch_xbatch_y不为 None,或输入位于GPU上的情况下无效。(默认值:1