torch_geometric.nn.conv.GatedGraphConv

class GatedGraphConv(out_channels: int, num_layers: int, aggr: str = 'add', bias: bool = True, **kwargs)[source]

Bases: MessagePassing

来自“Gated Graph Sequence Neural Networks”论文的门控图卷积操作符。

\[ \begin{align}\begin{aligned}\mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0}\\\mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)}\\\mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)})\end{aligned}\end{align} \]

直到表示 \(\mathbf{h}_i^{(L)}\)\(\mathbf{x}_i\) 的输入通道数需要小于或等于 out_channels\(e_{j,i}\) 表示从源节点 j 到目标节点 i 的边权重(默认:1

Parameters:
  • out_channels (int) – Size of each output sample.

  • num_layers (int) – 序列长度 \(L\)

  • aggr (str, 可选) – 使用的聚合方案 ("add", "mean", "max"). (默认: "add")

  • bias (bool, optional) – If set to False, the layer will not learn an additive bias. (default: True)

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • 输入: 节点特征 \((|\mathcal{V}|, F_{in})\), 边索引 \((2, |\mathcal{E}|)\)

  • output: node features \((|\mathcal{V}|, F_{out})\)

forward(x: Tensor, edge_index: Union[Tensor, SparseTensor], edge_weight: Optional[Tensor] = None) Tensor[source]

运行模块的前向传播。

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。