Shortcuts

torch.jit._serialization 的源代码

"""序列化。

本模块包含用于序列化TorchScript模块的功能,特别是:
    * torch.jit.save
    * torch.jit.load

这不打算直接导入;请使用`torch.jit`中暴露的功能。
"""
import os

import torch
from torch.jit._recursive import wrap_cpp_module
from torch.serialization import validate_cuda_device


[docs]def save(m, f, _extra_files=None): r""" 保存此模块的离线版本,以便在单独的进程中使用。 保存的模块序列化此模块的所有方法、子模块、参数和属性。它可以使用``torch::jit::load(filename)``加载到C++ API中,或者使用:func:`torch.jit.load `加载到Python API中。 要能够保存模块,它不能调用任何本机Python函数。这意味着所有子模块也必须是:class:`ScriptModule`的子类。 .. 危险:: 所有模块,无论其设备如何,在加载期间始终加载到CPU上。这与:func:`torch.load`的语义不同,并且可能会在将来更改。 参数: m: 要保存的:class:`ScriptModule`。 f: 一个类文件对象(必须实现write和flush)或包含文件名的字符串。 _extra_files: 从文件名到内容的映射,这些内容将与`f`一起存储。 .. 注意:: torch.jit.save尝试跨版本保留某些操作符的行为。例如,在PyTorch 1.5中,两个整数张量的除法执行地板除法,如果包含该代码的模块在PyTorch 1.5中保存并在PyTorch 1.6中加载,其除法行为将被保留。在PyTorch 1.6中保存的相同模块在PyTorch 1.5中加载将失败,因为1.6中的除法行为发生了变化,而1.5不知道如何复制1.6的行为。 示例: .. 测试代码:: import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 m = torch.jit.script(MyModule()) # 保存到文件 torch.jit.save(m, 'scriptmodule.pt') # 这行代码等同于上一行 m.save("scriptmodule.pt") # 保存到io.BytesIO缓冲区 buffer = io.BytesIO() torch.jit.save(m, buffer) # 保存带有额外文件 extra_files = {'foo.txt': b'bar'} torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files) """ if _extra_files is None: _extra_files = {} if isinstance(f, (str, os.PathLike)): m.save(f, _extra_files=_extra_files) else: ret = m.save_to_buffer(_extra_files=_extra_files) f.write(ret)
[docs]def load(f, map_location=None, _extra_files=None, _restore_shapes=False): r""" 加载之前使用:func:`torch.jit.save `保存的:class:`ScriptModule`或:class:`ScriptFunction`。 所有之前保存的模块,无论其设备如何,首先加载到CPU上,然后移动到它们保存时的设备。如果这失败(例如因为运行时系统没有某些设备),则会引发异常。 参数: f: 一个类文件对象(必须实现read、readline、tell和seek),或包含文件名的字符串 map_location (string or torch.device): 一个简化版本的``map_location``,用于动态重新映射存储到一组替代设备。 _extra_files (filename到content的字典): 在映射中给出的额外文件名将被加载,并且它们的内容将存储在提供的映射中。 _restore_shapes (bool): 是否在加载时使用存储的输入重新跟踪模块 返回: 一个:class:`ScriptModule`对象。 示例: .. 测试代码:: import torch import io torch.jit.load('scriptmodule.pt') # 从io.BytesIO对象加载ScriptModule with open('scriptmodule.pt', 'rb') as f: buffer = io.BytesIO(f.read()) # 将所有张量加载到原始设备 torch.jit.load(buffer) # 将所有张量加载到CPU,使用设备 buffer.seek(0) torch.jit.load(buffer, map_location=torch.device('cpu')) # 将所有张量加载到CPU,使用字符串 buffer.seek(0) torch.jit.load(buffer, map_location='cpu') # 加载带有额外文件。 extra_files = {'foo.txt': ''} # 值将被数据替换 torch.jit.load('scriptmodule.pt', _extra_files=extra_files) print(extra_files['foo.txt']) .. 测试输出:: :hide: ... .. 测试清理:: import os os.remove("scriptmodule.pt") """ if isinstance(f, (str, os.PathLike)): if not os.path.exists(f): # type: ignore[type-var] raise ValueError(f"提供的文件名 {f} 不存在") # type: ignore[str-bytes-safe] if os.path.isdir(f): raise ValueError(f"提供的文件名 {f} 是一个目录") # type: ignore[str-bytes-safe] map_location = validate_map_location(map_location) if _extra_files is None: _extra_files = {} cu = torch._C.CompilationUnit() if isinstance(f, (str, os.PathLike)): cpp_module = torch._C.import_ir_module(cu, os.fspath(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg] else: cpp_module = torch._C.import_ir_module_from_buffer( cu, f.read(), map_location=