torch_geometric.nn.models.DimeNetPlusPlus
- class DimeNetPlusPlus(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', output_initializer: str = 'zeros')[source]
基础类:
DimeNet来自“快速且不确定性感知的非平衡分子定向消息传递”论文的DimeNet++。
DimeNetPlusPlus是对DimeNet模型的升级,速度提高了8倍,准确率提高了10%。- Parameters:
hidden_channels (int) – Hidden embedding size.
out_channels (int) – Size of each output sample.
num_blocks (int) – Number of building blocks.
int_emb_size (int) – 交互块中嵌入的大小。
basis_emb_size (int) – 交互块中基础嵌入的大小。
out_emb_channels (int) – 输出块中嵌入的大小。
num_spherical (int) – Number of spherical harmonics.
num_radial (int) – Number of radial basis functions.
cutoff (
float, 默认:5.0) – (float, 可选): 原子间相互作用的截止距离。(默认:5.0)max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the
cutoffdistance. (default:32)envelope_exponent (int, optional) – Shape of the smooth cutoff. (default:
5)num_before_skip (
int, 默认:1) – (int, 可选): 在跳跃连接之前的交互块中的残差层数。(默认:1)num_after_skip (
int, 默认:2) – (int, 可选): 跳过连接后交互块中的残差层数。(默认:2)num_output_layers (
int, 默认:3) – (int, 可选): 输出块的线性层数量。(默认:3)act (
Union[str,Callable], 默认:'swish') – (str 或 Callable, 可选): 激活函数。 (默认:"swish")output_initializer (str, optional) – The initialization method for the output layer (
"zeros","glorot_orthogonal"). (default:"zeros")
- forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor
前向传播。
- Parameters:
z (torch.Tensor) – Atomic number of each atom with shape
[num_atoms].pos (torch.Tensor) – Coordinates of each atom with shape
[num_atoms, 3].batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape
[num_atoms]. (default:None)
- Return type:
- reset_parameters()
重置模块的所有可学习参数。