torch_geometric.nn.models.SchNet

class SchNet(hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, interaction_graph: Optional[Callable] = None, max_num_neighbors: int = 32, readout: str = 'add', dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: Optional[Tensor] = None)[source]

Bases: Module

连续滤波器卷积神经网络 SchNet 来自 “SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions” 论文,该论文使用了 以下形式的交互块。

\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),\]

这里 \(h_{\mathbf{\Theta}}\) 表示一个MLP,而 \(\mathbf{e}_{j,i}\) 表示原子之间的原子间距离。

注意

有关使用预训练SchNet变体的示例,请参见 examples/qm9_pretrained_schnet.py

Parameters:
  • hidden_channels (int, optional) – 隐藏嵌入大小。 (默认值: 128)

  • num_filters (int, optional) – 使用的过滤器数量。 (默认: 128)

  • num_interactions (int, optional) – 交互块的数量。 (默认: 6)

  • num_gaussians (int, optional) – 高斯分布的数量 \(\mu\)。 (default: 50)

  • interaction_graph (可调用的, 可选的) – 用于计算成对交互图和原子间距离的函数。如果设置为 None,将基于cutoffmax_num_neighbors属性构建图。 如果提供,此方法接收posbatch 张量,并应返回(edge_index, edge_weight)张量。 (默认 None)

  • cutoff (float, optional) – 原子间相互作用的截止距离。 (默认值: 10.0)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • readout (str, optional) – 是否应用 "add""mean" 全局聚合。(默认值:"add"

  • dipole (bool, 可选) – 如果设置为 True,将使用偶极矩的大小来进行最终预测,例如,对于 torch_geometric.datasets.QM9 的目标 0。 (默认: False)

  • mean (float, optional) – 要预测的属性的平均值。 (默认值: None)

  • std (float, optional) – 要预测的属性的标准差。(默认值:None

  • atomref (torch.Tensor, optional) – 单原子属性的参考。 期望一个形状为 (max_atomic_number, ) 的向量。

forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor[source]

前向传播。

Parameters:
  • z (torch.Tensor) – Atomic number of each atom with shape [num_atoms].

  • pos (torch.Tensor) – Coordinates of each atom with shape [num_atoms, 3].

  • batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape [num_atoms]. (default: None)

Return type:

Tensor

reset_parameters()[source]

重置模块的所有可学习参数。

static from_qm9_pretrained(root: str, dataset: 数据集, target: int) Tuple[SchNet, 数据集, 数据集, 数据集][source]

返回一个在QM9数据集上预训练的SchNet模型,该模型在指定的目标target上进行训练。

Return type:

Tuple[SchNet, Dataset, Dataset, Dataset]