EGTLayer

class dgl.nn.pytorch.gt.EGTLayer(feat_size, edge_feat_size, num_heads, num_virtual_nodes, dropout=0, attn_dropout=0, activation=ELU(alpha=1.0), edge_update=True)[source]

Bases: Module

EGTLayer 用于边缘增强图变换器(EGT),如`Global Self-Attention as a Replacement for Graph Convolution Reference ``_中介绍

Parameters:
  • feat_size (int) – 节点特征大小。

  • edge_feat_size (int) – 边特征大小。

  • num_heads (int) – 注意力头的数量,:attr: feat_size 必须能被其整除。

  • num_virtual_nodes (int) – 虚拟节点的数量。

  • dropout (float, optional) – 丢弃概率。默认值:0.0。

  • attn_dropout (float, optional) – 注意力机制的dropout概率。默认值:0.0。

  • activation (可调用的激活层, 可选) – 激活函数。默认值:nn.ELU()。

  • edge_update (bool, 可选) – 是否更新边的嵌入。默认值:True。

示例

>>> import torch as th
>>> from dgl.nn import EGTLayer
>>> batch_size = 16
>>> num_nodes = 100
>>> feat_size, edge_feat_size = 128, 32
>>> nfeat = th.rand(batch_size, num_nodes, feat_size)
>>> efeat = th.rand(batch_size, num_nodes, num_nodes, edge_feat_size)
>>> net = EGTLayer(
        feat_size=feat_size,
        edge_feat_size=edge_feat_size,
        num_heads=8,
        num_virtual_nodes=4,
    )
>>> out = net(nfeat, efeat)
forward(nfeat, efeat, mask=None)[source]

前向计算。注意:如果 num_virtual_nodes > 0,nfeatefeat 应该用虚拟节点的嵌入进行填充,而 mask 应该用 0 值填充虚拟节点。填充应放在开头。

Parameters:
  • nfeat (torch.Tensor) – 一个3D输入张量。形状:(batch_size, N, feat_size),其中N是最大节点数和虚拟节点数的总和。

  • efeat (torch.Tensor) – 用于注意力计算和自我更新的边嵌入。 形状: (batch_size, N, N, edge_feat_size).

  • mask (torch.Tensor, optional) – 用于避免在无效位置上计算的注意力掩码,其中有效位置由0表示,无效位置由-inf表示。形状:(batch_size, N, N)。默认值:None。

Returns:

  • nfeat (torch.Tensor) – 输出的节点嵌入。形状:(batch_size, N, feat_size).

  • efeat (torch.Tensor, optional) – 输出的边嵌入。形状:(batch_size, N, N, edge_feat_size). 只有在 edge_update 为 True 时才会返回。