• Docs >
  • Serialization semantics
Shortcuts

序列化语义

本笔记描述了如何在Python中保存和加载PyTorch张量和模块状态,以及如何序列化Python模块以便在C++中加载。

保存和加载张量

torch.save()torch.load() 让你可以轻松地保存和加载张量:

>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])

按照惯例,PyTorch 文件通常使用 ‘.pt’ 或 ‘.pth’ 扩展名。

torch.save()torch.load() 默认使用 Python 的 pickle,因此您也可以将多个张量作为元组、列表和字典等 Python 对象的一部分进行保存:

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

如果数据结构是可序列化的,包含 PyTorch 张量的自定义数据结构也可以保存。

保存和加载张量保留视图

保存张量会保留它们的视图关系:

>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

在幕后,这些张量共享相同的“存储”。更多关于视图和存储的信息,请参见 张量视图

当PyTorch保存张量时,它会分别保存它们的存储对象和张量元数据。这是一个可能会在将来发生变化的实现细节,但它通常节省空间,并让PyTorch能够轻松重建加载的张量之间的视图关系。例如,在上面的代码片段中,只有单个存储被写入到‘tensors.pt’中。

然而,在某些情况下,保存当前的存储对象可能是不必要的,并且会创建过大而难以处理的文件。在下面的代码片段中,一个比保存的张量要大得多的存储被写入到文件中:

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999

不仅仅是将small张量中的五个值保存到‘small.pt’中,它与large共享存储中的999个值也被保存和加载了。

当保存元素少于其存储对象的张量时,可以通过首先克隆张量来减小保存文件的大小。克隆张量会产生一个新张量,该张量具有一个新的存储对象,仅包含张量中的值:

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt')  # 保存 small 的克隆
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5

由于克隆的张量彼此独立,因此它们没有任何原始张量之间的视图关系。如果在保存小于其存储对象的张量时,文件大小和视图关系都很重要,那么在保存之前必须小心构建新的张量,以最小化其存储对象的大小,但仍保持所需的视图关系。

保存和加载 torch.nn.Modules

另请参阅: 教程: 保存和加载模块

在 PyTorch 中,模块的状态通常使用“状态字典”进行序列化。 模块的状态字典包含其所有参数和持久缓冲区:

>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

出于兼容性原因,建议不要直接保存模块,而是仅保存其状态字典。Python模块甚至有一个函数,load_state_dict(),用于从状态字典恢复其状态:

>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<所有键匹配成功>

请注意,状态字典首先通过torch.load()从其文件中加载,然后通过load_state_dict()恢复状态。

即使是自定义模块和包含其他模块的模块也有状态字典,并且可以使用这种模式:

# 一个包含两个线性层的模块
>>> class MyModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                   [-0.3289, 0.2827, 0.4588, 0.2031]])),
             ('l0.bias', tensor([ 0.0300, -0.1316])),
             ('l1.weight', tensor([[0.6533, 0.3413]])),
             ('l1.bias', tensor([-0.1112]))])

>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>

用于torch.save的序列化文件格式

自 PyTorch 1.6.0 起,torch.save 默认返回一个未压缩的 ZIP64 存档,除非用户设置 _use_new_zipfile_serialization=False

在这个存档中,文件的顺序如下

checkpoint.pth
├── data.pkl
├── byteorder  # 在 PyTorch 2.1.0 中添加
├── data/
│   ├── 0
│   ├── 1
│   ├── 2
│   └── …
└── version
The entries are as follows:
  • data.pkl 是序列化传递给 torch.save 的对象的结果, 不包括它包含的 torch.Storage 对象

  • byteorder 包含一个字符串,表示保存时的 sys.byteorder(“little” 或 “big”)

  • data/ 包含对象中的所有存储,其中每个存储都是一个单独的文件

  • version 包含在保存时的一个版本号,可以在加载时使用

保存时,PyTorch 会确保每个文件的本地文件头被填充到 64 字节的倍数偏移量,以确保每个文件的偏移量是 64 字节对齐的。

注意

某些设备(如XLA)上的张量被序列化为pickled的numpy数组。因此,它们的存储不会被序列化。在这些情况下,检查点中可能不存在data/

序列化 torch.nn.Modules 并在 C++ 中加载

另请参阅: 教程: 在C++中加载TorchScript模型

ScriptModules 可以被序列化为一个 TorchScript 程序,并使用 torch.jit.load() 加载。 这种序列化编码了所有模块的方法、子模块、参数和属性,并且它允许在 C++ 中加载序列化的程序(即不需要 Python)。

关于torch.jit.save()torch.save()之间的区别可能并不立即清晰。torch.save()使用pickle保存Python对象。这对于原型设计、研究和训练特别有用。另一方面,torch.jit.save()将ScriptModules序列化为可以在Python或C++中加载的格式。这在保存和加载C++模块时非常有用,或者在用C++运行在Python中训练的模块时,这是部署PyTorch模型时的常见做法。

在Python中编写、序列化和加载模块:

>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
递归脚本模块( original_name=MyModule
                      (l0): 递归脚本模块(original_name=Linear)
                      (l1): 递归脚本模块(original_name=Linear) )

跟踪的模块也可以使用torch.jit.save()保存,但需要注意的是,只有跟踪的代码路径会被序列化。以下示例演示了这一点:

# 一个包含控制流的模块
>>> class ControlFlowModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        if input.dim() > 1:
            return torch.tensor(0)

        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)

>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)

上述模块有一个if语句,该语句不是由跟踪的输入触发的,因此不是跟踪模块的一部分,也不会与它一起序列化。然而,脚本化模块包含该if语句,并且会与它一起序列化。有关脚本化和跟踪的更多信息,请参阅TorchScript文档

最后,在C++中加载模块:

>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');

请参阅PyTorch C++ API 文档 以了解如何在 C++ 中使用 PyTorch 模块的详细信息。

在不同PyTorch版本之间保存和加载ScriptModules

PyTorch 团队建议使用相同版本的 PyTorch 保存和加载模块。较旧版本的 PyTorch 可能不支持较新的模块,而较新版本可能已经删除了或修改了较旧的行为。这些更改在 PyTorch 的 发布说明 中有明确描述,依赖于已更改功能的模块可能需要更新以继续正常工作。在有限的情况下,如下所述,PyTorch 将保留序列化 ScriptModules 的历史行为,因此它们不需要更新。

torch.div 执行整数除法

在 PyTorch 1.5 及更早版本中,torch.div() 在给定两个整数输入时会执行地板除法:

# PyTorch 1.5(及更早版本)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)

在 PyTorch 1.7 中,torch.div() 将始终对其输入执行真正的除法,就像 Python 3 中的除法一样:

# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)

torch.div() 的行为在序列化的 ScriptModules 中得以保留。 也就是说,使用 PyTorch 1.6 之前版本序列化的 ScriptModules 在加载到较新版本的 PyTorch 时,即使给定两个整数输入,torch.div() 仍将执行地板除法。然而,使用 torch.div() 并在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 无法在早期版本的 PyTorch 中加载,因为这些早期版本不理解新的行为。

torch.full 总是推断浮点数据类型

在 PyTorch 1.5 及更早版本中,torch.full() 总是返回一个浮点张量,无论给定的填充值是什么:

# PyTorch 1.5 及更早版本
>>> torch.full((3,), 1)  # 注意整数填充值...
tensor([1., 1., 1.])     # ...但返回的是浮点数张量!

在 PyTorch 1.7 中,torch.full() 会根据填充值推断返回张量的 dtype:

# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])

>>> torch.full((3,), True)
tensor([True, True, True])

>>> torch.full((3,), 1.)
tensor([1., 1., 1.])

>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])

torch.full() 的行为在序列化的 ScriptModules 中保持不变。也就是说,使用 PyTorch 1.6 之前版本序列化的 ScriptModules 将继续看到 torch.full() 默认返回浮点张量,即使给定了布尔值或整数填充值。使用 torch.full() 并在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 无法在早期版本的 PyTorch 中加载,因为这些早期版本不理解新的行为。

实用函数

以下实用函数与序列化相关:

torch.serialization.register_package(priority, tagger, deserializer)[源代码]

注册具有关联优先级的标记和反序列化存储对象的可调用对象。 标记在保存时将设备与存储对象关联,而反序列化在加载时将存储对象移动到适当的设备。taggerdeserializer 按照它们的 priority 顺序运行,直到一个标记器/反序列化器返回一个不是 None 的值。

要在全局注册表中覆盖某个设备的反序列化行为,可以注册一个优先级高于现有标签器的标签器。

此函数还可用于为新设备注册标签器和反序列化器。

Parameters
Returns

示例

>>> def ipu_tag(obj):
>>>     if obj.device.type == 'ipu':
>>>         return 'ipu'
>>> def ipu_deserialize(obj, location):
>>>     if location.startswith('ipu'):
>>>         ipu = getattr(torch, "ipu", None)
>>>         assert ipu is not None, "IPU设备模块未加载"
>>>         assert torch.ipu.is_available(), "ipu不可用"
>>>         return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
torch.serialization.get_default_load_endianness()[源代码]

获取加载文件的备用字节顺序

如果保存的检查点中不存在字节顺序标记,则使用此字节顺序作为回退。默认情况下,它是“本机”字节顺序。

Returns

可选[加载字节序]

Return type

默认加载字节序

torch.serialization.set_default_load_endianness(endianness)[源代码]

设置加载文件时的备用字节顺序

如果保存的检查点中不存在字节顺序标记,则使用此字节顺序作为回退。默认情况下,它是“本机”字节顺序。

Parameters

字节序 – 新的备用字节顺序

优云智算