EdgeGATConv
- class dgl.nn.pytorch.conv.EdgeGATConv(in_feats, edge_feats, out_feats, num_heads, feat_drop=0.0, attn_drop=0.0, negative_slope=0.2, residual=True, activation=None, allow_zero_in_degree=False, bias=True)[source]
Bases:
Module
带有边缘特征的图注意力层来自 SCENE
\[\mathbf{v}_i^\prime = \mathbf{\Theta}_\mathrm{s} \cdot \mathbf{v}_i + \sum\limits_{j \in \mathcal{N}(v_i)} \alpha_{j, i} \left( \mathbf{\Theta}_\mathrm{n} \cdot \mathbf{v}_j + \mathbf{\Theta}_\mathrm{e} \cdot \mathbf{e}_{j,i} \right)\]其中 \(\mathbf{\Theta}\) 用于表示可学习的权重矩阵,用于将节点的特征转换为更新(s=自身)、邻近节点(n=邻居)和边特征(e=边)。注意力权重通过以下方式获得:
\[\alpha_{j, i} = \mathrm{softmax}_i \Big( \mathrm{LeakyReLU} \big( \mathbf{a}^T [ \mathbf{\Theta}_\mathrm{n} \cdot \mathbf{v}_i || \mathbf{\Theta}_\mathrm{n} \cdot \mathbf{v}_j || \mathbf{\Theta}_\mathrm{e} \cdot \mathbf{e}_{j,i} ] \big) \Big)\]使用\(\mathbf{a}\)对应一个可学习的向量。 \(\mathrm{softmax_i}\)表示通过节点\(i\)的所有传入边进行归一化。
- Parameters:
in_feats (int, 或 一对 整数) – 输入特征大小;即 \(\mathbf{v}_i\) 的维度数。 GATConv 可以应用于同构图和单向 二分图。 如果该层要应用于单向二分图,
in_feats
指定源节点和目标节点的输入特征大小。如果 给定一个标量,源节点和目标节点的特征大小将取相同的值。edge_feats (int) – 边特征大小;即 :math:mathbf{e}_{j,i}` 的维度数。
out_feats (int) – 输出特征大小;即\(\mathbf{v}_i^\prime\)的维度数。
num_heads (int) – 多头注意力机制中的头数。
feat_drop (float, 可选) – 特征的丢弃率。默认值:
0
。attn_drop (float, optional) – 注意力权重的丢弃率。默认值:
0
。negative_slope (float, optional) – LeakyReLU 负斜率的角度。默认值:
0.2
。残差 (bool, 可选) – 如果为True,则使用残差连接。默认值:
False
。activation (可调用的激活函数/层 或 None, 可选.) – 如果不为None,则对更新的节点特征应用激活函数。 默认值:
None
.allow_zero_in_degree (bool, optional) – 如果图中存在0入度的节点,这些节点的输出将无效,因为没有消息会传递给这些节点。这对某些应用程序是有害的,可能导致无声的性能退化。如果检测到输入图中存在0入度的节点,此模块将引发DGLError。通过设置为
True
,它将抑制检查并让用户自行处理。默认值:False
。偏差 (bool, 可选) – 如果为True,则学习一个偏差项。默认值:
True
。
注意
零入度节点将导致无效的输出值。这是因为没有消息会传递到这些节点,聚合函数将在空输入上应用。避免这种情况的常见做法是如果图是同质的,则为每个节点添加自环,这可以通过以下方式实现:
>>> g = ... # a DGLGraph >>> g = dgl.add_self_loop(g)
调用
add_self_loop
在某些图中可能不起作用,例如异构图,因为无法为自环边决定边类型。在这些情况下,将allow_zero_in_degree
设置为True
以解除代码阻塞并手动处理零入度节点。处理此问题的常见做法是在使用卷积后过滤掉零入度的节点。示例
>>> import dgl >>> import numpy as np >>> import torch as th >>> from dgl.nn import EdgeGATConv
>>> # Case 1: Homogeneous graph. >>> num_nodes, num_edges = 8, 30 >>> # Generate a graph. >>> graph = dgl.rand_graph(num_nodes,num_edges) >>> node_feats = th.rand((num_nodes, 20)) >>> edge_feats = th.rand((num_edges, 12)) >>> edge_gat = EdgeGATConv( ... in_feats=20, ... edge_feats=12, ... out_feats=15, ... num_heads=3, ... ) >>> # Forward pass. >>> new_node_feats = edge_gat(graph, node_feats, edge_feats) >>> new_node_feats.shape torch.Size([8, 3, 15]) torch.Size([30, 3, 10])
>>> # Case 2: Unidirectional bipartite graph. >>> u = [0, 1, 0, 0, 1] >>> v = [0, 1, 2, 3, 2] >>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)}) >>> u_feat = th.tensor(np.random.rand(2, 25).astype(np.float32)) >>> v_feat = th.tensor(np.random.rand(4, 30).astype(np.float32)) >>> nfeats = (u_feat,v_feat) >>> efeats = th.tensor(np.random.rand(5, 15).astype(np.float32)) >>> in_feats = (25,30) >>> edge_feats = 15 >>> out_feats = 10 >>> num_heads = 3 >>> egat_model = EdgeGATConv( ... in_feats, ... edge_feats, ... out_feats, ... num_heads, ... ) >>> # Forward pass. >>> new_node_feats, attention_weights = egat_model(g, nfeats, efeats, get_attention=True) >>> new_node_feats.shape, attention_weights.shape (torch.Size([4, 3, 10]), torch.Size([5, 3, 1]))
- forward(graph, feat, edge_feat, get_attention=False)[source]
Description
计算图注意力网络层。
- param graph:
图表。
- type graph:
DGLGraph
- param feat:
如果给定一个torch.Tensor,输入特征的形状为\((N, *, D_{in})\),其中 \(D_{in}\)是输入特征的大小,\(N\)是节点的数量。 如果给定一对torch.Tensor,这对张量必须包含两个形状为 \((N_{in}, *, D_{in_{src}})\)和\((N_{out}, *, D_{in_{dst}})\)的张量。
- type feat:
torch.Tensor 或一对 torch.Tensor
- param edge_feat:
输入边的特征形状为 \((E, D_{in_{edge}})\), 其中 \(E\) 是边的数量,\(D_{in_{edge}}\) 是边特征的大小。
- type edge_feat:
torch.Tensor
- param get_attention:
是否返回注意力值。默认为False。
- type get_attention:
布尔值,可选
- returns:
torch.Tensor – 输出特征的形状为 \((N, *, H, D_{out})\),其中 \(H\) 是头的数量,\(D_{out}\) 是输出特征的大小。
torch.Tensor, 可选 – 形状为 \((E, *, H, 1)\) 的注意力值。仅当
get_attention
为True
时返回。
- raises DGLError:
如果输入图中存在0入度的节点,将会引发DGLError,因为没有消息会传递给这些节点。这将导致无效的输出。可以通过将
allow_zero_in_degree
参数设置为True
来忽略此错误。