AtomicConv

class dgl.nn.pytorch.conv.AtomicConv(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling, features_to_use=None)[source]

Bases: Module

原子卷积层来自用于预测蛋白质-配体结合亲和力的原子卷积网络

表示原子\(i\)的类型为\(z_i\),原子\(i\)和原子\(j\)之间的距离为\(r_{ij}\)

距离变换

原子卷积层首先使用径向滤波器转换距离,然后执行池化操作。

对于由\(k\)索引的径向过滤器,它通过以下方式投影边缘距离

\[h_{ij}^{k} = \exp(-\gamma_{k}|r_{ij}-r_{k}|^2)\]

如果 \(r_{ij} < c_k\),

\[f_{ij}^{k} = 0.5 * \cos(\frac{\pi r_{ij}}{c_k} + 1),\]

否则,

\[f_{ij}^{k} = 0.\]

最后,

\[e_{ij}^{k} = h_{ij}^{k} * f_{ij}^{k}\]

聚合

对于每种类型 \(t\),每个原子收集来自所有类型 \(t\) 的邻居原子的距离信息:

\[p_{i, t}^{k} = \sum_{j\in N(i)} e_{ij}^{k} * 1(z_j == t)\]

然后连接所有RBF核和原子类型的结果。

Parameters:
  • interaction_cutoffs (float32 tensor of shape (K)) – \(c_k\) 在上述方程中。大致上,它们可以被视为可学习的截断值,如果两个原子之间的距离小于截断值,则它们被认为是连接的。K 表示径向滤波器的数量。

  • rbf_kernel_means (float32 tensor of shape (K)) – \(r_k\) 在上述方程中。K 表示径向滤波器的数量。

  • rbf_kernel_scaling (float32 tensor of shape (K)) – \(\gamma_k\) 在上述方程中。K 表示径向滤波器的数量。

  • features_to_use (Nonefloat tensorshape (T)) – 在原始论文中,这些是要考虑的原子序数,代表原子的类型。T 表示原子序数的类型数量。默认为 None。

注意

  • 这种卷积操作是为化学中的分子图设计的,但有可能将其扩展到更一般的图。

  • 关于论文中\(e_{ij}^{k}\)的定义与作者的实现似乎存在不一致。我们遵循作者的实现。在论文中,\(e_{ij}^{k}\)被定义为\(\exp(-\gamma_{k}|r_{ij}-r_{k}|^2 * f_{ij}^{k})\)

  • \(\gamma_{k}\), \(r_k\)\(c_k\) 都是可学习的。

示例

>>> import dgl
>>> import numpy as np
>>> import torch as th
>>> from dgl.nn import AtomicConv
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = th.ones(6, 1)
>>> edist = th.ones(6, 1)
>>> interaction_cutoffs = th.ones(3).float() * 2
>>> rbf_kernel_means = th.ones(3).float()
>>> rbf_kernel_scaling = th.ones(3).float()
>>> conv = AtomicConv(interaction_cutoffs, rbf_kernel_means, rbf_kernel_scaling)
>>> res = conv(g, feat, edist)
>>> res
tensor([[0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000],
            [0.5000, 0.5000, 0.5000],
            [1.0000, 1.0000, 1.0000],
            [0.5000, 0.5000, 0.5000],
            [0.0000, 0.0000, 0.0000]], grad_fn=<ViewBackward>)
forward(graph, feat, distances)[source]

Description

应用原子卷积层。

param graph:

基于执行消息传递的拓扑结构。

type graph:

DGLGraph

param feat:

初始节点特征,即论文中的原子数。 \(V\) 表示节点数量。

type feat:

形状为 \((V, 1)\) 的 Float32 张量

param distances:

边的末端节点之间的距离。E 表示边的数量。

type distances:

形状为 \((E, 1)\) 的 Float32 张量

returns:

更新节点表示。\(V\) 表示节点数量,\(K\) 表示径向滤波器的数量,\(T\) 表示原子序数的类型数量。

rtype:

形状为 \((V, K * T)\) 的 Float32 张量