torch_geometric.explain.algorithm.AttentionExplainer
- class AttentionExplainer(reduce: str = 'max')[source]
Bases:
ExplainerAlgorithm一个解释器,使用基于注意力的GNN(例如,
GATConv,GATv2Conv,或TransformerConv)生成的注意力系数作为边的解释。 跨层和头的注意力分数将根据reduce参数进行聚合。- Parameters:
reduce (str, optional) – 用于减少跨层和跨头的注意力分数的方法。(默认值:
"max")
- forward(model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs) Explanation[source]
计算解释。
- Parameters:
model (torch.nn.Module) – The model to explain.
x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.
edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.
target (torch.Tensor) – The target of the model.
index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default:
None)**kwargs (optional) – Additional keyword arguments passed to
model.
- Return type: