torch_geometric.nn.models.to_captum_input

class to_captum_input(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], mask_type: Union[str, MaskLevelType], *args)[source]

基础:

给定 xedge_indexmask_type,将其转换为可在 Captum 归因方法中使用的格式。返回 attribute 函数所需的 inputsadditional_forward_args。有关示例用法,请参见 to_captum_model()

Parameters:
  • x (torch.TensorDict[NodeType, torch.Tensor]) – 节点特征。 对于异构图,这是一个字典,保存了每种节点类型的节点特征。

  • edge_index (torch.TensorDict[EdgeType, torch.Tensor]) – 边的索引。对于异构图,这是一个字典,保存每种边类型的edge index

  • mask_type (str) – 表示使用Captum解释器创建的掩码类型。有效的输入是 "edge", "node", 和 "node_and_edge"

  • *args – 模型解释的额外前向参数,这些参数将被添加到 additional_forward_args 中。

Return type:

Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]