torch_geometric.nn.conv.CGConv
- class CGConv(channels: Union[int, Tuple[int, int]], dim: int = 0, aggr: str = 'add', batch_norm: bool = False, bias: bool = True, **kwargs)[source]
Bases:
MessagePassing来自“晶体图卷积神经网络用于材料特性的准确和可解释预测”论文的晶体图卷积算子。
\[\mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \sigma \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right) \odot g \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s \right)\]其中 \(\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{i,j} ]\) 表示中心节点特征、相邻节点特征和边特征的连接。 此外,\(\sigma\) 和 \(g\) 分别表示 sigmoid 和 softplus 函数。
- Parameters:
dim (int, optional) – 边缘特征的维度。(默认值:
0)aggr (str, optional) – 使用的聚合运算符 (
"add","mean","max"). (默认:"add")**kwargs (optional) – Additional arguments of
torch_geometric.nn.conv.MessagePassing.
- Shapes:
输入: 节点特征 \((|\mathcal{V}|, F)\) 或 \(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) 如果是二分图, 边索引 \((2, |\mathcal{E}|)\), 边特征 \((|\mathcal{E}|, D)\) (可选)
输出: 节点特征 \((|\mathcal{V}|, F)\) 或 \((|\mathcal{V_t}|, F_{t})\) 如果是二分图