torch_geometric.nn.models.SchNet
- class SchNet(hidden_channels: int = 128, num_filters: int = 128, num_interactions: int = 6, num_gaussians: int = 50, cutoff: float = 10.0, interaction_graph: Optional[Callable] = None, max_num_neighbors: int = 32, readout: str = 'add', dipole: bool = False, mean: Optional[float] = None, std: Optional[float] = None, atomref: Optional[Tensor] = None)[source]
Bases:
Module连续滤波器卷积神经网络 SchNet 来自 “SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions” 论文,该论文使用了 以下形式的交互块。
\[\mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \odot h_{\mathbf{\Theta}} ( \exp(-\gamma(\mathbf{e}_{j,i} - \mathbf{\mu}))),\]这里 \(h_{\mathbf{\Theta}}\) 表示一个MLP,而 \(\mathbf{e}_{j,i}\) 表示原子之间的原子间距离。
注意
有关使用预训练SchNet变体的示例,请参见 examples/qm9_pretrained_schnet.py。
- Parameters:
hidden_channels (int, optional) – 隐藏嵌入大小。 (默认值:
128)num_filters (int, optional) – 使用的过滤器数量。 (默认:
128)num_interactions (int, optional) – 交互块的数量。 (默认:
6)num_gaussians (int, optional) – 高斯分布的数量 \(\mu\)。 (default:
50)interaction_graph (可调用的, 可选的) – 用于计算成对交互图和原子间距离的函数。如果设置为
None,将基于cutoff和max_num_neighbors属性构建图。 如果提供,此方法接收pos和batch张量,并应返回(edge_index, edge_weight)张量。 (默认None)cutoff (float, optional) – 原子间相互作用的截止距离。 (默认值:
10.0)max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the
cutoffdistance. (default:32)readout (str, optional) – 是否应用
"add"或"mean"全局聚合。(默认值:"add")dipole (bool, 可选) – 如果设置为
True,将使用偶极矩的大小来进行最终预测,例如,对于torch_geometric.datasets.QM9的目标 0。 (默认:False)atomref (torch.Tensor, optional) – 单原子属性的参考。 期望一个形状为
(max_atomic_number, )的向量。
- forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor[source]
前向传播。
- Parameters:
z (torch.Tensor) – Atomic number of each atom with shape
[num_atoms].pos (torch.Tensor) – Coordinates of each atom with shape
[num_atoms, 3].batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape
[num_atoms]. (default:None)
- Return type: