torch_geometric.nn.conv.DNAConv

class DNAConv(channels: int, heads: int = 1, groups: int = 1, dropout: float = 0.0, cached: bool = False, normalize: bool = True, add_self_loops: bool = True, bias: bool = True, **kwargs)[source]

Bases: MessagePassing

来自“Just Jump: Towards Dynamic Neighborhood Aggregation in Graph Neural Networks”论文的动态邻域聚合操作符。

\[\mathbf{x}_v^{(t)} = h_{\mathbf{\Theta}}^{(t)} \left( \mathbf{x}_{v \leftarrow v}^{(t)}, \left\{ \mathbf{x}_{v \leftarrow w}^{(t)} : w \in \mathcal{N}(v) \right\} \right)\]

基于(多头)点积注意力

\[\mathbf{x}_{v \leftarrow w}^{(t)} = \textrm{Attention} \left( \mathbf{x}^{(t-1)}_v \, \mathbf{\Theta}_Q^{(t)}, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_K^{(t)}, \, [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_V^{(t)} \right)\]

使用 \(\mathbf{\Theta}_Q^{(t)}, \mathbf{\Theta}_K^{(t)}, \mathbf{\Theta}_V^{(t)}\) 分别表示查询、键和值信息的(分组)投影矩阵。 \(h^{(t)}_{\mathbf{\Theta}}\) 被实现为 torch_geometric.nn.conv.GCNConv 的不可训练版本。

注意

与其他层不同,此操作符期望节点特征的形状为 [num_nodes, num_layers, channels]

Parameters:
  • channels (int) – 每个输入/输出样本的大小。

  • heads (int, optional) – Number of multi-head-attentions. (default: 1)

  • groups (int, optional) – 用于所有线性投影的组数。(默认: 1)

  • dropout (float, optional) – 注意力系数的丢弃概率。(默认值:0.

  • cached (bool, optional) – If set to True, the layer will cache the computation of \(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2}\) on first execution, and will use the cached version for further executions. This parameter should only be set to True in transductive learning scenarios. (default: False)

  • normalize (bool, optional) – Whether to add self-loops and apply symmetric normalization. (default: True)

  • add_self_loops (bool, optional) – If set to False, will not add self-loops to the input graph. (default: True)

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • 输入: 节点特征 \((|\mathcal{V}|, L, F)\) 其中 \(L\) 是 层数, 边索引 \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F)\)

forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]

运行模块的前向传播。

Parameters:
  • x (torch.Tensor) – 输入节点特征的形状为 [num_nodes, num_layers, channels]

  • edge_index (torch.Tensor or SparseTensor) – The edge indices.

  • edge_weight (torch.Tensor, optional) – The edge weights. (default: None)

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。