torch_geometric.nn.models.DimeNet
- class DimeNet(hidden_channels: int, out_channels: int, num_blocks: int, num_bilinear: int, num_spherical: int, num_radial: int, cutoff: float = 5.0, max_num_neighbors: int = 32, envelope_exponent: int = 5, num_before_skip: int = 1, num_after_skip: int = 2, num_output_layers: int = 3, act: Union[str, Callable] = 'swish', output_initializer: str = 'zeros')[source]
Bases:
Module方向性消息传递神经网络(DimeNet)来自 “分子图的方向性消息传递”论文。 DimeNet以旋转等变的方式根据消息之间的角度转换消息。
注意
有关使用预训练DimeNet变体的示例,请参见 examples/qm9_pretrained_dimenet.py。
- Parameters:
hidden_channels (int) – 隐藏嵌入大小。
out_channels (int) – Size of each output sample.
num_blocks (int) – 构建块的数量。
num_bilinear (int) – 双线性层张量的大小。
num_spherical (int) – 球谐函数的数量。
num_radial (int) – 径向基函数的数量。
cutoff (float, optional) – 原子间相互作用的截止距离。(默认值:
5.0)max_num_neighbors (int, 可选) – 在
cutoff距离内为每个节点收集的最大邻居数。 (默认:32)envelope_exponent (int, optional) – 平滑截止的形状。 (默认:
5)num_before_skip (int, optional) – 在跳过连接之前的交互块中的残差层数。(默认值:
1)num_after_skip (int, optional) – 跳过连接后交互块中的残差层数。(默认值:
2)num_output_layers (int, optional) – 输出块的线性层数量。(默认:
3)act (str 或 Callable, 可选) – 激活函数。 (默认:
"swish")output_initializer (str, 可选) – 输出层的初始化方法 (
"zeros","glorot_orthogonal"). (默认:"zeros")
- forward(z: Tensor, pos: Tensor, batch: Optional[Tensor] = None) Tensor[source]
前向传播。
- Parameters:
z (torch.Tensor) – 每个原子的原子序数,形状为
[num_atoms]。pos (torch.Tensor) – 每个原子的坐标,形状为
[num_atoms, 3]。batch (torch.Tensor, optional) – 将每个原子分配到单独分子的批次索引,形状为
[num_atoms]。 (默认:None)
- Return type: