torch.utils.mobile_optimizer 的源代码
"""此模块包含用于移动模型优化和lint的实用方法。"""
import torch
from enum import Enum
from torch._C import _MobileOptimizerType as MobileOptimizerType
from typing import Optional, Set, List, AnyStr
class LintCode(Enum):
BUNDLED_INPUT = 1
REQUIRES_GRAD = 2
DROPOUT = 3
BATCHNORM = 4
[docs]def optimize_for_mobile(
script_module: torch.jit.ScriptModule,
optimization_blocklist: Optional[Set[MobileOptimizerType]] = None,
preserved_methods: Optional[List[AnyStr]] = None,
backend: str = 'CPU') -> torch.jit.RecursiveScriptModule:
"""
优化一个用于移动部署的torch脚本模块。
参数:
script_module: 一个类型为ScriptModule的torch脚本模块实例。
optimization_blocklist: 一个类型为MobileOptimizerType的集合。当未传递集合时,优化方法将运行所有优化过程;否则,优化方法将运行未包含在optimization_blocklist中的优化过程。
preserved_methods: 当调用freeze_module过程时需要保留的方法列表
backend: 用于运行结果模型的设备类型('CPU'(默认)、'Vulkan'或'Metal')。
返回:
一个新的优化后的torch脚本模块
"""
if not isinstance(script_module, torch.jit.ScriptModule):
raise TypeError(
f'Got {type(script_module)}, but ScriptModule is expected.')
if optimization_blocklist is None:
optimization_blocklist = set()
if preserved_methods is None:
preserved_methods = []
# 将潜在的字节数组转换为字符串(如果有)以通过类型检查
# 这里我们使用一个新名称,因为将其赋值回preserved_methods会引发
# mypy错误(即List[AnyStr] = List[str])
preserved_methods_str: List[str] = [str(method) for method in preserved_methods]
bundled_inputs_attributes = _get_bundled_inputs_preserved_attributes(script_module, preserved_methods_str)
if all(hasattr(script_module, method) for method in bundled_inputs_attributes):
preserved_methods_str = list(set(preserved_methods_str + bundled_inputs_attributes))
non_exist_methods = []
for method in preserved_methods_str:
if not hasattr(script_module, method):
non_exist_methods.append(method)
if non_exist_methods:
raise AttributeError(
f"The following methods to preserve do not exist in script_module: {', '.join(non_exist_methods)}")
backend = backend.lower()
if backend == 'cpu':
optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile(
script_module._c,
optimization_blocklist,
preserved_methods_str)
elif backend == 'vulkan':
optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile(
script_module._c,
optimization_blocklist,
preserved_methods_str)
elif backend == 'metal':
optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods_str)
else:
raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan' or 'Metal'")
return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)
def generate_mobile_module_lints(script_module: torch.jit.ScriptModule):
"""
为给定的torch脚本模块生成lint列表。
参数:
script_module: 一个类型为ScriptModule的torch脚本模块实例。
返回:
lint_map: 包含模块lint的列表字典
"""
if not isinstance(script_module, torch.jit.ScriptModule):
raise TypeError(
f'Got {type(script_module)}, but ScriptModule is expected.')
lint_list = []
if not hasattr(script_module, "_generate_bundled_inputs_for_forward"):
lint_list.append({"name": LintCode.<span class