torch_geometric.nn.models.captum_output_to_dicts
- class captum_output_to_dicts(captum_attrs: Tuple[Tensor, ...], mask_type: Union[str, MaskLevelType], metadata: Tuple[List[str], List[Tuple[str, str, str]]])[source]
基础:
将Captum归因方法的输出转换为两个字典,分别包含节点和边的归因张量。此函数用于解释
HeteroData对象时使用。 有关示例用法,请参见to_captum_model()。- Parameters:
captum_attrs (tuple[torch.Tensor]) – 归因方法的输出。
mask_type (str) –
表示使用Captum解释器创建的掩码类型。有效的输入是
"edge","node", 和"node_and_edge":"edge":captum_attrs仅包含边属性。返回的元组没有节点属性,并且边属性字典的键是边类型,值是形状为[num_edges]的边掩码张量。"node":captum_attrs仅包含节点属性。返回的元组有一个节点属性字典,键是节点类型,值是形状为[num_nodes, num_features]的节点掩码张量,并且没有边属性。"node_and_edge":captum_attrs包含节点和边属性。
metadata (Metadata) – 异构图的元数据。
- Return type:
Tuple[Optional[Dict[str,Tensor]],Optional[Dict[Tuple[str,str,str],Tensor]]]