Shortcuts

torch.utils.mobile_optimizer 的源代码

"""此模块包含用于移动模型优化和lint的实用方法。"""

import torch
from enum import Enum
from torch._C import _MobileOptimizerType as MobileOptimizerType
from typing import Optional, Set, List, AnyStr

class LintCode(Enum):
    BUNDLED_INPUT = 1
    REQUIRES_GRAD = 2
    DROPOUT = 3
    BATCHNORM = 4

[docs]def optimize_for_mobile( script_module: torch.jit.ScriptModule, optimization_blocklist: Optional[Set[MobileOptimizerType]] = None, preserved_methods: Optional[List[AnyStr]] = None, backend: str = 'CPU') -> torch.jit.RecursiveScriptModule: """ 优化一个用于移动部署的torch脚本模块。 参数: script_module: 一个类型为ScriptModule的torch脚本模块实例。 optimization_blocklist: 一个类型为MobileOptimizerType的集合。当未传递集合时,优化方法将运行所有优化过程;否则,优化方法将运行未包含在optimization_blocklist中的优化过程。 preserved_methods: 当调用freeze_module过程时需要保留的方法列表 backend: 用于运行结果模型的设备类型('CPU'(默认)、'Vulkan'或'Metal')。 返回: 一个新的优化后的torch脚本模块 """ if not isinstance(script_module, torch.jit.ScriptModule): raise TypeError( f'Got {type(script_module)}, but ScriptModule is expected.') if optimization_blocklist is None: optimization_blocklist = set() if preserved_methods is None: preserved_methods = [] # 将潜在的字节数组转换为字符串(如果有)以通过类型检查 # 这里我们使用一个新名称,因为将其赋值回preserved_methods会引发 # mypy错误(即List[AnyStr] = List[str]) preserved_methods_str: List[str] = [str(method) for method in preserved_methods] bundled_inputs_attributes = _get_bundled_inputs_preserved_attributes(script_module, preserved_methods_str) if all(hasattr(script_module, method) for method in bundled_inputs_attributes): preserved_methods_str = list(set(preserved_methods_str + bundled_inputs_attributes)) non_exist_methods = [] for method in preserved_methods_str: if not hasattr(script_module, method): non_exist_methods.append(method) if non_exist_methods: raise AttributeError( f"The following methods to preserve do not exist in script_module: {', '.join(non_exist_methods)}") backend = backend.lower() if backend == 'cpu': optimized_cpp_module = torch._C._jit_pass_optimize_for_mobile( script_module._c, optimization_blocklist, preserved_methods_str) elif backend == 'vulkan': optimized_cpp_module = torch._C._jit_pass_vulkan_optimize_for_mobile( script_module._c, optimization_blocklist, preserved_methods_str) elif backend == 'metal': optimized_cpp_module = torch._C._jit_pass_metal_optimize_for_mobile(script_module._c, preserved_methods_str) else: raise TypeError("Unknown backend, must be one of 'CPU', 'Vulkan' or 'Metal'") return torch.jit._recursive.wrap_cpp_module(optimized_cpp_module)
def generate_mobile_module_lints(script_module: torch.jit.ScriptModule): """ 为给定的torch脚本模块生成lint列表。 参数: script_module: 一个类型为ScriptModule的torch脚本模块实例。 返回: lint_map: 包含模块lint的列表字典 """ if not isinstance(script_module, torch.jit.ScriptModule): raise TypeError( f'Got {type(script_module)}, but ScriptModule is expected.') lint_list = [] if not hasattr(script_module, "_generate_bundled_inputs_for_forward"): lint_list.append({"name": LintCode.<span class