torch_geometric.nn.conv.CGConv

class CGConv(channels: Union[int, Tuple[int, int]], dim: int = 0, aggr: str = 'add', batch_norm: bool = False, bias: bool = True, **kwargs)[source]

Bases: MessagePassing

来自“晶体图卷积神经网络用于材料特性的准确和可解释预测”论文的晶体图卷积算子。

\[\mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \sigma \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right) \odot g \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s \right)\]

其中 \(\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, \mathbf{e}_{i,j} ]\) 表示中心节点特征、相邻节点特征和边特征的连接。 此外,\(\sigma\)\(g\) 分别表示 sigmoid 和 softplus 函数。

Parameters:
  • channels (inttuple) – 每个输入样本的大小。一个元组对应于源和目标维度的大小。

  • dim (int, optional) – 边缘特征的维度。(默认值:0

  • aggr (str, optional) – 使用的聚合运算符 ("add", "mean", "max"). (默认: "add")

  • batch_norm (bool, 可选) – 如果设置为 True,将使用批量归一化。(默认值:False

  • bias (bool, 可选) – 如果设置为 False,该层将不会学习 一个加性偏差。(默认值:True

  • **kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing.

Shapes:
  • 输入: 节点特征 \((|\mathcal{V}|, F)\)\(((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))\) 如果是二分图, 边索引 \((2, |\mathcal{E}|)\), 边特征 \((|\mathcal{E}|, D)\) (可选)

  • 输出: 节点特征 \((|\mathcal{V}|, F)\)\((|\mathcal{V_t}|, F_{t})\) 如果是二分图

forward(x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, SparseTensor], edge_attr: Optional[Tensor] = None) Tensor[source]

运行模块的前向传播。

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。