Shortcuts

获取图形节点名称

torchvision.models.feature_extraction.get_graph_node_names(model: Module, tracer_kwargs: Optional[Dict[str, Any]] = None, suppress_diff_warning: bool = False, concrete_args: Optional[Dict[str, Any]] = None) Tuple[List[str], List[str]][source]

开发工具,用于按执行顺序返回节点名称。请参阅create_feature_extractor()下的节点名称说明。这对于查看哪些节点名称可用于特征提取非常有用。有两个原因导致节点名称不能直接从模型的代码中轻松读取:

  1. 并非所有子模块都被追踪。来自torch.nn的模块都属于这一类。

  2. 表示重复应用相同操作或叶子模块的节点会获得一个_{counter}后缀。

模型被追踪两次:一次在训练模式下,一次在评估模式下。两组节点名称都被返回。

有关此处使用的节点命名约定的更多详细信息,请参阅相关子标题中的文档

Parameters:
  • model (nn.Module) – 我们想要打印节点名称的模型

  • tracer_kwargs (dict, 可选) – 一个包含关键字参数的字典,用于 NodePathTracer(它们最终会传递给 torch.fx.Tracer)。 默认情况下,它将设置为包装并使所有torchvision操作成为叶节点: {“autowrap_modules”: (math, torchvision.ops,),”leaf_modules”: _get_leaf_modules_for_ops(),} 警告:如果用户提供了tracer_kwargs,上述默认参数将附加到用户提供的字典中。

  • suppress_diff_warning (bool, optional) – 是否在训练和评估版本的图之间存在差异时抑制警告。默认为 False。

  • concrete_args (Optional[Dict[str, any]]) – 不应被视为代理的具体参数。根据Pytorch文档,此参数的API可能无法保证。

Returns:

从训练模式下的模型跟踪中获取的节点名称列表,以及从评估模式下的模型跟踪中获取的另一个列表。

Return type:

tuple(list, list)

示例:

>>> model = torchvision.models.resnet18()
>>> train_nodes, eval_nodes = get_graph_node_names(model)