torch_geometric.explain.algorithm.CaptumExplainer
- class CaptumExplainer(attribution_method: Union[str, Any], **kwargs)[source]
Bases:
ExplainerAlgorithm一个基于Captum的解释器,用于识别在GNN预测中起关键作用的紧凑子图结构和节点特征。
这个解释器算法使用 Captum 来计算归因。
目前,支持以下归因方法:
captum.attr.IntegratedGradientscaptum.attr.Saliencycaptum.attr.InputXGradientcaptum.attr.Deconvolutioncaptum.attr.ShapleyValueSamplingcaptum.attr.GuidedBackprop
- Parameters:
attribution_method (Attribution 或 str) – 使用的Captum归因方法。可以是字符串或
captum.attr方法。**kwargs – Captum归因方法的附加参数。
- forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs) Union[Explanation, HeteroExplanation][source]
计算解释。
- Parameters:
model (torch.nn.Module) – The model to explain.
x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.
edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.
target (torch.Tensor) – The target of the model.
index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default:
None)**kwargs (optional) – Additional keyword arguments passed to
model.
- Return type: