解释图神经网络

解释GNN模型对于许多用例至关重要。 (2.3及以上版本)提供了torch_geometric.explain包,用于一流的GNN可解释性支持,目前包括

  1. 一个灵活的接口,用于通过Explainer类生成各种解释,

  2. 几种基础的解释算法,包括,例如GNNExplainer, PGExplainerCaptumExplainer,

  3. 支持通过ExplanationHeteroExplanation类可视化解释,

  4. 并通过metric包来评估解释的指标。

警告

这里讨论的解释API可能会在未来发生变化,因为我们不断努力提高它们的易用性和通用性。

解释器界面

torch_geometric.explain.Explainer 类旨在处理所有可解释性参数(更多详情请参见 ExplainerConfig 类):

  1. 使用torch_geometric.explain.algorithm模块中的哪个算法(例如GNNExplainer

  2. 要计算的解释类型, explanation_type="phenomenon" 用于解释数据集的潜在现象,以及 explanation_type="model" 用于解释GNN模型的预测(更多详情请参见“GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks”论文)。

  3. 节点和边的不同类型掩码(例如mask="object"mask="attributes"

  4. 对掩码的任何后处理(例如threshold_type="topk"threshold_type="hard"

该类允许用户轻松比较不同的可解释性方法,并轻松切换不同类型的掩码,同时确保高级框架保持一致。Explainer生成一个ExplanationHeteroExplanation对象,该对象包含关于哪些节点、边和特征对于解释GNN模型至关重要的最终信息。

注意

你可以在这篇博客文章中了解更多关于torch_geometric.explain包的信息。

Examples

接下来,我们将讨论一些使用案例及其相应的代码示例。

解释同质图上的节点分类

假设我们有一个GNN model,它在同质图上进行节点分类。 我们可以使用torch_geometric.explain.algorithm.GNNExplainer算法来生成一个Explanation。 我们配置Explainer以同时使用node_mask_typeedge_mask_type,使得最终的Explanation对象包含(1)一个node_mask(指示哪些节点和特征对预测至关重要),以及(2)一个edge_mask(指示哪些边对预测至关重要)。

from torch_geometric.data import Data
from torch_geometric.explain import Explainer, GNNExplainer

data = Data(...)  # A homogeneous graph data object.

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',  # Model returns log probabilities.
    ),
)

# Generate explanation for the node at index `10`:
explanation = explainer(data.x, data.edge_index, index=10)
print(explanation.edge_mask)
print(explanation.node_mask)

最后,我们可以可视化特征重要性和解释的关键子图:

explanation.visualize_feature_importance(top_k=10)

explanation.visualize_graph()

为了评估来自GNNExplainer的解释,我们可以利用torch_geometric.explain.metric模块。 例如,要计算unfaithfulness()的解释,运行:

from torch_geometric.explain import unfaithfulness

metric = unfaithfulness(explainer, explanation)
print(metric)

解释异质图上的节点分类

假设我们有一个异构GNN model,它在异构图上进行节点分类。 我们可以通过torch_geometric.explain.algorithm.CaptumExplainer算法使用 Captum中的IntegratedGradient归因方法来生成HeteroExplanation

注意

CaptumExplainer 是一个围绕 Captum 库的封装器,支持大多数归因方法,用于解释任何同质或异质的 模型。

我们配置Explainer以使用node_mask_typeedge_mask_type,使得最终的HeteroExplanation对象包含(1)每个节点类型的node_mask(指示每个节点类型中哪些节点和特征对预测至关重要),以及(2)每个边类型的edge_mask(指示每个边类型中哪些边对预测至关重要)。

from torch_geometric.data import HeteroData
from torch_geometric.explain import Explainer, CaptumExplainer

hetero_data = HeteroData(...)  # A heterogeneous graph data object.

explainer = Explainer(
    model,  # It is assumed that model outputs a single tensor.
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config = dict(
        mode='multiclass_classification',
        task_level=task_level,
        return_type='probs',  # Model returns probabilities.
    ),
)

# Generate batch-wise heterogeneous explanations for
# the nodes at index `1` and `3`:
hetero_explanation = explainer(
    hetero_data.x_dict,
    hetero_data.edge_index_dict,
    index=torch.tensor([1, 3]),
)
print(hetero_explanation.edge_mask_dict)
print(hetero_explanation.node_mask_dict)

解释同质图上的图回归

假设我们有一个GNN model,它在一个同质图上进行图回归。 我们可以使用torch_geometric.explain.algorithm.PGExplainer算法来生成一个Explanation。 我们配置Explainer使用一个edge_mask_type,使得最终的Explanation对象包含一个edge_mask(指示哪些边对预测至关重要)。 重要的是,传递一个node_mask_typeExplainer将会抛出错误,因为PGExplainer无法解释节点的重要性:

from torch_geometric.data import Data
from torch_geometric.explain import Explainer, PGExplainer

dataset = ...
loader = DataLoader(dataset, batch_size=1, shuffle=True)

explainer = Explainer(
    model=model,
    algorithm=PGExplainer(epochs=30, lr=0.003),
    explanation_type='phenomenon',
    edge_mask_type='object',
    model_config=dict(
        mode='regression',
        task_level='graph',
        return_type='raw',
    ),
    # Include only the top 10 most important edges:
    threshold_config=dict(threshold_type='topk', value=10),
)

# PGExplainer needs to be trained separately since it is a parametric
# explainer i.e it uses a neural network to generate explanations:
for epoch in range(30):
    for batch in loader:
        loss = explainer.algorithm.train(
            epoch, model, batch.x, batch.edge_index, target=batch.target)

# Generate the explanation for a particular graph:
explanation = explainer(dataset[0].x, dataset[0].edge_index)
print(explanation.edge_mask)

Since this feature is still undergoing heavy development, please feel free to reach out to the core team either on GitHub or Slack if you have any questions, comments or concerns.