torch_geometric.nn.models.GroupAddRev
- class GroupAddRev(conv: Union[Module, ModuleList], split_dim: int = -1, num_groups: Optional[int] = None, disable: bool = False, num_bwd_passes: int = 1)[source]
基础类:
InvertibleModule来自“Graph Neural Networks with 1000 Layers”论文的分组可逆GNN模块。该模块使得训练任意深度的GNN成为可能,且内存复杂度与层数无关。
它通过将输入节点特征 \(\mathbf{X}\) 在特征维度上划分为 \(C\) 组来实现。然后,一个分组的可逆 GNN 块 \(f_{\theta(i)}\) 对一组输入进行操作并生成一组输出:
\[ \begin{align}\begin{aligned}\mathbf{X}^{\prime}_0 &= \sum_{i=2}^C \mathbf{X}_i\\\mathbf{X}^{\prime}_i &= f_{\theta(i)} ( \mathbf{X}^{\prime}_{i - 1}, \mathbf{A}) + \mathbf{X}_i\end{aligned}\end{align} \]对于所有 \(i \in \{ 1, \ldots, C \}\)。
注意
有关使用
GroupAddRev的示例,请参见 examples/rev_gnn.py。- Parameters:
conv (torch.nn.Module 或 torch.nn.ModuleList]) – 一个种子GNN。输入和输出的特征维度需要匹配。
split_dim (int, optional) – 用于分割组的维度。 (default:
-1)disable (bool, 可选) – 如果设置为
True,将禁用InvertibleFunction的使用,并且将在没有内存节省的情况下执行模块。(默认值:False)num_bwd_passes (int, optional) – 保留与输出链接的反向传递次数。在最后一次反向传递后,输出将被丢弃并释放内存。(默认值:
1)
- forward(*args)