torch.jit._serialization 的源代码
"""序列化。
本模块包含用于序列化TorchScript模块的功能,特别是:
* torch.jit.save
* torch.jit.load
这不打算直接导入;请使用`torch.jit`中暴露的功能。
"""
import os
import torch
from torch.jit._recursive import wrap_cpp_module
from torch.serialization import validate_cuda_device
[docs]def save(m, f, _extra_files=None):
r"""
保存此模块的离线版本,以便在单独的进程中使用。
保存的模块序列化此模块的所有方法、子模块、参数和属性。它可以使用``torch::jit::load(filename)``加载到C++ API中,或者使用:func:`torch.jit.load `加载到Python API中。
要能够保存模块,它不能调用任何本机Python函数。这意味着所有子模块也必须是:class:`ScriptModule`的子类。
.. 危险::
所有模块,无论其设备如何,在加载期间始终加载到CPU上。这与:func:`torch.load`的语义不同,并且可能会在将来更改。
参数:
m: 要保存的:class:`ScriptModule`。
f: 一个类文件对象(必须实现write和flush)或包含文件名的字符串。
_extra_files: 从文件名到内容的映射,这些内容将与`f`一起存储。
.. 注意::
torch.jit.save尝试跨版本保留某些操作符的行为。例如,在PyTorch 1.5中,两个整数张量的除法执行地板除法,如果包含该代码的模块在PyTorch 1.5中保存并在PyTorch 1.6中加载,其除法行为将被保留。在PyTorch 1.6中保存的相同模块在PyTorch 1.5中加载将失败,因为1.6中的除法行为发生了变化,而1.5不知道如何复制1.6的行为。
示例:
.. 测试代码::
import torch
import io
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
m = torch.jit.script(MyModule())
# 保存到文件
torch.jit.save(m, 'scriptmodule.pt')
# 这行代码等同于上一行
m.save("scriptmodule.pt")
# 保存到io.BytesIO缓冲区
buffer = io.BytesIO()
torch.jit.save(m, buffer)
# 保存带有额外文件
extra_files = {'foo.txt': b'bar'}
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
"""
if _extra_files is None:
_extra_files = {}
if isinstance(f, (str, os.PathLike)):
m.save(f, _extra_files=_extra_files)
else:
ret = m.save_to_buffer(_extra_files=_extra_files)
f.write(ret)
[docs]def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
r"""
加载之前使用:func:`torch.jit.save `保存的:class:`ScriptModule`或:class:`ScriptFunction`。
所有之前保存的模块,无论其设备如何,首先加载到CPU上,然后移动到它们保存时的设备。如果这失败(例如因为运行时系统没有某些设备),则会引发异常。
参数:
f: 一个类文件对象(必须实现read、readline、tell和seek),或包含文件名的字符串
map_location (string or torch.device): 一个简化版本的``map_location``,用于动态重新映射存储到一组替代设备。
_extra_files (filename到content的字典): 在映射中给出的额外文件名将被加载,并且它们的内容将存储在提供的映射中。
_restore_shapes (bool): 是否在加载时使用存储的输入重新跟踪模块
返回:
一个:class:`ScriptModule`对象。
示例:
.. 测试代码::
import torch
import io
torch.jit.load('scriptmodule.pt')
# 从io.BytesIO对象加载ScriptModule
with open('scriptmodule.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
# 将所有张量加载到原始设备
torch.jit.load(buffer)
# 将所有张量加载到CPU,使用设备
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))
# 将所有张量加载到CPU,使用字符串
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')
# 加载带有额外文件。
extra_files = {'foo.txt': ''} # 值将被数据替换
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
print(extra_files['foo.txt'])
.. 测试输出::
:hide:
...
.. 测试清理::
import os
os.remove("scriptmodule.pt")
"""
if isinstance(f, (str, os.PathLike)):
if not os.path.exists(f): # type: ignore[type-var]
raise ValueError(f"提供的文件名 {f} 不存在") # type: ignore[str-bytes-safe]
if os.path.isdir(f):
raise ValueError(f"提供的文件名 {f} 是一个目录") # type: ignore[str-bytes-safe]
map_location = validate_map_location(map_location)
if _extra_files is None:
_extra_files = {}
cu = torch._C.CompilationUnit()
if isinstance(f, (str, os.PathLike)):
cpp_module = torch._C.import_ir_module(cu, os.fspath(f), map_location, _extra_files, _restore_shapes) # type: ignore[call-arg]
else:
cpp_module = torch._C.import_ir_module_from_buffer(
cu, f.read(), map_location=