torch_geometric.explain
警告
该模块正在积极开发中,可能不稳定。 访问需要从主分支安装PyG。
哲学
该模块提供了一套工具,用于解释PyG模型的预测或解释数据集的潜在现象(更多详情请参见“GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks”论文)。
我们使用torch_geometric.explain.Explanation类来表示解释,这是一个Data对象,包含节点、边、特征和数据任何属性的掩码。
torch_geometric.explain.Explainer 类旨在处理所有可解释性参数(有关更多详细信息,请参见 torch_geometric.explain.config.ExplainerConfig 类):
which algorithm from the
torch_geometric.explain.algorithmmodule to use (e.g.,GNNExplainer)要计算的解释类型(例如,
explanation_type="phenomenon"或explanation_type="model")the different type of masks for node and edges (e.g.,
mask="object"ormask="attributes")any postprocessing of the masks (e.g.,
threshold_type="topk"orthreshold_type="hard")
该类允许用户轻松比较不同的可解释性方法,并轻松在不同类型的掩码之间切换,同时确保高级框架保持一致。
解释器
- class Explainer(model: Module, algorithm: ExplainerAlgorithm, explanation_type: Union[ExplanationType, str], model_config: Union[ModelConfig, Dict[str, Any]], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None, threshold_config: Optional[ThresholdConfig] = None)[source]
Bases:
object一个用于图神经网络实例级解释的解释器类。
- Parameters:
model (torch.nn.Module) – The model to explain.
algorithm (ExplainerAlgorithm) – 解释算法。
explanation_type (ExplanationType 或 str) –
要计算的解释类型。可能的值为:
"model": 解释模型预测。"phenomenon": 解释模型试图预测的现象。
实际上,这意味着解释算法将根据模型输出(
"model")或目标输出("phenomenon")计算其损失。model_config (ModelConfig) – 模型配置。 请参阅
ModelConfig以获取 可用的选项。(默认值:None)node_mask_type (MaskType 或 str, 可选) –
应用于节点的掩码类型。可能的值为(默认值:
None):None: 不会对节点应用任何掩码。"object": 将对每个节点应用掩码。"common_attributes": 将对每个特征应用掩码。"attributes": 将对所有节点的每个特征应用掩码。
edge_mask_type (MaskType 或 str, 可选) – 应用于边的掩码类型。具有与
node_mask_type相同的可能值。 (默认:None)threshold_config (ThresholdConfig, 可选) – 阈值配置。 有关可用选项,请参见
ThresholdConfig。(默认值:None)
- get_prediction(*args, **kwargs) Tensor[source]
返回模型在输入图上的预测。
如果模型模式是
"regression",预测结果将作为标量值返回。 如果模型模式是"multiclass_classification"或"binary_classification",预测结果将作为预测的类别标签返回。- Parameters:
*args – 传递给模型的参数。
**kwargs (可选) – 传递给模型的额外关键字参数。
- Return type:
- get_masked_prediction(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], node_mask: Optional[Union[Tensor, Dict[str, Tensor]]] = None, edge_mask: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs) Tensor[source]
返回模型在应用了节点和边掩码的输入图上的预测。
- Return type:
- __call__(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, **kwargs) Union[Explanation, HeteroExplanation][source]
计算给定输入和目标下GNN的解释。
注意
如果你收到类似“试图第二次通过图进行反向传播”的错误信息,请确保你提供的目标是使用
torch.no_grad()计算的。- Parameters:
x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.
edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.
target (torch.Tensor) – 模型的目标。 如果解释类型是
"phenomenon",则必须提供目标。 如果解释类型是"model",则目标应设置为None,并且会自动推断。对于分类任务,目标需要包含类别标签。(默认值:None)index (Union[int, Tensor], optional) – 要解释的模型输出的第一维度的索引。 可以是单个索引或索引的张量。 如果设置为
None,将解释所有模型输出。 (默认值:None)**kwargs – 传递给GNN的额外参数。
- Return type:
- class ExplainerConfig(explanation_type: Union[ExplanationType, str], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None)[source]
配置类用于存储和验证高级解释参数。
- Parameters:
explanation_type (ExplanationType or str) –
The type of explanation to compute. The possible values are:
"model": Explains the model prediction."phenomenon": Explains the phenomenon that the model is trying to predict.
In practice, this means that the explanation algorithm will either compute their losses with respect to the model output (
"model") or the target output ("phenomenon").node_mask_type (MaskType or str, optional) –
The type of mask to apply on nodes. The possible values are (default:
None):None: Will not apply any mask on nodes."object": Will mask each node."common_attributes": Will mask each feature."attributes": Will mask each feature across all nodes.
edge_mask_type (MaskType or str, optional) – The type of mask to apply on edges. Has the sample possible values as
node_mask_type. (default:None)
- class ModelConfig(mode: Union[ModelMode, str], task_level: Union[ModelTaskLevel, str], return_type: Optional[Union[ModelReturnType, str]] = None)[source]
用于存储模型参数的配置类。
- Parameters:
mode (ModelMode 或 str) –
模型的模式。可能的值为:
"binary_classification": 一个二分类模型。"multiclass_classification": 一个多分类模型。"regression": 一个回归模型。
task_level (ModelTaskLevel 或 str) –
模型的任务级别。 可能的值为:
"node": 节点级别的预测模型。"edge": 边级别的预测模型。"graph": 图级别的预测模型。
return_type (ModelReturnType 或 str, 可选) –
模型的返回类型。可能的值为(默认:
None):"raw": 模型返回原始值。"probs": 模型返回概率。"log_probs": 模型返回对数概率。
解释
- class Explanation(x: Optional[Tensor] = None, edge_index: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, y: Optional[Union[Tensor, int, float]] = None, pos: Optional[Tensor] = None, time: Optional[Tensor] = None, **kwargs)[source]
基础类:
Data,ExplanationMixin保存了所有获得的同质图的解释。
解释对象是一个
Data对象,并且可以保存节点属性和边属性。如果需要,它还可以保存原始图。- Parameters:
- validate(raise_on_error: bool = True) bool[source]
验证
Explanation对象的正确性。
- get_explanation_subgraph() Explanation[source]
返回诱导子图,其中所有属性为零的节点和边都被屏蔽掉。
- Return type:
- get_complement_subgraph() Explanation[source]
返回诱导子图,其中所有具有任何属性的节点和边都被屏蔽。
- Return type:
- visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[List[str]] = None, top_k: Optional[int] = None)[source]
通过汇总所有节点的节点掩码,创建节点特征重要性的条形图。
- class HeteroExplanation(_mapping: Optional[Dict[str, Any]] = None, **kwargs)[source]
基础类:
HeteroData,ExplanationMixin保存了异构图的所有获得的解释。
解释对象是一个
HeteroData对象, 并且可以保存节点属性和边属性。 如果需要,它还可以保存原始图。- validate(raise_on_error: bool = True) bool[source]
Validates the correctness of the
Explanationobject.
- get_explanation_subgraph() HeteroExplanation[source]
返回诱导子图,其中所有属性为零的节点和边都被屏蔽掉。
- Return type:
- get_complement_subgraph() HeteroExplanation[source]
返回诱导子图,其中所有具有任何属性的节点和边都被屏蔽。
- Return type:
解释器算法
用于实现解释器算法的抽象基类。 |
|
一个返回随机解释的虚拟解释器(用于测试目的)。 |
|
来自"GNNExplainer: Generating Explanations for Graph Neural Networks"论文的GNN-Explainer模型,用于识别在GNN预测中起关键作用的紧凑子图结构和节点特征。 |
|
一个基于Captum的解释器,用于识别在GNN预测中起关键作用的紧凑子图结构和节点特征。 |
|
来自"Parameterized Explainer for Graph Neural Network"论文的PGExplainer模型。 |
|
一个解释器,使用基于注意力的GNN(例如, |
|
来自"Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking"论文的GraphMask-Explainer模型,用于识别在GNN预测中起关键作用的层次紧凑子图结构和节点特征。 |
解释指标
解释的质量可以通过多种不同的方法来评判。 PyG 支持以下开箱即用的指标:
比较并评估解释掩码与真实解释掩码。 |
|
评估 |
|
返回组件特征评分,如"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks"论文中所述。 |
|
返回如"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks"论文中所述的保真度曲线的AUC。 |
|
评估 |