BiasedMHA
- class dgl.nn.pytorch.gt.BiasedMHA(feat_size, num_heads, bias=True, attn_bias_type='add', attn_drop=0.1)[source]
Bases:
Module
带有图注意力偏置的密集多头注意力模块。
计算节点之间的注意力,使用从图结构中获得的注意力偏差,如Do Transformers Really Perform Bad for Graph Representation?中介绍的那样。
\[\text{Attn}=\text{softmax}(\dfrac{QK^T}{\sqrt{d}} \circ b)\]\(Q\) 和 \(K\) 是节点的特征表示。\(d\) 是对应的
feat_size
。\(b\) 是注意力偏差,根据运算符 \(\circ\) 可以是加性或乘性的。- Parameters:
feat_size (int) – Feature size.
num_heads (int) – Number of attention heads, by which
feat_size
is divisible.bias (bool, optional) – 如果为True,则在线性投影中使用偏置。默认值:True。
attn_bias_type (str, optional) –
The type of attention bias used for modifying attention. Selected from ‘add’ or ‘mul’. Default: ‘add’.
’add’ is for additive attention bias.
’mul’ is for multiplicative attention bias.
attn_drop (float, optional) – 注意力权重的丢弃概率。默认值:0.1。
示例
>>> import torch as th >>> from dgl.nn import BiasedMHA
>>> ndata = th.rand(16, 100, 512) >>> bias = th.rand(16, 100, 100, 8) >>> net = BiasedMHA(feat_size=512, num_heads=8) >>> out = net(ndata, bias)
- forward(ndata, attn_bias=None, attn_mask=None)[source]
前向计算。
- Parameters:
ndata (torch.Tensor) – 一个3D输入张量。形状:(batch_size, N,
feat_size
),其中 N是节点的最大数量。attn_bias (torch.Tensor, optional) – The attention bias used for attention modification. Shape: (batch_size, N, N,
num_heads
).attn_mask (torch.Tensor, optional) – The attention mask used for avoiding computation on invalid positions, where invalid positions are indicated by True values. Shape: (batch_size, N, N). Note: For rows corresponding to unexisting nodes, make sure at least one entry is set to False to prevent obtaining NaNs with softmax.
- Returns:
y – The output tensor. Shape: (batch_size, N,
feat_size
)- Return type:
torch.Tensor