Shortcuts

torch.jit.save

torch.jit.save(m, f, _extra_files=None)[源代码]

保存此模块的离线版本,以便在单独的进程中使用。

保存的模块序列化了此模块的所有方法、子模块、参数和属性。可以使用 torch::jit::load(filename) 加载到C++ API中,或者使用 torch.jit.load 加载到Python API中。

要能够保存一个模块,它不能调用任何原生 Python 函数。这意味着所有子模块也必须是 ScriptModule 的子类。

危险

所有模块,无论其设备如何,在加载期间总是被加载到CPU上。这与torch.load()的语义不同,并且可能在将来发生变化。

Parameters
  • m – 一个 ScriptModule 用于保存。

  • f – 一个类文件对象(必须实现 write 和 flush 方法)或包含文件名的字符串。

  • _extra_files – 从文件名到内容的映射,这些内容将作为f的一部分存储。

注意

torch.jit.save 尝试在不同版本之间保留某些操作符的行为。例如,在 PyTorch 1.5 中,两个整数张量的除法执行的是地板除法,如果包含该代码的模块在 PyTorch 1.5 中保存并在 PyTorch 1.6 中加载,其除法行为将被保留。然而,在 PyTorch 1.6 中保存的同一模块在 PyTorch 1.5 中加载将会失败,因为 1.6 中的除法行为发生了变化,而 1.5 不知道如何复制 1.6 的行为。

示例: .. testcode:

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

m = torch.jit.script(MyModule())

# 保存到文件
torch.jit.save(m, 'scriptmodule.pt')
# 这一行与前一行等效
m.save("scriptmodule.pt")

# 保存到 io.BytesIO 缓冲区
buffer = io.BytesIO()
torch.jit.save(m, buffer)

# 保存时附带额外文件
extra_files = {'foo.txt': b'bar'}
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
优云智算