torch_geometric.nn.models.to_captum_model
- class to_captum_model(model: Module, mask_type: Union[str, MaskLevelType] = MaskLevelType.edge, output_idx: Optional[int] = None, metadata: Optional[Tuple[List[str], List[Tuple[str, str, str]]]] = None)[source]
基础:
将模型转换为可用于Captum归因方法的模型。
同构图示例代码:
from captum.attr import IntegratedGradients from torch_geometric.data import Data from torch_geometric.nn import GCN from torch_geometric.nn import to_captum_model, to_captum_input data = Data(x=(...), edge_index(...)) model = GCN(...) ... # Train the model. # Explain predictions for node `10`: mask_type="edge" output_idx = 10 captum_model = to_captum_model(model, mask_type, output_idx) inputs, additional_forward_args = to_captum_input(data.x, data.edge_index,mask_type) ig = IntegratedGradients(captum_model) ig_attr = ig.attribute(inputs = inputs, target=int(y[output_idx]), additional_forward_args=additional_forward_args, internal_batch_size=1)
异构图的示例代码:
from captum.attr import IntegratedGradients from torch_geometric.data import HeteroData from torch_geometric.nn import HeteroConv from torch_geometric.nn import (captum_output_to_dicts, to_captum_model, to_captum_input) data = HeteroData(...) model = HeteroConv(...) ... # Train the model. # Explain predictions for node `10`: mask_type="edge" metadata = data.metadata output_idx = 10 captum_model = to_captum_model(model, mask_type, output_idx, metadata) inputs, additional_forward_args = to_captum_input(data.x_dict, data.edge_index_dict, mask_type) ig = IntegratedGradients(captum_model) ig_attr = ig.attribute(inputs=inputs, target=int(y[output_idx]), additional_forward_args=additional_forward_args, internal_batch_size=1) edge_attr_dict = captum_output_to_dicts(ig_attr, mask_type, metadata)
注意
有关在 PyG 中使用 Captum 归因方法的示例,请参见 examples/explain/captum_explainer.py。
- Parameters:
model (torch.nn.Module) – 需要解释的模型。
mask_type (str, optional) – 表示使用Captum解释器创建的掩码类型。有效输入为
"edge"、"node"和"node_and_edge"。(默认值:"edge")output_idx (int, 可选) – 要解释的输出元素(节点或链接索引)的索引。设置了
output_idx后,前向函数将返回模型在指定索引处的元素的输出。(默认值:None)metadata (Metadata, optional) – 异构图(heterogeneous graph)的元数据。 仅在解释
HeteroData对象时需要。 (默认值:None)
- Return type:
Union[CaptumModel,CaptumHeteroModel]