torch_geometric.nn.models.captum_output_to_dicts

class captum_output_to_dicts(captum_attrs: Tuple[Tensor, ...], mask_type: Union[str, MaskLevelType], metadata: Tuple[List[str], List[Tuple[str, str, str]]])[source]

基础:

Captum归因方法的输出转换为两个字典,分别包含节点和边的归因张量。此函数用于解释HeteroData对象时使用。 有关示例用法,请参见to_captum_model()

Parameters:
  • captum_attrs (tuple[torch.Tensor]) – 归因方法的输出。

  • mask_type (str) –

    表示使用Captum解释器创建的掩码类型。有效的输入是 "edge", "node", 和 "node_and_edge":

    1. "edge": captum_attrs 仅包含边属性。返回的元组没有节点属性,并且边属性字典的键是边类型,值是形状为 [num_edges] 的边掩码张量。

    2. "node": captum_attrs 仅包含节点属性。返回的元组有一个节点属性字典,键是节点类型,值是形状为 [num_nodes, num_features] 的节点掩码张量,并且没有边属性。

    3. "node_and_edge": captum_attrs 包含节点和

      边属性。

  • metadata (Metadata) – 异构图的元数据。

Return type:

Tuple[Optional[Dict[str, Tensor]], Optional[Dict[Tuple[str, str, str], Tensor]]]