torch_geometric.nn.unpool.knn_interpolate
- knn_interpolate(x: Tensor, pos_x: Tensor, pos_y: Tensor, batch_x: Optional[Tensor] = None, batch_y: Optional[Tensor] = None, k: int = 3, num_workers: int = 1)[source]
来自“PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space”论文的k-NN插值。
对于每个点 \(y\) 的位置 \(\mathbf{p}(y)\),其插值特征 \(\mathbf{f}(y)\) 由以下公式给出:
\[\mathbf{f}(y) = \frac{\sum_{i=1}^k w(x_i) \mathbf{f}(x_i)}{\sum_{i=1}^k w(x_i)} \textrm{, where } w(x_i) = \frac{1}{d(\mathbf{p}(y), \mathbf{p}(x_i))^2}\]并且 \(\{ x_1, \ldots, x_k \}\) 表示 \(k\) 个最近的点 到 \(y\)。
- Parameters:
x (torch.Tensor) – Node feature matrix \(\mathbf{X} \in \mathbb{R}^{N \times F}\).
pos_x (torch.Tensor) – 节点位置矩阵 \(\in \mathbb{R}^{N \times d}\).
pos_y (torch.Tensor) – 上采样的节点位置矩阵 \(\in \mathbb{R}^{M \times d}\).
batch_x (torch.Tensor, optional) – 批次向量 \(\mathbf{b_x} \in {\{ 0, \ldots, B-1\}}^N\), 它将 每个节点从 \(\mathbf{X}\) 分配到特定的示例。 (默认:
None)batch_y (torch.Tensor, optional) – 批次向量 \(\mathbf{b_y} \in {\{ 0, \ldots, B-1\}}^N\), 它将 \(\mathbf{Y}\) 中的每个节点分配给特定的示例。 (默认:
None)k (int, 可选) – 邻居的数量。(默认:
3)num_workers (int, optional) – 用于计算的工人数量。 在
batch_x或batch_y不为None,或输入位于GPU上的情况下无效。(默认值:1)