torch_geometric.explain

警告

该模块正在积极开发中,可能不稳定。 访问需要从主分支安装

哲学

该模块提供了一套工具,用于解释PyG模型的预测或解释数据集的潜在现象(更多详情请参见“GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks”论文)。

我们使用torch_geometric.explain.Explanation类来表示解释,这是一个Data对象,包含节点、边、特征和数据任何属性的掩码。

torch_geometric.explain.Explainer 类旨在处理所有可解释性参数(有关更多详细信息,请参见 torch_geometric.explain.config.ExplainerConfig 类):

  • which algorithm from the torch_geometric.explain.algorithm module to use (e.g., GNNExplainer)

  • 要计算的解释类型(例如explanation_type="phenomenon"explanation_type="model"

  • the different type of masks for node and edges (e.g., mask="object" or mask="attributes")

  • any postprocessing of the masks (e.g., threshold_type="topk" or threshold_type="hard")

该类允许用户轻松比较不同的可解释性方法,并轻松在不同类型的掩码之间切换,同时确保高级框架保持一致。

解释器

class Explainer(model: Module, algorithm: ExplainerAlgorithm, explanation_type: Union[ExplanationType, str], model_config: Union[ModelConfig, Dict[str, Any]], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None, threshold_config: Optional[ThresholdConfig] = None)[source]

Bases: object

一个用于图神经网络实例级解释的解释器类。

Parameters:
  • model (torch.nn.Module) – The model to explain.

  • algorithm (ExplainerAlgorithm) – 解释算法。

  • explanation_type (ExplanationTypestr) –

    要计算的解释类型。可能的值为:

    • "model": 解释模型预测。

    • "phenomenon": 解释模型试图预测的现象。

    实际上,这意味着解释算法将根据模型输出("model")或目标输出("phenomenon")计算其损失。

  • model_config (ModelConfig) – 模型配置。 请参阅 ModelConfig 以获取 可用的选项。(默认值:None

  • node_mask_type (MaskTypestr, 可选) –

    应用于节点的掩码类型。可能的值为(默认值:None):

    • None: 不会对节点应用任何掩码。

    • "object": 将对每个节点应用掩码。

    • "common_attributes": 将对每个特征应用掩码。

    • "attributes": 将对所有节点的每个特征应用掩码。

  • edge_mask_type (MaskTypestr, 可选) – 应用于边的掩码类型。具有与 node_mask_type 相同的可能值。 (默认: None)

  • threshold_config (ThresholdConfig, 可选) – 阈值配置。 有关可用选项,请参见 ThresholdConfig。(默认值:None

get_prediction(*args, **kwargs) Tensor[source]

返回模型在输入图上的预测。

如果模型模式是"regression",预测结果将作为标量值返回。 如果模型模式是"multiclass_classification""binary_classification",预测结果将作为预测的类别标签返回。

Parameters:
  • *args – 传递给模型的参数。

  • **kwargs (可选) – 传递给模型的额外关键字参数。

Return type:

Tensor

get_masked_prediction(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], node_mask: Optional[Union[Tensor, Dict[str, Tensor]]] = None, edge_mask: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs) Tensor[source]

返回模型在应用了节点和边掩码的输入图上的预测。

Return type:

Tensor

__call__(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, **kwargs) Union[Explanation, HeteroExplanation][source]

计算给定输入和目标下GNN的解释。

注意

如果你收到类似“试图第二次通过图进行反向传播”的错误信息,请确保你提供的目标是使用torch.no_grad()计算的。

Parameters:
  • x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.

  • edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.

  • target (torch.Tensor) – 模型的目标。 如果解释类型是 "phenomenon",则必须提供目标。 如果解释类型是 "model",则目标应设置为 None,并且会自动推断。对于分类任务,目标需要包含类别标签。(默认值:None

  • index (Union[int, Tensor], optional) – 要解释的模型输出的第一维度的索引。 可以是单个索引或索引的张量。 如果设置为 None,将解释所有模型输出。 (默认值: None)

  • **kwargs – 传递给GNN的额外参数。

Return type:

Union[Explanation, HeteroExplanation]

get_target(prediction: Tensor) Tensor[source]

从给定的预测中返回模型的目标。

如果模型模式是"regression"类型,预测结果将直接返回。 如果模型模式是"multiclass_classification""binary_classification"类型,预测结果将返回为预测的类别标签。

Return type:

Tensor

class ExplainerConfig(explanation_type: Union[ExplanationType, str], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None)[source]

配置类用于存储和验证高级解释参数。

Parameters:
  • explanation_type (ExplanationType or str) –

    The type of explanation to compute. The possible values are:

    • "model": Explains the model prediction.

    • "phenomenon": Explains the phenomenon that the model is trying to predict.

    In practice, this means that the explanation algorithm will either compute their losses with respect to the model output ("model") or the target output ("phenomenon").

  • node_mask_type (MaskType or str, optional) –

    The type of mask to apply on nodes. The possible values are (default: None):

    • None: Will not apply any mask on nodes.

    • "object": Will mask each node.

    • "common_attributes": Will mask each feature.

    • "attributes": Will mask each feature across all nodes.

  • edge_mask_type (MaskType or str, optional) – The type of mask to apply on edges. Has the sample possible values as node_mask_type. (default: None)

class ModelConfig(mode: Union[ModelMode, str], task_level: Union[ModelTaskLevel, str], return_type: Optional[Union[ModelReturnType, str]] = None)[source]

用于存储模型参数的配置类。

Parameters:
  • mode (ModelModestr) –

    模型的模式。可能的值为:

    • "binary_classification": 一个二分类模型。

    • "multiclass_classification": 一个多分类模型。

    • "regression": 一个回归模型。

  • task_level (ModelTaskLevelstr) –

    模型的任务级别。 可能的值为:

    • "node": 节点级别的预测模型。

    • "edge": 边级别的预测模型。

    • "graph": 图级别的预测模型。

  • return_type (ModelReturnTypestr, 可选) –

    模型的返回类型。可能的值为(默认:None):

    • "raw": 模型返回原始值。

    • "probs": 模型返回概率。

    • "log_probs": 模型返回对数概率。

class ThresholdConfig(threshold_type: Union[ThresholdType, str], value: Union[float, int])[source]

配置类用于存储和验证阈值参数。

Parameters:
  • threshold_type (ThresholdTypestr) –

    要应用的阈值类型。 可能的值为:

    • None: 不应用任何阈值。

    • "hard": 对每个掩码应用硬阈值。 掩码中值低于 value 的元素被设置为 0,其他元素被设置为 1

    • "topk": 对每个掩码应用软阈值。 每个掩码中前 obj:value 个元素被保留,其他元素被设置为 0

    • "topk_hard": 与 "topk" 相同,但所有被保留的元素的值被设置为 1

  • value (intfloat, 可选) – 用于阈值处理的值。 (默认值: None)

解释

class Explanation(x: Optional[Tensor] = None, edge_index: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, y: Optional[Union[Tensor, int, float]] = None, pos: Optional[Tensor] = None, time: Optional[Tensor] = None, **kwargs)[source]

基础类:Data, ExplanationMixin

保存了所有获得的同质图的解释。

解释对象是一个Data对象,并且可以保存节点属性和边属性。如果需要,它还可以保存原始图。

Parameters:
  • node_mask (Tensor, optional) – 节点级别的掩码,形状为 [num_nodes, 1], [1, num_features][num_nodes, num_features]。(默认值:None

  • edge_mask (Tensor, optional) – 边级别的掩码,形状为 [num_edges]。(默认值:None

  • **kwargs (optional) – Additional attributes.

validate(raise_on_error: bool = True) bool[source]

验证Explanation对象的正确性。

Return type:

bool 翻译后的内容: bool 在这个例子中,`bool` 是一个Python函数名称,根据翻译规则1,不需要翻译。因此,翻译后的内容保持不变。

get_explanation_subgraph() Explanation[source]

返回诱导子图,其中所有属性为零的节点和边都被屏蔽掉。

Return type:

Explanation

get_complement_subgraph() Explanation[source]

返回诱导子图,其中所有具有任何属性的节点和边都被屏蔽。

Return type:

Explanation

visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[List[str]] = None, top_k: Optional[int] = None)[source]

通过汇总所有节点的节点掩码,创建节点特征重要性的条形图。

Parameters:
  • path (str, optional) – 图表保存的路径。 如果设置为 None,将实时可视化图表。 (默认: None)

  • feat_labels (List[str], optional) – 特征的标签。 (默认 None)

  • top_k (int, 可选) – 要绘制的Top k特征。如果为None,则绘制所有特征。(默认值:None

visualize_graph(path: Optional[str] = None, backend: Optional[str] = None, node_labels: Optional[List[str]] = None) None[source]

可视化解释图,边的透明度对应边的重要性。

Parameters:
  • path (str, optional) – 图表保存的路径。 如果设置为 None,将实时可视化图表。 (默认: None)

  • backend (str, optional) – 用于可视化的图形绘制后端 ("graphviz", "networkx"). 如果设置为 None,将根据可用的系统包使用最合适的可视化后端。 (默认: None)

  • node_labels (list[str], optional) – 节点的标签/ID。 (default: None)

Return type:

None

class HeteroExplanation(_mapping: Optional[Dict[str, Any]] = None, **kwargs)[source]

基础类: HeteroData, ExplanationMixin

保存了异构图的所有获得的解释。

解释对象是一个HeteroData对象, 并且可以保存节点属性和边属性。 如果需要,它还可以保存原始图。

validate(raise_on_error: bool = True) bool[source]

Validates the correctness of the Explanation object.

Return type:

bool 翻译后的内容: bool 在这个例子中,`bool` 是一个Python函数名称,根据翻译规则1,不需要翻译。因此,翻译后的内容保持不变。

get_explanation_subgraph() HeteroExplanation[source]

返回诱导子图,其中所有属性为零的节点和边都被屏蔽掉。

Return type:

HeteroExplanation

get_complement_subgraph() HeteroExplanation[source]

返回诱导子图,其中所有具有任何属性的节点和边都被屏蔽。

Return type:

HeteroExplanation

visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[Dict[str, List[str]]] = None, top_k: Optional[int] = None)[source]

通过为每种节点类型汇总所有节点的节点掩码,创建节点特征重要性的条形图。

Parameters:
  • path (str, optional) – 保存绘图的路径。 如果设置为 None,将实时可视化绘图。 (默认: None)

  • feat_labels (Dict[NodeType, List[str]], optional) – 每个节点类型的特征标签。(默认 None

  • top_k (int, optional) – Top k features to plot. If None plots all features. (default: None)

解释器算法

ExplainerAlgorithm

用于实现解释器算法的抽象基类。

DummyExplainer

一个返回随机解释的虚拟解释器(用于测试目的)。

GNNExplainer

来自"GNNExplainer: Generating Explanations for Graph Neural Networks"论文的GNN-Explainer模型,用于识别在GNN预测中起关键作用的紧凑子图结构和节点特征。

CaptumExplainer

一个基于Captum的解释器,用于识别在GNN预测中起关键作用的紧凑子图结构和节点特征。

PGExplainer

来自"Parameterized Explainer for Graph Neural Network"论文的PGExplainer模型。

AttentionExplainer

一个解释器,使用基于注意力的GNN(例如GATConvGATv2Conv,或TransformerConv)生成的注意力系数作为边解释。

GraphMaskExplainer

来自"Interpreting Graph Neural Networks for NLP With Differentiable Edge Masking"论文的GraphMask-Explainer模型,用于识别在GNN预测中起关键作用的层次紧凑子图结构和节点特征。

解释指标

解释的质量可以通过多种不同的方法来评判。 PyG 支持以下开箱即用的指标:

groundtruth_metrics

比较并评估解释掩码与真实解释掩码。

fidelity

评估Explainer在给定Explanation时的保真度,如"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks"论文中所述。

characterization_score

返回组件特征评分,如"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks"论文中所述。

fidelity_curve_auc

返回如"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks"论文中所述的保真度曲线的AUC。

unfaithfulness

评估Explanation对底层GNN预测器的忠实度,如"评估图神经网络的可解释性"论文中所述。