torch_geometric.nn.conv.PointGNNConv
- class PointGNNConv(mlp_h: Module, mlp_f: Module, mlp_g: Module, **kwargs)[source]
Bases:
MessagePassing来自“Point-GNN: Graph Neural Network for 3D Object Detection in a Point Cloud”论文的PointGNN操作符。
\[ \begin{align}\begin{aligned}\Delta \textrm{pos}_i &= h_{\mathbf{\Theta}}(\mathbf{x}_i)\\\mathbf{e}_{j,i} &= f_{\mathbf{\Theta}}(\textrm{pos}_j - \textrm{pos}_i + \Delta \textrm{pos}_i, \mathbf{x}_j)\\\mathbf{x}^{\prime}_i &= g_{\mathbf{\Theta}}(\max_{j \in \mathcal{N}(i)} \mathbf{e}_{j,i}) + \mathbf{x}_i\end{aligned}\end{align} \]相对位置用于消息传递步骤中,以引入全局平移不变性。为了应对中心节点局部邻域的偏移,作者提出利用对齐偏移。图应使用基于半径的截止静态构建。
- Parameters:
mlp_h (torch.nn.Module) – 一个神经网络 \(h_{\mathbf{\Theta}}\) 将大小为 \(F_{in}\) 的节点特征映射到三维坐标偏移 \(\Delta \textrm{pos}_i\)。
mlp_f (torch.nn.Module) – 一个神经网络 \(f_{\mathbf{\Theta}}\) 从大小为 \(F_{in}\) 的邻居特征和三维向量 \(\textrm{pos_j} - \textrm{pos_i} + \Delta \textrm{pos}_i\) 计算 \(\mathbf{e}_{j,i}\)。
mlp_g (torch.nn.Module) – 一个神经网络 \(g_{\mathbf{\Theta}}\),它将聚合的边缘特征映射回 \(F_{in}\)。
**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing.
- Shapes:
输入: 节点特征 \((|\mathcal{V}|, F_{in})\), 位置 \((|\mathcal{V}|, 3)\), 边索引 \((2, |\mathcal{E}|)\),
输出: 节点特征 \((|\mathcal{V}|, F_{in})\)