torch_geometric.nn.models.to_captum_input
- class to_captum_input(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], mask_type: Union[str, MaskLevelType], *args)[source]
基础:
给定
x、edge_index和mask_type,将其转换为可在 Captum 归因方法中使用的格式。返回 Captum 的attribute函数所需的inputs和additional_forward_args。有关示例用法,请参见to_captum_model()。- Parameters:
x (torch.Tensor 或 Dict[NodeType, torch.Tensor]) – 节点特征。 对于异构图,这是一个字典,保存了每种节点类型的节点特征。
edge_index (torch.Tensor 或 Dict[EdgeType, torch.Tensor]) – 边的索引。对于异构图,这是一个字典,保存每种边类型的
edge index。mask_type (str) – 表示使用Captum解释器创建的掩码类型。有效的输入是
"edge","node", 和"node_and_edge"。*args – 模型解释的额外前向参数,这些参数将被添加到
additional_forward_args中。
- Return type: