torch_geometric.nn.conv.GatedGraphConv
- class GatedGraphConv(out_channels: int, num_layers: int, aggr: str = 'add', bias: bool = True, **kwargs)[source]
Bases:
MessagePassing来自“Gated Graph Sequence Neural Networks”论文的门控图卷积操作符。
\[ \begin{align}\begin{aligned}\mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0}\\\mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)}\\\mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, \mathbf{h}_i^{(l)})\end{aligned}\end{align} \]直到表示 \(\mathbf{h}_i^{(L)}\)。 \(\mathbf{x}_i\) 的输入通道数需要小于或等于
out_channels。 \(e_{j,i}\) 表示从源节点j到目标节点i的边权重(默认:1)- Parameters:
out_channels (int) – Size of each output sample.
num_layers (int) – 序列长度 \(L\)。
aggr (str, 可选) – 使用的聚合方案 (
"add","mean","max"). (默认:"add")bias (bool, optional) – If set to
False, the layer will not learn an additive bias. (default:True)**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing.
- Shapes:
输入: 节点特征 \((|\mathcal{V}|, F_{in})\), 边索引 \((2, |\mathcal{E}|)\)
output: node features \((|\mathcal{V}|, F_{out})\)