torch_geometric.nn.conv.PointTransformerConv

class PointTransformerConv(in_channels: Union[int, Tuple[int, int]], out_channels: int, pos_nn: Optional[Callable] = None, attn_nn: Optional[Callable] = None, add_self_loops: bool = True, **kwargs)[source]

Bases: MessagePassing

来自“Point Transformer”论文的Point Transformer层。

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j} \left(\mathbf{W}_3 \mathbf{x}_j + \delta_{ij} \right),\]

其中注意力系数 \(\alpha_{i,j}\) 和 位置嵌入 \(\delta_{ij}\) 计算如下

\[\alpha_{i,j}= \textrm{softmax} \left( \gamma_\mathbf{\Theta} (\mathbf{W}_1 \mathbf{x}_i - \mathbf{W}_2 \mathbf{x}_j + \delta_{i,j}) \right)\]

\[\delta_{i,j}= h_{\mathbf{\Theta}}(\mathbf{p}_i - \mathbf{p}_j),\]

使用 \(\gamma_\mathbf{\Theta}\)\(h_\mathbf{\Theta}\) 表示神经网络,即多层感知机(MLPs),并且 \(\mathbf{P} \in \mathbb{R}^{N \times D}\) 定义了每个点的位置。

Parameters:
  • in_channels (int or tuple) – Size of each input sample, or -1 to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities.

  • out_channels (int) – Size of each output sample.

  • pos_nn (torch.nn.Module, optional) – 一个神经网络 \(h_\mathbf{\Theta}\),它将相对空间坐标 pos_j - pos_i 的形状从 [-1, 3] 映射到 [-1, out_channels]。 如果没有进一步指定,将默认使用 torch.nn.Linear 转换。(默认值: None)

  • attn_nn (torch.nn.Module, optional) – 一个神经网络 \(\gamma_\mathbf{\Theta}\),它将形状为 [-1, out_channels] 的节点特征映射到形状 [-1, out_channels]。(默认值:None

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • input: node features \((|\mathcal{V}|, F_{in})\) or \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) if bipartite, positions \((|\mathcal{V}|, 3)\) or \(((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))\) if bipartite, edge indices \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\) or \((|\mathcal{V}_t|, F_{out})\) if bipartite

forward(x: Union[Tensor, Tuple[Tensor, Tensor]], pos: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor]) Tensor[source]

运行模块的前向传播。

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。