Shortcuts

torch.load

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)[源代码]

从文件加载使用 torch.save() 保存的对象。

torch.load() 使用 Python 的解封设施,但对张量底层的存储进行了特殊处理。它们首先在 CPU 上反序列化,然后移动到它们保存时的设备上。如果这失败了(例如因为运行时系统没有某些设备),则会引发异常。然而,可以使用 map_location 参数将存储动态重新映射到一组替代设备。

如果 map_location 是一个可调用对象,它将为每个序列化的存储调用一次,并带有两个参数:存储和位置。存储参数将是存储的初始反序列化,位于CPU上。每个序列化的存储都有一个与之关联的位置标签,该标签标识了它保存时的设备,这个标签是传递给 map_location 的第二个参数。内置的位置标签是 'cpu' 用于CPU张量和 'cuda:device_id'(例如 'cuda:2')用于CUDA张量。map_location 应返回 None 或一个存储。如果 map_location 返回一个存储,它将被用作最终反序列化的对象,已经移动到正确的设备。否则,torch.load() 将回退到默认行为,就像没有指定 map_location 一样。

如果 map_location 是一个 torch.device 对象或包含设备标签的字符串,它指示所有张量应加载到的位置。

否则,如果 map_location 是一个字典,它将用于将文件中出现的位置标签(键)重新映射到指定存储位置的标签(值)。

用户扩展可以通过torch.serialization.register_package()注册自己的位置标签和标记及反序列化方法。

Parameters
  • f (联合[字符串, 路径类, 二进制IO, IO[字节]]) – 一个类文件对象(必须实现 read(), readline(), tell(), 和 seek()), 或者一个包含文件名的字符串或 os.PathLike 对象

  • map_location (可选[联合[可调用[[张量, 字符串], 张量], 设备, 字符串, 字典[字符串, 字符串]]]) – 一个函数,torch.device,字符串或一个字典,指定如何重新映射存储位置

  • pickle_module (可选[任意]) – 用于解封元数据和对象的模块(必须与用于序列化文件的pickle_module匹配)

  • weights_only (bool) – 指示解封器是否应仅限于加载张量、原始类型和字典

  • mmap (可选[布尔值]) – 指示文件是否应被映射而不是将所有存储加载到内存中。 通常,文件中的张量存储首先会从磁盘移动到CPU内存,之后它们会被移动到保存时标记的位置,或者由map_location指定。如果最终位置是CPU,则第二步是空操作。当mmap标志被设置时,第一步中不会将张量存储从磁盘复制到CPU内存,而是将f映射。

  • pickle_load_args (Any) – (仅限 Python 3) 传递给 pickle_module.load()pickle_module.Unpickler() 的可选关键字参数,例如, errors=...

Return type

任意

警告

torch.load() 除非 weights_only 参数设置为 True, 否则会隐式使用 pickle 模块,该模块已知是不安全的。 可以构造恶意的 pickle 数据,在解封时执行任意代码。 切勿在不安全模式下加载可能来自不受信任源的数据,或可能已被篡改的数据。只加载您信任的数据

注意

当你在包含GPU张量的文件上调用torch.load()时,这些张量默认会被加载到GPU上。你可以调用torch.load(.., map_location='cpu') 然后调用load_state_dict()来避免在加载模型检查点时GPU内存激增。

注意

默认情况下,我们将字节字符串解码为utf-8。这是为了避免在Python 3中加载由Python 2保存的文件时出现常见的错误情况UnicodeDecodeError: 'ascii' codec can't decode byte 0x...。如果此默认设置不正确,您可以使用额外的encoding关键字参数来指定如何加载这些对象,例如,encoding='latin1'使用latin1编码将它们解码为字符串,而encoding='bytes'则将它们保持为字节数组,这些字节数组可以在以后使用byte_array.decode(...)进行解码。

示例

>>> torch.load('tensors.pt', weights_only=True)
# 将所有张量加载到CPU上
>>> torch.load('tensors.pt', map_location=torch.device('cpu'), weights_only=True)
# 使用函数将所有张量加载到CPU上
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage, weights_only=True)
# 将所有张量加载到GPU 1上
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1), weights_only=True)
# 将张量从GPU 1映射到GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1': 'cuda:0'}, weights_only=True)
# 从io.BytesIO对象加载张量
# 从缓冲区加载,设置weights_only=False,警告:这可能不安全
>>> with open('tensor.pt', 'rb') as f:
...     buffer = io.BytesIO(f.read())
>>> torch.load(buffer, weights_only=False)
# 使用'ascii'编码解封装加载模块
# 从模块加载,设置weights_only=False,警告:这可能不安全
>>> torch.load('module.pt', encoding='ascii', weights_only=False)
优云智算