PGExplainer
- class dgl.nn.pytorch.explain.PGExplainer(model, num_features, num_hops=None, explain_graph=True, coff_budget=0.01, coff_connect=0.0005, sample_bias=0.0)[source]
Bases:
Module
PGExplainer 来自 图神经网络的参数化解释器
PGExplainer采用深度神经网络(解释网络)来参数化解释的生成过程,这使得它能够集体解释多个实例。PGExplainer将底层结构建模为边缘分布,从中采样出解释图。
- Parameters:
model (nn.Module) –
用于解释多类图分类的GNN模型
它的前向函数必须具有以下形式
forward(self, graph, nfeat, embed, edge_weight)
.它的前向函数的输出是logits,如果embed=False,否则是中间节点嵌入。
num_features (int) – 由
model
使用的节点嵌入大小。num_hops (int, optional) – GNN信息聚合的跳数,必须与要解释的GNN所使用的消息传递层数相匹配。
explain_graph (bool, optional) – 是否初始化模型以进行图级或节点级预测。
coff_budget (float, optional) – 用于约束解释大小的正则化参数。默认值:0.01。
coff_connect (float, 可选) – 用于约束解释连通性的熵正则化。默认值:5e-4。
sample_bias (float, optional) – 某些群体成员在样本中被系统性地选择的可能性比其他成员更高。默认值:0.0。
- explain_graph(graph, feat, temperature=1.0, training=False, **kwargs)[source]
学习并返回一个边缘掩码,该掩码在解释GNN对图做出的预测中起着至关重要的作用。同时,返回基于边缘掩码选择的边缘所做出的预测。
- Parameters:
- Returns:
Tensor – 给定掩码图的分类概率。它是一个形状为 \((B, L)\) 的张量,其中 \(L\) 是数据集中不同类型的标签,\(B\) 是批量大小。
Tensor – 边权重,它是一个形状为 \((E)\) 的张量,其中 \(E\) 是图中的边数。较高的权重表示边的贡献较大。
示例
>>> import torch as th >>> import torch.nn as nn >>> import dgl >>> from dgl.data import GINDataset >>> from dgl.dataloading import GraphDataLoader >>> from dgl.nn import GraphConv, PGExplainer >>> import numpy as np
>>> # Define the model >>> class Model(nn.Module): ... def __init__(self, in_feats, out_feats): ... super().__init__() ... self.conv = GraphConv(in_feats, out_feats) ... self.fc = nn.Linear(out_feats, out_feats) ... nn.init.xavier_uniform_(self.fc.weight) ... ... def forward(self, g, h, embed=False, edge_weight=None): ... h = self.conv(g, h, edge_weight=edge_weight) ... ... if embed: ... return h ... ... with g.local_scope(): ... g.ndata['h'] = h ... hg = dgl.mean_nodes(g, 'h') ... return self.fc(hg)
>>> # Load dataset >>> data = GINDataset('MUTAG', self_loop=True) >>> dataloader = GraphDataLoader(data, batch_size=64, shuffle=True)
>>> # Train the model >>> feat_size = data[0][0].ndata['attr'].shape[1] >>> model = Model(feat_size, data.gclasses) >>> criterion = nn.CrossEntropyLoss() >>> optimizer = th.optim.Adam(model.parameters(), lr=1e-2) >>> for bg, labels in dataloader: ... preds = model(bg, bg.ndata['attr']) ... loss = criterion(preds, labels) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = PGExplainer(model, data.gclasses)
>>> # Train the explainer >>> # Define explainer temperature parameter >>> init_tmp, final_tmp = 5.0, 1.0 >>> optimizer_exp = th.optim.Adam(explainer.parameters(), lr=0.01) >>> for epoch in range(20): ... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / 20)) ... for bg, labels in dataloader: ... loss = explainer.train_step(bg, bg.ndata['attr'], tmp) ... optimizer_exp.zero_grad() ... loss.backward() ... optimizer_exp.step()
>>> # Explain the prediction for graph 0 >>> graph, l = data[0] >>> graph_feat = graph.ndata.pop("attr") >>> probs, edge_weight = explainer.explain_graph(graph, graph_feat)
- explain_node(nodes, graph, feat, temperature=1.0, training=False, **kwargs)[source]
学习并返回一个边缘掩码,该掩码在解释GNN为提供的节点ID集所做的预测中起着至关重要的作用。 同时,返回使用图和边缘掩码所做的预测。
- Parameters:
- Returns:
Tensor – 给定掩码图的分类概率。它是一个形状为 \((N, L)\) 的张量,其中 \(L\) 是数据集中不同类型的节点标签,\(N\) 是图中的节点数量。
Tensor – 边权重,它是一个形状为 \((E)\) 的张量,其中 \(E\) 是图中的边数量。较高的权重表示该边的贡献较大。
DGLGraph – 输入中心节点的 k-hop 内邻居上诱导的子图的批处理集。
Tensor – 子图中心节点的新 ID。
示例
>>> import dgl >>> import numpy as np >>> import torch
>>> # Define the model >>> class Model(torch.nn.Module): ... def __init__(self, in_feats, out_feats): ... super().__init__() ... self.conv1 = dgl.nn.GraphConv(in_feats, out_feats) ... self.conv2 = dgl.nn.GraphConv(out_feats, out_feats) ... ... def forward(self, g, h, embed=False, edge_weight=None): ... h = self.conv1(g, h, edge_weight=edge_weight) ... if embed: ... return h ... return self.conv2(g, h)
>>> # Load dataset >>> data = dgl.data.CoraGraphDataset(verbose=False) >>> g = data[0] >>> features = g.ndata["feat"] >>> labels = g.ndata["label"]
>>> # Train the model >>> model = Model(features.shape[1], data.num_classes) >>> criterion = torch.nn.CrossEntropyLoss() >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-2) >>> for epoch in range(20): ... logits = model(g, features) ... loss = criterion(logits, labels) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step()
>>> # Initialize the explainer >>> explainer = dgl.nn.PGExplainer( ... model, data.num_classes, num_hops=2, explain_graph=False ... )
>>> # Train the explainer >>> # Define explainer temperature parameter >>> init_tmp, final_tmp = 5.0, 1.0 >>> optimizer_exp = torch.optim.Adam(explainer.parameters(), lr=0.01) >>> epochs = 10 >>> for epoch in range(epochs): ... tmp = float(init_tmp * np.power(final_tmp / init_tmp, epoch / epochs)) ... loss = explainer.train_step_node(g.nodes(), g, features, tmp) ... optimizer_exp.zero_grad() ... loss.backward() ... optimizer_exp.step()
>>> # Explain the prediction for graph 0 >>> probs, edge_weight, bg, inverse_indices = explainer.explain_node( ... 0, g, features ... )
- forward(*input: Any) None
定义每次调用时执行的计算。
应该由所有子类覆盖。
注意
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- train_step(graph, feat, temperature, **kwargs)[source]
计算图分类解释网络的损失
- Parameters:
- Returns:
表示损失的标量张量。
- Return type:
张量
- train_step_node(nodes, graph, feat, temperature, **kwargs)[source]
计算节点分类的解释网络的损失
- Parameters:
nodes (int, iterable[int], tensor) – 用于训练解释网络的图中的节点,不能有任何重复值。
graph (DGLGraph) – 输入的同构图。
feat (Tensor) – The input feature of shape \((N, D)\). \(N\) is the number of nodes, and \(D\) is the feature size.
temperature (float) – The temperature parameter fed to the sampling procedure.
kwargs (dict) – Additional arguments passed to the GNN model.
- Returns:
表示损失的标量张量。
- Return type:
张量