LapPosEncoder
- class dgl.nn.pytorch.gt.LapPosEncoder(model_type, num_layer, k, dim, n_head=1, batch_norm=False, num_post_layer=0)[source]
Bases:
Module
拉普拉斯位置编码器(LPE),如GraphGPS: General Powerful Scalable Graph Transformers中介绍
该模块是一个使用Transformer或DeepSet学习的拉普拉斯位置编码模块。
- Parameters:
model_type (str) – LPE的编码器模型类型,只能是“Transformer”或“DeepSet”。
num_layer (int) – Transformer/DeepSet 编码器中的层数。
k (int) – 最小的非平凡特征向量的数量。
dim (int) – 最终拉普拉斯编码的输出大小。
n_head (int, optional) – Transformer Encoder 中的头数。 默认值:1。
batch_norm (bool, optional) – 如果为True,则在原始拉普拉斯位置编码上应用批量归一化。默认值:False。
num_post_layer (int, optional) – 如果 num_post_layer > 0,在池化后应用一个具有
num_post_layer
层的 MLP。默认值:0。
示例
>>> import dgl >>> from dgl import LapPE >>> from dgl.nn import LapPosEncoder
>>> transform = LapPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True) >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4])) >>> g = transform(g) >>> eigvals, eigvecs = g.ndata['eigval'], g.ndata['eigvec'] >>> transformer_encoder = LapPosEncoder( model_type="Transformer", num_layer=3, k=5, dim=16, n_head=4 ) >>> pos_encoding = transformer_encoder(eigvals, eigvecs) >>> deepset_encoder = LapPosEncoder( model_type="DeepSet", num_layer=3, k=5, dim=16, num_post_layer=2 ) >>> pos_encoding = deepset_encoder(eigvals, eigvecs)