torch.jit.save¶
- torch.jit.save(m, f, _extra_files=None)[源代码]¶
保存此模块的离线版本,以便在单独的进程中使用。
保存的模块序列化了此模块的所有方法、子模块、参数和属性。可以使用
torch::jit::load(filename)加载到C++ API中,或者使用torch.jit.load加载到Python API中。要能够保存一个模块,它不能调用任何原生 Python 函数。这意味着所有子模块也必须是
ScriptModule的子类。危险
所有模块,无论其设备如何,在加载期间总是被加载到CPU上。这与
torch.load()的语义不同,并且可能在将来发生变化。- Parameters
m – 一个
ScriptModule用于保存。f – 一个类文件对象(必须实现 write 和 flush 方法)或包含文件名的字符串。
_extra_files – 从文件名到内容的映射,这些内容将作为f的一部分存储。
注意
torch.jit.save 尝试在不同版本之间保留某些操作符的行为。例如,在 PyTorch 1.5 中,两个整数张量的除法执行的是地板除法,如果包含该代码的模块在 PyTorch 1.5 中保存并在 PyTorch 1.6 中加载,其除法行为将被保留。然而,在 PyTorch 1.6 中保存的同一模块在 PyTorch 1.5 中加载将会失败,因为 1.6 中的除法行为发生了变化,而 1.5 不知道如何复制 1.6 的行为。
示例: .. testcode:
import torch import io class MyModule(torch.nn.Module): def forward(self, x): return x + 10 m = torch.jit.script(MyModule()) # 保存到文件 torch.jit.save(m, 'scriptmodule.pt') # 这一行与前一行等效 m.save("scriptmodule.pt") # 保存到 io.BytesIO 缓冲区 buffer = io.BytesIO() torch.jit.save(m, buffer) # 保存时附带额外文件 extra_files = {'foo.txt': b'bar'} torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)