子图X

class dgl.nn.pytorch.explain.SubgraphX(model, num_hops, coef=10.0, high2low=True, num_child=12, num_rollouts=20, node_min=3, shapley_steps=100, log=False)[source]

Bases: Module

SubgraphX 来自 关于通过子图探索解释图神经网络的可解释性

它从原始图中识别出最重要的子图,该子图在基于GNN的图分类中起着关键作用。

它采用蒙特卡洛树搜索(MCTS)来高效探索不同的子图以进行解释,并使用Shapley值作为子图重要性的度量。

Parameters:
  • model (nn.Module) –

    用于解释多类图分类的GNN模型

    • 它的前向函数必须具有以下形式 forward(self, graph, nfeat).

    • 它的前向函数的输出是logits。

  • num_hops (int) – 模型中消息传递层的数量

  • coef (float, optional) – 这个超参数控制探索和利用之间的权衡。较高的值鼓励算法探索相对未被访问的节点。默认值:10.0

  • high2low (bool, 可选) – 如果为 True,它将使用“High2low”策略进行剪枝操作, 在搜索树中扩展子节点时,从高到低扩展子节点。否则,它将使用 “Low2high”策略。默认值:True

  • num_child (int, optional) – 这是在搜索树中扩展子节点时要展开的子节点数量。默认值:12

  • num_rollouts (int, optional) – 这是MCTS的模拟次数。默认值:20

  • node_min (int, optional) – 这是基于子图中节点数量定义叶节点的阈值。默认值:3

  • shapley_steps (int, optional) – 这是在估计Shapley值时用于蒙特卡洛采样的步骤数。默认值:100

  • log (bool, 可选) – 如果为True,它将记录进度。默认值:False

explain_graph(graph, feat, target_class, **kwargs)[source]

从原始图中找到最重要的子图,以便模型将图分类为目标类别。

Parameters:
  • graph (DGLGraph) – 一个同构图

  • feat (Tensor) – 输入节点特征的形状为 \((N, D)\)\(N\) 是 节点数量,\(D\) 是特征大小

  • target_class (int) – 要解释的目标类别

  • kwargs (dict) – 传递给GNN模型的额外参数

Returns:

表示最重要子图的节点

Return type:

张量

示例

>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.functional as F
>>> from dgl.data import GINDataset
>>> from dgl.dataloading import GraphDataLoader
>>> from dgl.nn import GraphConv, AvgPooling, SubgraphX
>>> # Define the model
>>> class Model(nn.Module):
...     def __init__(self, in_dim, n_classes, hidden_dim=128):
...         super().__init__()
...         self.conv1 = GraphConv(in_dim, hidden_dim)
...         self.conv2 = GraphConv(hidden_dim, n_classes)
...         self.pool = AvgPooling()
...
...     def forward(self, g, h):
...         h = F.relu(self.conv1(g, h))
...         h = self.conv2(g, h)
...         return self.pool(g, h)
>>> # Load dataset
>>> data = GINDataset('MUTAG', self_loop=True)
>>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model
>>> feat_size = data[0][0].ndata['attr'].shape[1]
>>> model = Model(feat_size, data.gclasses)
>>> criterion = nn.CrossEntropyLoss()
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
>>> for bg, labels in dataloader:
...     logits = model(bg, bg.ndata['attr'])
...     loss = criterion(logits, labels)
...     optimizer.zero_grad()
...     loss.backward()
...     optimizer.step()
>>> # Initialize the explainer
>>> explainer = SubgraphX(model, num_hops=2)
>>> # Explain the prediction for graph 0
>>> graph, l = data[0]
>>> graph_feat = graph.ndata.pop("attr")
>>> g_nodes_explain = explainer.explain_graph(graph, graph_feat,
...                                           target_class=l)
forward(*input: Any) None

定义每次调用时执行的计算。

应该由所有子类覆盖。

注意

尽管前向传递的配方需要在此函数内定义,但之后应该调用Module实例而不是这个,因为前者负责运行已注册的钩子,而后者则默默地忽略它们。