torch_geometric.nn.models.DimeNetPlusPlus

class DimeNetPlusPlus(hidden_channels: int, out_channels: int, num_blocks: int, int_emb_size: int, basis_emb_size: int, out_emb_channels: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', output_initializer: str = 'zeros')[source]

基础类:DimeNet

来自“快速且不确定性感知的非平衡分子定向消息传递”论文的DimeNet++。

DimeNetPlusPlus 是对 DimeNet 模型的升级,速度提高了8倍,准确率提高了10%。

Parameters:
  • hidden_channels (int) – Hidden embedding size.

  • out_channels (int) – Size of each output sample.

  • num_blocks (int) – Number of building blocks.

  • int_emb_size (int) – 交互块中嵌入的大小。

  • basis_emb_size (int) – 交互块中基础嵌入的大小。

  • out_emb_channels (int) – 输出块中嵌入的大小。

  • num_spherical (int) – Number of spherical harmonics.

  • num_radial (int) – Number of radial basis functions.

  • cutoff (float, 默认: 5.0) – (float, 可选): 原子间相互作用的截止距离。(默认: 5.0)

  • max_num_neighbors (int, optional) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • envelope_exponent (int, optional) – Shape of the smooth cutoff. (default: 5)

  • num_before_skip (int, 默认: 1) – (int, 可选): 在跳跃连接之前的交互块中的残差层数。(默认: 1)

  • num_after_skip (int, 默认: 2) – (int, 可选): 跳过连接后交互块中的残差层数。(默认: 2)

  • num_output_layers (int, 默认: 3) – (int, 可选): 输出块的线性层数量。(默认: 3)

  • act (Union[str, Callable], 默认: 'swish') – (str 或 Callable, 可选): 激活函数。 (默认: "swish")

  • output_initializer (str, optional) – The initialization method for the output layer ("zeros", "glorot_orthogonal"). (default: "zeros")

forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor

前向传播。

Parameters:
  • z (torch.Tensor) – Atomic number of each atom with shape [num_atoms].

  • pos (torch.Tensor) – Coordinates of each atom with shape [num_atoms, 3].

  • batch (torch.Tensor, optional) – Batch indices assigning each atom to a separate molecule with shape [num_atoms]. (default: None)

Return type:

Tensor

reset_parameters()

重置模块的所有可学习参数。

classmethod from_qm9_pretrained(root: str, dataset: 数据集, target: int) Tuple[DimeNetPlusPlus, 数据集, 数据集, 数据集][source]

返回一个在QM9数据集上预训练的DimeNetPlusPlus模型,该模型在指定的目标target上进行训练。

Return type:

Tuple[DimeNetPlusPlus, Dataset, Dataset, Dataset]