子图X
- class dgl.nn.pytorch.explain.SubgraphX(model, num_hops, coef=10.0, high2low=True, num_child=12, num_rollouts=20, node_min=3, shapley_steps=100, log=False)[source]
Bases:
Module
SubgraphX 来自 关于通过子图探索解释图神经网络的可解释性
它从原始图中识别出最重要的子图,该子图在基于GNN的图分类中起着关键作用。
它采用蒙特卡洛树搜索(MCTS)来高效探索不同的子图以进行解释,并使用Shapley值作为子图重要性的度量。
- Parameters:
model (nn.Module) –
用于解释多类图分类的GNN模型
它的前向函数必须具有以下形式
forward(self, graph, nfeat)
.它的前向函数的输出是logits。
num_hops (int) – 模型中消息传递层的数量
coef (float, optional) – 这个超参数控制探索和利用之间的权衡。较高的值鼓励算法探索相对未被访问的节点。默认值:10.0
high2low (bool, 可选) – 如果为 True,它将使用“High2low”策略进行剪枝操作, 在搜索树中扩展子节点时,从高到低扩展子节点。否则,它将使用 “Low2high”策略。默认值:True
num_child (int, optional) – 这是在搜索树中扩展子节点时要展开的子节点数量。默认值:12
num_rollouts (int, optional) – 这是MCTS的模拟次数。默认值:20
node_min (int, optional) – 这是基于子图中节点数量定义叶节点的阈值。默认值:3
shapley_steps (int, optional) – 这是在估计Shapley值时用于蒙特卡洛采样的步骤数。默认值:100
log (bool, 可选) – 如果为True,它将记录进度。默认值:False
- explain_graph(graph, feat, target_class, **kwargs)[source]
从原始图中找到最重要的子图,以便模型将图分类为目标类别。
- Parameters:
- Returns:
表示最重要子图的节点
- Return type:
张量
示例
>>> import torch >>> import torch.nn as nn >>> import torch.nn.functional as F >>> from dgl.data import GINDataset >>> from dgl.dataloading import GraphDataLoader >>> from dgl.nn import GraphConv, AvgPooling, SubgraphX
>>> # Define the model >>> class Model(nn.Module): ... def __init__(self, in_dim, n_classes, hidden_dim=128): ... super().__init__() ... self.conv1 = GraphConv(in_dim, hidden_dim) ... self.conv2 = GraphConv(hidden_dim, n_classes) ... self.pool = AvgPooling() ... ... def forward(self, g, h): ... h = F.relu(self.conv1(g, h)) ... h = self.conv2(g, h) ... return self.pool(g, h)
>>> # Load dataset >>> data = GINDataset('MUTAG', self_loop=True) >>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model >>> feat_size = data[0][0].ndata['attr'].shape[1] >>> model = Model(feat_size, data.gclasses) >>> criterion = nn.CrossEntropyLoss() >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) >>> for bg, labels in dataloader: ... logits = model(bg, bg.ndata['attr']) ... loss = criterion(logits, labels) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = SubgraphX(model, num_hops=2)
>>> # Explain the prediction for graph 0 >>> graph, l = data[0] >>> graph_feat = graph.ndata.pop("attr") >>> g_nodes_explain = explainer.explain_graph(graph, graph_feat, ... target_class=l)