异质图神经网络解释器

class dgl.nn.pytorch.explain.HeteroGNNExplainer(model, num_hops, lr=0.01, num_epochs=100, *, alpha1=0.005, alpha2=1.0, beta1=1.0, beta2=0.1, log=True)[source]

Bases: Module

GNNExplainer模型来自GNNExplainer: 生成图神经网络的解释,适用于异构图

它识别出在基于GNN的节点分类和图分类中起关键作用的紧凑子图结构和节点特征的小子集。

为了生成解释,它通过学习一个边缘掩码 \(M\) 和一个特征掩码 \(F\) 来优化以下目标函数。

\[l(y, \hat{y}) + \alpha_1 \|M\|_1 + \alpha_2 H(M) + \beta_1 \|F\|_1 + \beta_2 H(F)\]

其中 \(l\) 是损失函数,\(y\) 是原始模型预测, \(\hat{y}\) 是应用了边和特征掩码的模型预测,\(H\) 是 熵函数。

Parameters:
  • model (nn.Module) –

    要解释的GNN模型。

    • 其前向函数所需的参数是graph和feat。后者用于输入节点特征。

    • 它还应该可选地接受一个eweight参数用于边权重,并在消息传递中乘以消息。

    • 其前向函数的输出是预测节点/图类别的logits。

    另请参见explain_node()explain_graph()中的示例。

  • num_hops (int) – GNN信息聚合的跳数。

  • lr (float, optional) – 使用的学习率,默认为0.01。

  • num_epochs (int, optional) – 训练的轮数。

  • alpha1 (float, optional) – 较高的值将通过减少边缘掩码的总和来使解释边缘掩码更加稀疏。

  • alpha2 (float, optional) – 较高的值将通过减少边缘掩码的熵来使解释边缘掩码更加稀疏。

  • beta1 (float, optional) – 较高的值将通过降低节点特征掩码的平均值,使解释节点特征掩码更加稀疏。

  • beta2 (float, optional) – 较高的值将通过减少节点特征掩码的熵,使解释节点特征掩码更加稀疏。

  • log (bool, 可选) – 如果为 True,它将记录计算过程,默认为 True。

explain_graph(graph, feat, **kwargs)[source]

学习并返回节点特征掩码和边掩码,这些掩码在解释GNN对图的预测中起着至关重要的作用。

Parameters:
  • graph (DGLGraph) – 一个将被解释的异构图。

  • feat (dict[str, Tensor]) – 将输入节点特征(值)与图中存在的相应节点类型(键)关联的字典。 输入特征的形状为 \((N_t, D_t)\)\(N_t\) 是节点类型 \(t\) 的节点数量,\(D_t\) 是节点类型 \(t\) 的特征大小。

  • kwargs (dict) – 传递给GNN模型的额外参数。

Returns:

  • feat_mask (dict[str, Tensor]) – 将学习到的节点特征重要性掩码(值)与相应的节点类型(键)关联的字典。掩码的形状为 \((D_t)\),其中 \(D_t\) 是节点类型 t 的节点特征大小。值在 \((0, 1)\) 范围内。值越高,表示越重要。

  • edge_mask (dict[Tuple[str], Tensor]) – 将学习到的边重要性掩码(值)与相应的规范边类型(键)关联的字典。掩码的形状为 \((E_t)\),其中 \(E_t\) 是图中规范边类型 \(t\) 的边数。值在 \((0, 1)\) 范围内。值越高,表示越重要。

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch as th
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.nn import HeteroGNNExplainer
>>> class Model(nn.Module):
...     def __init__(self, in_dim, num_classes, canonical_etypes):
...         super(Model, self).__init__()
...         self.etype_weights = nn.ModuleDict({
...             '_'.join(c_etype): nn.Linear(in_dim, num_classes)
...             for c_etype in canonical_etypes
...         })
...
...     def forward(self, graph, feat, eweight=None):
...         with graph.local_scope():
...             c_etype_func_dict = {}
...             for c_etype in graph.canonical_etypes:
...                 src_type, etype, dst_type = c_etype
...                 wh = self.etype_weights['_'.join(c_etype)](feat[src_type])
...                 graph.nodes[src_type].data[f'h_{c_etype}'] = wh
...                 if eweight is None:
...                     c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'),
...                         fn.mean('m', 'h'))
...                 else:
...                     graph.edges[c_etype].data['w'] = eweight[c_etype]
...                     c_etype_func_dict[c_etype] = (
...                         fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h'))
...             graph.multi_update_all(c_etype_func_dict, 'sum')
...             hg = 0
...             for ntype in graph.ntypes:
...                 if graph.num_nodes(ntype):
...                     hg = hg + dgl.mean_nodes(graph, 'h', ntype=ntype)
...             return hg
>>> input_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({
...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes['user'].data['h'] = th.randn(g.num_nodes('user'), input_dim)
>>> g.nodes['game'].data['h'] = th.randn(g.num_nodes('game'), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, num_classes, g.canonical_etypes)
>>> feat = g.ndata['h']
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
...     logits = model(g, feat)
...     loss = F.cross_entropy(logits, th.tensor([1]))
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain for the graph
>>> explainer = HeteroGNNExplainer(model, num_hops=1)
>>> feat_mask, edge_mask = explainer.explain_graph(g, feat)
>>> feat_mask
{'game': tensor([0.2684, 0.2597, 0.3135, 0.2976, 0.2607]),
 'user': tensor([0.2216, 0.2908, 0.2644, 0.2738, 0.2663])}
>>> edge_mask
{('game', 'rev_plays', 'user'): tensor([0.8922, 0.1966, 0.8371, 0.1330]),
 ('user', 'plays', 'game'): tensor([0.1785, 0.1696, 0.8065, 0.2167])}
explain_node(ntype, node_id, graph, feat, **kwargs)[source]

学习并返回节点特征掩码和一个子图,这些在解释GNN对类型为ntype的节点node_id所做的预测中起着关键作用。

它要求model返回一个将节点类型映射到特定类型预测的字典。

Parameters:
  • ntype (str) – 要解释的节点类型。model 必须经过训练才能对此特定节点类型进行预测。

  • node_id (int) – 要解释的节点的ID。

  • graph (DGLGraph) – 一个异构图。

  • feat (dict[str, Tensor]) – The dictionary that associates input node features (values) with the respective node types (keys) present in the graph. The input features are of shape \((N_t, D_t)\). \(N_t\) is the number of nodes for node type \(t\), and \(D_t\) is the feature size for node type \(t\)

  • kwargs (dict) – Additional arguments passed to the GNN model.

Returns:

  • new_node_id (Tensor) – 输入中心节点的新ID。

  • sg (DGLGraph) – 输入中心节点的k跳入邻居诱导的子图。

  • feat_mask (dict[str, Tensor]) – 将学习到的节点特征重要性掩码(值)与相应的节点类型(键)关联的字典。掩码的形状为\((D_t)\),其中\(D_t\)是节点类型t的节点特征大小。值在\((0, 1)\)范围内。值越高,表示越重要。

  • edge_mask (dict[Tuple[str], Tensor]) – 将学习到的边重要性掩码(值)与相应的规范边类型(键)关联的字典。掩码的形状为\((E_t)\),其中\(E_t\)是子图中规范边类型\(t\)的边数。值在\((0, 1)\)范围内。值越高,表示越重要。

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch as th
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.nn import HeteroGNNExplainer
>>> class Model(nn.Module):
...     def __init__(self, in_dim, num_classes, canonical_etypes):
...         super(Model, self).__init__()
...         self.etype_weights = nn.ModuleDict({
...             '_'.join(c_etype): nn.Linear(in_dim, num_classes)
...             for c_etype in canonical_etypes
...         })
...
...     def forward(self, graph, feat, eweight=None):
...         with graph.local_scope():
...             c_etype_func_dict = {}
...             for c_etype in graph.canonical_etypes:
...                 src_type, etype, dst_type = c_etype
...                 wh = self.etype_weights['_'.join(c_etype)](feat[src_type])
...                 graph.nodes[src_type].data[f'h_{c_etype}'] = wh
...                 if eweight is None:
...                     c_etype_func_dict[c_etype] = (fn.copy_u(f'h_{c_etype}', 'm'),
...                         fn.mean('m', 'h'))
...                 else:
...                     graph.edges[c_etype].data['w'] = eweight[c_etype]
...                     c_etype_func_dict[c_etype] = (
...                         fn.u_mul_e(f'h_{c_etype}', 'w', 'm'), fn.mean('m', 'h'))
...             graph.multi_update_all(c_etype_func_dict, 'sum')
...             return graph.ndata['h']
>>> input_dim = 5
>>> num_classes = 2
>>> g = dgl.heterograph({
...     ('user', 'plays', 'game'): ([0, 1, 1, 2], [0, 0, 1, 1])})
>>> g.nodes['user'].data['h'] = th.randn(g.num_nodes('user'), input_dim)
>>> g.nodes['game'].data['h'] = th.randn(g.num_nodes('game'), input_dim)
>>> transform = dgl.transforms.AddReverse()
>>> g = transform(g)
>>> # define and train the model
>>> model = Model(input_dim, num_classes, g.canonical_etypes)
>>> feat = g.ndata['h']
>>> optimizer = th.optim.Adam(model.parameters())
>>> for epoch in range(10):
...     logits = model(g, feat)['user']
...     loss = F.cross_entropy(logits, th.tensor([1, 1, 1]))
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain the prediction for node 0 of type 'user'
>>> explainer = HeteroGNNExplainer(model, num_hops=1)
>>> new_center, sg, feat_mask, edge_mask = explainer.explain_node('user', 0, g, feat)
>>> new_center
tensor([0])
>>> sg
Graph(num_nodes={'game': 1, 'user': 1},
      num_edges={('game', 'rev_plays', 'user'): 1, ('user', 'plays', 'game'): 1,
                 ('user', 'rev_rev_plays', 'game'): 1},
      metagraph=[('game', 'user', 'rev_plays'), ('user', 'game', 'plays'),
                 ('user', 'game', 'rev_rev_plays')])
>>> feat_mask
{'game': tensor([0.2348, 0.2780, 0.2611, 0.2513, 0.2823]),
 'user': tensor([0.2716, 0.2450, 0.2658, 0.2876, 0.2738])}
>>> edge_mask
{('game', 'rev_plays', 'user'): tensor([0.0630]),
 ('user', 'plays', 'game'): tensor([0.1939]),
 ('user', 'rev_rev_plays', 'game'): tensor([0.9166])}
forward(*input: Any) None

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.