AtomicConv
- class dgl.nn.pytorch.conv.AtomicConv(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling, features_to_use=None)[source]
Bases:
Module
原子卷积层来自用于预测蛋白质-配体结合亲和力的原子卷积网络
表示原子\(i\)的类型为\(z_i\),原子\(i\)和原子\(j\)之间的距离为\(r_{ij}\)。
距离变换
原子卷积层首先使用径向滤波器转换距离,然后执行池化操作。
对于由\(k\)索引的径向过滤器,它通过以下方式投影边缘距离
\[h_{ij}^{k} = \exp(-\gamma_{k}|r_{ij}-r_{k}|^2)\]如果 \(r_{ij} < c_k\),
\[f_{ij}^{k} = 0.5 * \cos(\frac{\pi r_{ij}}{c_k} + 1),\]否则,
\[f_{ij}^{k} = 0.\]最后,
\[e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}\]聚合
对于每种类型 \(t\),每个原子收集来自所有类型 \(t\) 的邻居原子的距离信息:
\[p_{i, t}^{k} = \sum_{j\in N(i)} e_{ij}^{k} * 1(z_j == t)\]然后连接所有RBF核和原子类型的结果。
- Parameters:
interaction_cutoffs (float32 tensor of shape (K)) – \(c_k\) 在上述方程中。大致上,它们可以被视为可学习的截断值,如果两个原子之间的距离小于截断值,则它们被认为是连接的。K 表示径向滤波器的数量。
rbf_kernel_means (float32 tensor of shape (K)) – \(r_k\) 在上述方程中。K 表示径向滤波器的数量。
rbf_kernel_scaling (float32 tensor of shape (K)) – \(\gamma_k\) 在上述方程中。K 表示径向滤波器的数量。
features_to_use (None 或 float tensor 的 shape (T)) – 在原始论文中,这些是要考虑的原子序数,代表原子的类型。T 表示原子序数的类型数量。默认为 None。
注意
这种卷积操作是为化学中的分子图设计的,但有可能将其扩展到更一般的图。
关于论文中\(e_{ij}^{k}\)的定义与作者的实现似乎存在不一致。我们遵循作者的实现。在论文中,\(e_{ij}^{k}\)被定义为\(\exp(-\gamma_{k}|r_{ij}-r_{k}|^2 * f_{ij}^{k})\)。
\(\gamma_{k}\), \(r_k\) 和 \(c_k\) 都是可学习的。
示例
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import AtomicConv
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3])) >>> feat = th.ones(6, 1) >>> edist = th.ones(6, 1) >>> interaction_cutoffs = th.ones(3).float() * 2 >>> rbf_kernel_means = th.ones(3).float() >>> rbf_kernel_scaling = th.ones(3).float() >>> conv = AtomicConv(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling) >>> res = conv(g, feat, edist) >>> res tensor([[0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [0.5000, 0.5000, 0.5000], [1.0000, 1.0000, 1.0000], [0.5000, 0.5000, 0.5000], [0.0000, 0.0000, 0.0000]], grad_fn=<ViewBackward>)
- forward(graph, feat, distances)[source]
Description
应用原子卷积层。
- param graph:
基于执行消息传递的拓扑结构。
- type graph:
DGLGraph
- param feat:
初始节点特征,即论文中的原子数。 \(V\) 表示节点数量。
- type feat:
形状为 \((V, 1)\) 的 Float32 张量
- param distances:
边的末端节点之间的距离。E 表示边的数量。
- type distances:
形状为 \((E, 1)\) 的 Float32 张量
- returns:
更新节点表示。\(V\) 表示节点数量,\(K\) 表示径向滤波器的数量,\(T\) 表示原子序数的类型数量。
- rtype:
形状为 \((V, K * T)\) 的 Float32 张量