空间编码器3D
- class dgl.nn.pytorch.gt.SpatialEncoder3d(num_kernels, num_heads=1, max_node_type=100)[source]
Bases:
Module
3D 空间编码器,如一个Transformer可以理解2D和3D分子数据中介绍的那样
该模块根据高斯基础核函数对三维几何空间中节点对 \((i,j)\) 之间的成对关系进行编码:
\(\psi _{(i,j)} ^k = \frac{1}{\sqrt{2\pi} \lvert \sigma^k \rvert} \exp{\left ( -\frac{1}{2} \left( \frac{\gamma_{(i,j)} \lvert \lvert r_i - r_j \rvert \rvert + \beta_{(i,j)} - \mu^k}{\lvert \sigma^k \rvert} \right) ^2 \right)},k=1,...,K,\)
其中 \(K\) 是高斯基核的数量。\(r_i\) 是节点 \(i\) 的笛卡尔坐标。 \(\gamma_{(i,j)}, \beta_{(i,j)}\) 是由节点类型决定的可学习的缩放因子和偏置。\(\mu^k, \sigma^k\) 是高斯基核的可学习中心和标准差。
- Parameters:
示例
>>> import torch as th >>> import dgl >>> from dgl.nn import SpatialEncoder3d
>>> coordinate = th.rand(1, 4, 3) >>> node_type = th.tensor([[1, 0, 2, 1]]) >>> spatial_encoder = SpatialEncoder3d(num_kernels=4, ... num_heads=8, ... max_node_type=3) >>> out = spatial_encoder(coordinate, node_type=node_type) >>> print(out.shape) torch.Size([1, 4, 4, 8])
- forward(coord, node_type=None)[source]
- Parameters:
coord (torch.Tensor) – 节点的3D坐标,形状为 \((B, N, 3)\),其中 \(B\) 是批量大小,\(N\) 是节点的最大数量。
node_type (torch.Tensor, optional) –
节点的节点类型ID。默认值:无。
如果指定,
node_type
应该是一个形状为 \((B, N,)\) 的张量。每对节点的高斯核中的缩放因子由它们的节点类型决定。否则,
node_type
将默认设置为相同形状的零张量。
- Returns:
返回注意力偏差作为形状的3D空间编码 \((B, N, N, H)\),其中 \(H\) 是
num_heads
。- Return type:
torch.Tensor