空间编码器

class dgl.nn.pytorch.gt.SpatialEncoder(max_dist, num_heads=1)[source]

Bases: Module

空间编码器,如《Do Transformers Really Perform Bad for Graph Representation?》中介绍的那样。

该模块是一个可学习的空间嵌入模块,它为注意力偏差编码了每对节点之间的最短距离。

Parameters:
  • max_dist (int) – 每个节点对之间要编码的最短路径距离的上限。 所有距离将被限制在范围 [0, max_dist] 内。

  • num_heads (int, optional) – 如果应用了多头注意力机制,则为注意力头的数量。 默认值:1。

示例

>>> import torch as th
>>> import dgl
>>> from dgl.nn import SpatialEncoder
>>> from dgl import shortest_dist
>>> g1 = dgl.graph(([0,0,0,1,1,2,3,3], [1,2,3,0,3,0,0,1]))
>>> g2 = dgl.graph(([0,1], [1,0]))
>>> n1, n2 = g1.num_nodes(), g2.num_nodes()
>>> # use -1 padding since shortest_dist returns -1 for unreachable node pairs
>>> dist = -th.ones((2, 4, 4), dtype=th.long)
>>> dist[0, :n1, :n1] = shortest_dist(g1, root=None, return_paths=False)
>>> dist[1, :n2, :n2] = shortest_dist(g2, root=None, return_paths=False)
>>> spatial_encoder = SpatialEncoder(max_dist=2, num_heads=8)
>>> out = spatial_encoder(dist)
>>> print(out.shape)
torch.Size([2, 4, 4, 8])
forward(dist)[source]
Parameters:

dist (Tensor) – 带有-1填充的批处理图的最短路径距离,一个形状为 \((B, N, N)\) 的张量,其中 \(B\) 是批处理图的批量大小,\(N\) 是最大节点数。

Returns:

返回注意力偏差作为形状的空间编码 \((B, N, N, H)\),其中 \(H\)num_heads

Return type:

torch.Tensor