GNNExplainer

class dgl.nn.pytorch.explain.GNNExplainer(model, num_hops, lr=0.01, num_epochs=100, *, alpha1=0.005, alpha2=1.0, beta1=1.0, beta2=0.1, log=True)[source]

Bases: Module

GNNExplainer模型来自GNNExplainer: 生成图神经网络的解释

它识别出在基于GNN的节点分类和图分类中起关键作用的紧凑子图结构和节点特征的小子集。

To generate an explanation, it learns an edge mask \(M\) and a feature mask \(F\) by optimizing the following objective function.

\[l(y, \hat{y}) + \alpha_1 \|M\|_1 + \alpha_2 H(M) + \beta_1 \|F\|_1 + \beta_2 H(F)\]

where \(l\) is the loss function, \(y\) is the original model prediction, \(\hat{y}\) is the model prediction with the edge and feature mask applied, \(H\) is the entropy function.

Parameters:
  • model (nn.Module) –

    要解释的GNN模型。

    • 其前向函数所需的参数是graph和feat。后者用于输入节点特征。

    • 它还应该可选地接受一个eweight参数用于边权重,并在消息传递中乘以消息。

    • 其前向函数的输出是预测节点/图类别的logits。

    另请参见explain_node()explain_graph()中的示例。

  • num_hops (int) – GNN信息聚合的跳数。

  • lr (float, optional) – 使用的学习率,默认为0.01。

  • num_epochs (int, optional) – 训练的轮数。

  • alpha1 (float, optional) – 较高的值将通过减少边缘掩码的总和来使解释边缘掩码更加稀疏。

  • alpha2 (float, optional) – A higher value will make the explanation edge masks more sparse by decreasing the entropy of the edge mask.

  • beta1 (float, optional) – 较高的值将通过减少节点特征掩码的平均值,使解释节点特征掩码更加稀疏。

  • beta2 (float, optional) – 较高的值将通过减少节点特征掩码的熵来使解释节点特征掩码更加稀疏。

  • log (bool, optional) – 如果为True,它将记录计算过程,默认为True。

explain_graph(graph, feat, **kwargs)[source]

学习并返回一个节点特征掩码和一个边掩码,这些掩码在解释GNN对图的预测中起着至关重要的作用。

Parameters:
  • graph (DGLGraph) – 一个同构图。

  • feat (Tensor) – 输入特征的形状为 \((N, D)\)\(N\) 是 节点数量,\(D\) 是特征大小。

  • kwargs (dict) – 传递给GNN模型的额外参数。第一维度是节点或边数量的张量将被视为节点/边特征。

Returns:

  • feat_mask (Tensor) – 学习到的特征重要性掩码,形状为 \((D)\),其中 \(D\) 是 特征大小。值的范围在 \((0, 1)\) 之间。 值越高,表示越重要。

  • edge_mask (Tensor) – 学习到的图中边的重要性掩码,它是一个形状为 \((E)\) 的张量, 其中 \(E\) 是图中边的数量。值的范围在 \((0, 1)\) 之间。值越高, 表示越重要。

示例

>>> import dgl.function as fn
>>> import torch
>>> import torch.nn as nn
>>> from dgl.data import GINDataset
>>> from dgl.dataloading import GraphDataLoader
>>> from dgl.nn import AvgPooling, GNNExplainer
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
>>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Define a model
>>> class Model(nn.Module):
...     def __init__(self, in_feats, out_feats):
...         super(Model, self).__init__()
...         self.linear = nn.Linear(in_feats, out_feats)
...         self.pool = AvgPooling()
...
...     def forward(self, graph, feat, eweight=None):
...         with graph.local_scope():
...             feat = self.linear(feat)
...             graph.ndata['h'] = feat
...             if eweight is None:
...                 graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
...             else:
...                 graph.edata['w'] = eweight
...                 graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
...             return self.pool(graph, graph.ndata['h'])
>>> # Train the model
>>> feat_size = data[0][0].ndata['attr'].shape[1]
>>> model = Model(feat_size, data.gclasses)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for bg, labels in dataloader:
...     logits = model(bg, bg.ndata['attr'])
...     loss = criterion(logits, labels)
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain the prediction for graph 0
>>> explainer = GNNExplainer(model, num_hops=1)
>>> g, _ = data[0]
>>> features = g.ndata['attr']
>>> feat_mask, edge_mask = explainer.explain_graph(g, features)
>>> feat_mask
tensor([0.2362, 0.2497, 0.2622, 0.2675, 0.2649, 0.2962, 0.2533])
>>> edge_mask
tensor([0.2154, 0.2235, 0.8325, ..., 0.7787, 0.1735, 0.1847])
explain_node(node_id, graph, feat, **kwargs)[source]

学习并返回一个节点特征掩码和子图,这些在解释GNN对节点node_id所做的预测中起着至关重要的作用。

Parameters:
  • node_id (int) – 要解释的节点。

  • graph (DGLGraph) – A homogeneous graph.

  • feat (Tensor) – 输入特征的形状为 \((N, D)\)\(N\) 是 节点数量,\(D\) 是特征大小。

  • kwargs (dict) – Additional arguments passed to the GNN model. Tensors whose first dimension is the number of nodes or edges will be assumed to be node/edge features.

Returns:

  • new_node_id (Tensor) – 输入中心节点的新ID。

  • sg (DGLGraph) – 输入中心节点的k跳入邻居诱导的子图。

  • feat_mask (Tensor) – 学习到的节点特征重要性掩码,形状为\((D)\),其中\(D\)是特征大小。值在\((0, 1)\)范围内。值越高,表示越重要。

  • edge_mask (Tensor) – 学习到的子图中边的重要性掩码,形状为\((E)\),其中\(E\)是子图中的边数。值在\((0, 1)\)范围内。值越高,表示越重要。

示例

>>> import dgl
>>> import dgl.function as fn
>>> import torch
>>> import torch.nn as nn
>>> from dgl.data import CoraGraphDataset
>>> from dgl.nn import GNNExplainer
>>> # Load dataset
>>> data = CoraGraphDataset()
>>> g = data[0]
>>> features = g.ndata['feat']
>>> labels = g.ndata['label']
>>> train_mask = g.ndata['train_mask']
>>> # Define a model
>>> class Model(nn.Module):
...     def __init__(self, in_feats, out_feats):
...         super(Model, self).__init__()
...         self.linear = nn.Linear(in_feats, out_feats)
...
...     def forward(self, graph, feat, eweight=None):
...         with graph.local_scope():
...             feat = self.linear(feat)
...             graph.ndata['h'] = feat
...             if eweight is None:
...                 graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
...             else:
...                 graph.edata['w'] = eweight
...                 graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
...             return graph.ndata['h']
>>> # Train the model
>>> model = Model(features.shape[1], data.num_classes)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for epoch in range(10):
...     logits = model(g, features)
...     loss = criterion(logits[train_mask], labels[train_mask])
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Explain the prediction for node 10
>>> explainer = GNNExplainer(model, num_hops=1)
>>> new_center, sg, feat_mask, edge_mask = explainer.explain_node(10, g, features)
>>> new_center
tensor([1])
>>> sg.num_edges()
12
>>> # Old IDs of the nodes in the subgraph
>>> sg.ndata[dgl.NID]
tensor([ 9, 10, 11, 12])
>>> # Old IDs of the edges in the subgraph
>>> sg.edata[dgl.EID]
tensor([51, 53, 56, 48, 52, 57, 47, 50, 55, 46, 49, 54])
>>> feat_mask
tensor([0.2638, 0.2738, 0.3039,  ..., 0.2794, 0.2643, 0.2733])
>>> edge_mask
tensor([0.0937, 0.1496, 0.8287, 0.8132, 0.8825, 0.8515, 0.8146, 0.0915, 0.1145,
        0.9011, 0.1311, 0.8437])
forward(*input: Any) None

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.