SVDPE

class dgl.transforms.SVDPE(k, feat_name='svd_pe', padding=False, random_flip=True)[source]

Bases: BaseTransform

基于SVD的位置编码,如用全局自注意力替代图卷积中介绍

此函数计算最大的\(k\)奇异值及相应的左右奇异向量,以形成位置编码,这些编码可以存储在ndata中。

Parameters:
  • k (int) – 用于位置编码的最大奇异值及其对应的奇异向量的数量。

  • feat_name (str, optional) – 用于存储在ndata中计算的位置编码的名称。 默认值:svd_pe

  • padding (bool, optional) – 如果为False,当\(k > N\)时,会引发错误, 其中\(N\)g中的节点数。 如果为True,当\(k > N\)时,在编码的末尾添加零填充。 默认值:False。

  • random_flip (bool, optional) – 如果为True,随机翻转编码向量的符号。 建议在训练期间激活以获得更好的泛化能力。 默认值:True。

示例

>>> import dgl
>>> from dgl import SVDPE
>>> transform = SVDPE(k=2, feat_name="svd_pe")
>>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
>>> g_ = transform(g)
>>> print(g_.ndata['svd_pe'])
tensor([[-6.3246e-01, -1.1373e-07, -6.3246e-01,  0.0000e+00],
        [-6.3246e-01,  7.6512e-01, -6.3246e-01, -7.6512e-01],
        [ 6.3246e-01,  4.7287e-01,  6.3246e-01, -4.7287e-01],
        [-6.3246e-01, -7.6512e-01, -6.3246e-01,  7.6512e-01],
        [ 6.3246e-01, -4.7287e-01,  6.3246e-01,  4.7287e-01]])