torch_geometric.nn.conv.PointGNNConv

class PointGNNConv(mlp_h: Module, mlp_f: Module, mlp_g: Module, **kwargs)[source]

Bases: MessagePassing

来自“Point-GNN: Graph Neural Network for 3D Object Detection in a Point Cloud”论文的PointGNN操作符。

\[ \begin{align}\begin{aligned}\Delta \textrm{pos}_i &= h_{\mathbf{\Theta}}(\mathbf{x}_i)\\\mathbf{e}_{j,i} &= f_{\mathbf{\Theta}}(\textrm{pos}_j - \textrm{pos}_i + \Delta \textrm{pos}_i, \mathbf{x}_j)\\\mathbf{x}^{\prime}_i &= g_{\mathbf{\Theta}}(\max_{j \in \mathcal{N}(i)} \mathbf{e}_{j,i}) + \mathbf{x}_i\end{aligned}\end{align} \]

相对位置用于消息传递步骤中,以引入全局平移不变性。为了应对中心节点局部邻域的偏移,作者提出利用对齐偏移。图应使用基于半径的截止静态构建。

Parameters:
  • mlp_h (torch.nn.Module) – 一个神经网络 \(h_{\mathbf{\Theta}}\) 将大小为 \(F_{in}\) 的节点特征映射到三维坐标偏移 \(\Delta \textrm{pos}_i\)

  • mlp_f (torch.nn.Module) – 一个神经网络 \(f_{\mathbf{\Theta}}\) 从大小为 \(F_{in}\) 的邻居特征和三维向量 \(\textrm{pos_j} - \textrm{pos_i} + \Delta \textrm{pos}_i\) 计算 \(\mathbf{e}_{j,i}\)

  • mlp_g (torch.nn.Module) – 一个神经网络 \(g_{\mathbf{\Theta}}\),它将聚合的边缘特征映射回 \(F_{in}\)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • 输入: 节点特征 \((|\mathcal{V}|, F_{in})\), 位置 \((|\mathcal{V}|, 3)\), 边索引 \((2, |\mathcal{E}|)\),

  • 输出: 节点特征 \((|\mathcal{V}|, F_{in})\)

forward(x: Tensor, pos: Tensor, edge_index: Union[Tensor, SparseTensor]) Tensor[source]

运行模块的前向传播。

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。