• Docs >
  • Distributed Checkpoint - torch.distributed.checkpoint
Shortcuts

分布式检查点 - torch.distributed.checkpoint

分布式检查点(DCP)支持从多个rank并行加载和保存模型。 它处理加载时的resharding,从而实现在一个集群拓扑中保存,并在另一个集群拓扑中加载。

DCP 与 torch.savetorch.load 在几个重要方面有所不同:

  • 它为每个检查点生成多个文件,每个等级至少有一个。

  • 它在原地操作,意味着模型应首先分配其数据,DCP使用该存储空间。

加载和保存检查点的入口点如下:

torch.distributed.checkpoint.state_dict_saver.save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[源代码]

保存一个分布式模型,采用SPMD风格。

此函数与 torch.save() 不同,因为它处理 ShardedTensorDTensor,通过让每个rank仅保存其本地分片。

对于每个Stateful对象(同时具有state_dictload_state_dict), 保存将在序列化之前调用state_dict

警告

在不同版本的 PyTorch 中,保存的 state_dicts 不保证向后兼容性。

警告

如果使用process_group参数,请确保只有其rank调用save_state_dict,并且state_dict中的所有数据都属于它。

注意

当保存FSDP的ShardingStrategy.HYBRID_SHARD的检查点时,只有一个shard_group应该调用save_state_dict,并且需要传入相应的进程组。

注意

If no process group is available, this function assumes the intention is to save the

本地进程中的state_dict。

Parameters
  • state_dict (字典[str, 任意]) – 要保存的状态字典。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id的含义取决于存储方式。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是一个键。(默认值: None)

  • storage_writer (可选[StorageWriter]) – 用于执行写入的 StorageWriter 实例。如果未指定此参数,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则会引发异常。(默认值:None

  • 计划器 (可选[SavePlanner]) – SavePlanner 的实例。如果未指定,将使用默认计划器。(默认值:None

  • process_group (可选[ProcessGroup]) – 用于跨等级同步的进程组。 (默认值: None)

Returns

保存的检查点元数据对象。

Return type

元数据

示例

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> torch.distributed.checkpoint.save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )

注意

save_state_dict 使用集合来协调跨等级的写入。 对于基于 NCCL 的进程组,对象的内部张量表示必须在通信发生之前移动到 GPU 设备。 在这种情况下,使用的设备由 torch.cuda.current_device() 给出, 并且用户有责任确保通过 torch.cuda.set_device() 设置此设备,以便每个等级都有单独的 GPU。

torch.distributed.checkpoint.state_dict_saver.async_save(state_dict, *, checkpoint_id=None, storage_writer=None, planner=None, process_group=None)[源代码]

异步版本的 save_state_dict。此代码首先在 CPU 上取消暂存 state_dict,然后在单独的线程中调用 save

警告

此功能是实验性的,可能会发生变化。

Parameters
  • state_dict (字典[str, 任意]) – 要保存的状态字典。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id的含义取决于存储方式。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是一个键。(默认值: None)

  • storage_writer (可选[StorageWriter]) – 用于执行写入的 StorageWriter 实例。如果未指定此项,DCP 将根据 checkpoint_id 自动推断写入器。如果 checkpoint_id 也为 None,则会引发异常。(默认值:None

  • 计划器 (可选[SavePlanner]) – SavePlanner 的实例。如果未指定,将使用默认计划器。(默认值:None

  • process_group (可选[ProcessGroup]) – 用于跨rank同步的ProcessGroup。 (默认值: None)

Returns

一个持有从save方法返回的结果Metadata对象的未来。

Return type

未来

示例

>>> my_model = MyModule()
>>> state_dict = {"model": my_model}
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter("/checkpoint/1")
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
>>>     state_dict=state_dict,
>>>     storage_writer=fs_storage_writer,
>>> )
>>>
>>> # ... 做一些工作 ...
>>>
>>> checkpoint_future.result()
torch.distributed.checkpoint.state_dict_saver.save_state_dict(state_dict, storage_writer, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[源代码]

此方法已弃用。请切换到‘save’。

Return type

元数据

torch.distributed.checkpoint.state_dict_loader.load(state_dict, *, checkpoint_id=None, storage_reader=None, planner=None, process_group=None)[源代码]

以SPMD风格加载分布式state_dict

每个等级将尝试读取最少量的数据,以满足请求的state_dict。当加载ShardedTensorDTensor实例时,每个等级仅读取其本地分片的数据。

对于每个Stateful对象(具有state_dictload_state_dict), 加载将在尝试反序列化之前首先调用state_dict,然后在反序列化完成后调用load_state_dict

警告

在调用此函数之前,state_dict 中的所有张量都必须分配到其目标设备上。

所有非张量数据都使用torch.load()加载,并在state_dict上就地修改。

警告

用户必须在根模块上调用load_state_dict以确保加载后处理和非张量数据正确传播。

Parameters
  • state_dict (字典[str, 任意]) – 要保存的状态字典。

  • checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id的含义取决于存储方式。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是一个键。(默认值: None)

  • storage_reader (可选[StorageReader]) – 用于执行读取的 StorageWriter 实例。如果未指定此参数,DCP 将根据 checkpoint_id 自动推断读取器。如果 checkpoint_id 也为 None,则会引发异常。(默认值:None

  • 计划器 (可选[LoadPlanner]) – LoadPlanner 的实例。如果未指定,将使用默认计划器。(默认值:None

  • process_group (可选[ProcessGroup]) – 用于跨rank同步的ProcessGroup。 (默认值: None)

Returns

无。

Return type

Examples
>>> my_model = MyModule()
>>> optimizer = Adagrad(my_model.parameters())
>>> model_state_dict = my_model.state_dict()
>>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1")
>>> torch.distributed.checkpoint.load_state_dict(
>>>     state_dict=model_state_dict,
>>>     storage_reader=fs_storage_reader,
>>> )
>>> # module.load_state_dict() 函数可能有自定义的步骤
>>> # 为了刷新 state_dict,必须调用它以
>>> # 确保正确的行为。
>>> my_model.load_state_dict(model_state_dict)

注意

load_state_dict 使用集合来协调跨进程的读取。 对于基于 NCCL 的进程组,对象的内部张量表示必须在通信发生之前移动到 GPU 设备上。 在这种情况下,使用的设备由 torch.cuda.current_device() 给出, 并且用户有责任确保通过 torch.cuda.set_device() 设置此设备,以便每个进程都有独立的 GPU。

torch.distributed.checkpoint.state_dict_loader.load_state_dict(state_dict, storage_reader, process_group=None, coordinator_rank=0, no_dist=False, planner=None)[源代码]

此方法已弃用。请切换到‘load’。

除了上述入口点外,有状态对象(如下所述)在保存/加载过程中提供了额外的定制功能 .. automodule:: torch.distributed.checkpoint.stateful

class torch.distributed.checkpoint.stateful.Stateful(*args, **kwargs)[源代码]

可检查点并恢复的对象的有状态协议。

load_state_dict(state_dict)[源代码]

从提供的 state_dict 恢复对象的状态。

Parameters

state_dict (字典[字符串, 任意]) – 要恢复的状态字典

state_dict()[源代码]

对象应将其状态字典表示作为字典返回。 此函数的输出将被检查点保存,并在稍后通过 load_state_dict() 恢复。

警告

由于恢复检查点的就地性质,此函数在调用 torch.distributed.checkpoint.load 期间也会被调用。

Returns

对象的状态字典

Return type

字典

这个示例展示了如何使用Pytorch分布式检查点来保存FSDP模型。

以下类型定义了在检查点期间使用的IO接口:

class torch.distributed.checkpoint.StorageReader[源代码]

load_state_dict 使用的接口,用于从存储中读取数据。

一个StorageReader实例在分布式检查点中同时充当协调器和跟随者。作为初始化的一部分,每个实例都会被告知其角色。

子类应预期以下由 load_state_dict 调用的顺序:

  1. (所有排名) 如果用户传递了有效的checkpoint_id,则设置checkpoint_id。

  2. (所有等级) read_metadata()

  3. (所有等级) set_up_storage_reader()

  4. (所有等级) prepare_local_plan()

  5. (协调器) prepare_global_plan()

  6. (所有等级) read_data()

abstract prepare_global_plan(plans)[源代码]

执行存储加载的集中规划。

此方法仅在协调器实例上调用。

虽然这种方法可以生成一个完全不同的计划,但首选的方式是将特定于存储的数据存储在 LoadPlan::storage_data 中。

Parameters

计划 (列表[加载计划]) – 每个等级对应一个加载计划实例的列表。

Returns

存储全局规划后的转换LoadPlan列表

Return type

列表[加载计划]

abstract prepare_local_plan(plan)[源代码]

执行特定存储的本地规划。

虽然这种方法可以生成一个完全不同的计划,但推荐的方式是将特定于存储的数据存储在 LoadPlan::storage_data 中。

Parameters

计划 (LoadPlan) – 使用的LoadPlan中的本地计划。

Returns

存储本地规划后的转换 LoadPlan

Return type

加载计划

abstract read_data(plan, planner)[源代码]

使用 plannerplan 中读取所有项目以解析数据。

子类应调用 LoadPlanner::load_bytes 将 BytesIO 对象反序列化到正确的位置。

子类应调用 LoadPlanner::resolve_tensor 以获取应加载数据的张量。

存储层的责任是正确调度任何所需的跨设备复制。

Parameters
  • 计划 (LoadPlan) – 要执行的本地计划

  • planner (LoadPlanner) – 用于解析项目的计划对象。

Returns

一个在所有读取完成后完成的未来。

Return type

未来[无]

abstract read_metadata()[源代码]

读取检查点元数据。

Returns

与正在加载的检查点关联的元数据对象。

Return type

元数据

abstract reset(checkpoint_id=None)[源代码]

调用以指示即将发生全新的检查点读取。如果用户为此检查点读取设置了checkpoint_id,则可能会存在checkpoint_id。checkpoint_id的含义取决于存储方式。它可以是文件夹/文件的路径或键值存储的键。

Parameters

checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id的含义取决于存储方式。它可以是文件夹或文件的路径。如果存储更像是一个键值存储,它也可以是一个键。(默认值: None)

abstract set_up_storage_reader(metadata, is_coordinator)[源代码]

初始化此实例。

Parameters
  • 元数据 (Metadata) – 要使用的元数据模式。

  • is_coordinator (bool) – 此实例是否负责协调检查点。

abstract classmethod validate_checkpoint_id(checkpoint_id)[源代码]

检查给定的 checkpoint_id 是否受存储支持。这使我们能够启用自动存储选择。

Return type

bool

class torch.distributed.checkpoint.StorageWriter[源代码]

save_state_dict 用于写入存储的接口。

一个StorageWriter实例在分布式检查点中同时充当协调器和跟随者。作为初始化的一部分,每个实例都会被告知其角色。

子类应预期以下调用序列。

  1. (所有排名) 如果用户传递了有效的checkpoint_id,则设置checkpoint_id。

  2. (所有等级) set_up_storage_writer()

  3. (所有等级) prepare_local_plan()

  4. (协调器) prepare_global_plan()

  5. (所有排名) write_data()

  6. (协调器) finish()

abstract finish(metadata, results)[源代码]

写入元数据并将当前检查点标记为成功。

用于序列化元数据的实际格式/模式是实现细节。唯一的要求是它能够恢复到相同的对象图。

Parameters
  • 元数据 (元数据) – 新检查点的元数据

  • 结果 (列表[列表[WriteResult]]) – 来自所有等级的WriteResults列表。

Returns

Return type

abstract prepare_global_plan(plans)[源代码]

执行存储的集中规划。

此方法仅在协调器实例上调用。

虽然这种方法可以生成一个完全不同的计划,但首选的方式是将存储特定数据存储在 SavePlan::storage_data 中。

Parameters

计划 (列表[保存计划]) – 每个等级对应一个保存计划实例的列表。

Returns

存储全局规划后转换的SavePlan列表

Return type

列表[保存计划]

abstract prepare_local_plan(plan)[源代码]

执行特定存储的本地规划。

虽然这种方法可以生成一个完全不同的计划,但推荐的方式是将存储特定数据存储在 SavePlan::storage_data 中。

Parameters

计划 (SavePlan) – 来自正在使用的 SavePlanner 的本地计划。

Returns

存储本地规划后的转换 SavePlan

Return type

保存计划

abstract reset(checkpoint_id=None)[源代码]

调用以指示即将发生全新的检查点写入。如果用户为此检查点写入设置了checkpoint_id,则可能会存在checkpoint_id。checkpoint_id的含义取决于存储方式。它可以是文件夹/文件的路径或键值存储的键。

Parameters

checkpoint_id (Union[str, os.PathLike, None]) – 此检查点实例的ID。checkpoint_id的含义取决于存储方式。它可以是文件夹或文件的路径。如果存储是键值存储,它也可以是一个键。(默认值: None)

abstract set_up_storage_writer(is_coordinator)[源代码]

初始化此实例。

Parameters

is_coordinator (bool) – 此实例是否负责协调检查点。

abstract classmethod validate_checkpoint_id(checkpoint_id)[源代码]

检查给定的 checkpoint_id 是否受存储支持。这使我们能够启用自动存储选择。

Return type

bool

abstract write_data(plan, planner)[源代码]

使用planner解析数据,写出plan中的所有项目。

子类应在计划中的每个项目上调用 SavePlanner::resolve_data 以获取对基础对象的访问权限以进行写入。

子类应延迟调用resolve_data,因为它可以分配内存。 对于张量,请做出以下假设:

  • 它们可能位于任何设备上,包括与WriteItem::tensor_data上的设备不匹配的设备。

  • 它们可能是视图或不连续的。只需要保存投影。

Parameters
  • 计划 (保存计划) – 要执行的保存计划。

  • planner (SavePlanner) – 用于解析项目到数据的计划器对象。

Returns

一个完成时返回 WriteResult 列表的未来

Return type

未来[列表[写入结果]]

以下类型定义了在检查点期间使用的规划器接口:

class torch.distributed.checkpoint.LoadPlanner[源代码]

定义load_state_dict使用的协议的抽象类,用于规划加载过程。

LoadPlanner 是有状态的对象,可用于自定义整个加载过程。

LoadPlanner 充当 state_dict 的访问代理,因此对其进行的任何转换都将对整个过程可见。

规划器子类在调用load_state_dict期间可以预期以下调用序列:

  1. set_up_planner - called on all ranks.

    表示加载检查点的开始。

  2. create_local_plan - called on all ranks.

    处理 state_dict 并生成一个 LoadPlan,该计划将被发送以进行全局规划。

  3. create_global_plan - called on the coordinator rank only.

    从所有等级获取LoadPlan并做出任何全局决策。

  4. load_bytes - called multiple times on each rank

    这会在 state_dict 中的每个非张量值上调用一次。

  5. resolve_tensor and commit_tensor - called multiple times on each rank

    它们成对地为 state_dict 中的每个 Tensor 值调用。

建议用户扩展DefaultLoadPlanner,而不是直接实现此接口,因为大多数更改都可以通过修改单个方法来实现。

有两种常见的扩展模式:

重写 state_dict。这是扩展加载过程的最简单方法,因为它不需要理解 LoadPlan 的工作原理。我们需要保留对原始 state_dict 的引用,因为加载是就地进行的,所以我们需要能够就地执行它

>>> class RenamePlanner(DefaultLoadPlanner):
>>>     def set_up_planner(self, state_dict, metadata, is_coordinator):
>>>         self.original_state_dict = state_dict
>>>         state_dict = {"foo_" + k: v for k, v in state_dict.items()}
>>>
>>>         if self.flatten_sharded_tensors:
>>>             state_dict = _flatten_sharded_tensors(state_dict)
>>>
>>>         if self.flatten_state_dict:
>>>             state_dict, self.mappings = flatten_state_dict(state_dict)
>>>
>>>         self.state_dict = state_dict
>>>         self.metadata = metadata
>>>         self.is_coordinator = is_coordinator
>>>
>>>     def load_bytes(self, read_item, value):
>>>         # 移除 "foo_" 前缀
>>>         self.original_state_dict[read_item.dest_index.fqn[4:]] = torch.load(value)

修改 resolve_tensor 和 commit_tensor 以处理加载时的转换。

>>> class MetaModelMaterialize(DefaultSavePlanner):
>>>     def resolve_tensor(self, read_item):
>>>         tensor = super().resolve_tensor(read_item)
>>>         return torch.empty_like(tensor, device="cpu")
>>>
>>>     def commit_tensor(self, read_item, tensor):
>>>         self.state_dict[read_item.dest_index.fqn] = tensor
abstract commit_tensor(read_item, tensor)[源代码]

当 StorageReader 完成将数据加载到 tensor 后调用一次。

提供的张量是调用 resolve_tensor 返回的同一个张量。 仅当此 LoadPlanner 需要在将其复制回 state_dict 中的张量之前对 tensor 进行后处理时,才需要此方法。

张量的内容将遵循其设备同步模型。

abstract create_global_plan(global_plan)[源代码]

计算全局负载计划并返回每个等级的计划。

. 注意:这仅在协调器秩上调用

Return type

列表[加载计划]

abstract create_local_plan()[源代码]

基于 state_dict 和 set_up_planner 提供的元数据创建一个 LoadPlan。

. 注意:这是在每个等级上调用的。

Return type

加载计划

abstract finish_plan(central_plan)[源代码]

接受协调员的计划并返回最终的LoadPlan。

Return type

加载计划

abstract load_bytes(read_item, value)[源代码]

加载由 read_itemvalue 描述的项目。

此方法预计会就地修改底层 state_dict。

由用于生成正在加载的检查点的 SavePlanner 定义 value 的内容。

abstract resolve_tensor(read_item)[源代码]

返回由 read_item 描述的张量,以供 StorageReader 加载 read_item

张量应与底层state_dict中的一个张量别名,因为StorageReader将替换其内容。 如果由于任何原因无法实现,规划器可以使用commit_tensor方法将数据复制回state_dict中的张量。

Return type

张量

abstract set_up_planner(state_dict, metadata, is_coordinator)[源代码]

初始化此实例以将数据加载到 state_dict 中。

. 注意:这是在每个等级上调用的。

class torch.distributed.checkpoint.LoadPlan(items: List[torch.distributed.checkpoint.planner.ReadItem], storage_data: Any = None, planner_data: Any = None)[源代码]
class torch.distributed.checkpoint.ReadItem(type: torch.distributed.checkpoint.planner.LoadItemType, dest_index: torch.distributed.checkpoint.metadata.MetadataIndex, dest_offsets: torch.Size, storage_index: torch.distributed.checkpoint.metadata.MetadataIndex, storage_offsets: torch.Size, lengths: torch.Size)[源代码]
class torch.distributed.checkpoint.SavePlanner[源代码]

定义用于规划保存过程的协议的抽象类。

SavePlanners 是有状态的对象,可用于自定义整个保存过程。

SavePlanner 充当 state_dict 的访问代理,因此对其进行的任何转换都将对整个过程可见。

计划器子类在save_state_dict期间可以预期以下调用序列:

  1. set_up_planner - called on all ranks.

    信号表示检查点保存的开始。

  2. create_local_plan - called on all ranks.

    处理 state_dict 并生成一个 SavePlan,该计划将被发送以进行全局规划。

  3. create_global_plan - called on the coordinator rank only.

    从所有等级获取SavePlan并做出任何全局决策。

  4. finish_plan - called on all ranks.

    这使得每个等级都有机会调整到全球规划决策。

  5. resolve_data - called multiple times on each rank

    在存储层的state_dict中查找要写入的值。

建议用户扩展DefaultSavePlanner,而不是直接扩展此接口,因为大多数更改都可以通过更改单个方法来实现。

有3种常见的扩展模式:

重写 state_dict。这是扩展保存过程的最简单方法,因为它不需要理解 SavePlan 的工作原理:

>>> class RenamePlanner(DefaultSavePlanner):
>>>     def set_up_planner(self, state_dict, is_coordinator):
>>>         # 将所有键前缀为 `foo_`
>>>         super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, is_coordinator)

修改本地计划和查找同时进行。这在需要精细控制数据持久化方式时非常有用

>>> class FP16Planner(DefaultSavePlanner):
>>>     def create_local_plan(self):
>>>         plan = super().create_local_plan()
>>>         for p in plan:
>>>             if p.tensor_data is not None:
>>>                 p.tensor_data.properties.dtype = torch.float16
>>>         return plan
>>>
>>>     def resolve_data(self, write_item):
>>>         item = super().resolve_data(write_item)
>>>         return item if write_item.type == WriteItemType.BYTE_IO else item.to(torch.float16)

使用全局规划步骤来做出中央决策,这些决策不能由每个等级单独做出

>>> from itertools import islice
>>> from dataclasses import replace
>>> class DDPLoadBalancingPlanner(DefaultSavePlanner):
>>>     # 这使用了默认的本地计划行为,即所有非分片写入都在rank 0中
>>>     # 此示例不处理ShardedTensors
>>>     def create_global_plan(self, all_plans):
>>>         def chunk(it, size):
>>>             it = iter(it)
>>>         return list(iter(lambda: tuple(islice(it, size)), ()))
>>>         all_plans = [
>>>             replace(plan, items=items) for plan, items in
>>>                 zip(all_plans, chunk(all_plans[0].items, len(all_plans)))
>>>         ]
>>>         return super().create_global_plan(all_plans)

最后,一些规划器需要在检查点中保存额外的元数据,这是通过让每个rank在其本地计划中贡献其数据项,并由全局规划器将它们聚合来实现的:

>>> class SaveExtraDataPlanner(DefaultSavePlanner):
>>>     def create_local_plan(self) -> SavePlan:
>>>         plan = super().create_local_plan()
>>>         return replace(plan, planner_data="per-rank-data")
>>>
>>>     def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]:
>>>         global_plan, metadata = super().create_global_plan(all_plans)
>>>         merged_data = [p.planner_data for p in global_plan]
>>>         metadata = replace(metadata, planner_data=merged_data)
>>>         return global_plan, metadata
abstract create_global_plan(all_plans)[源代码]

计算全局检查点计划并返回每个等级的本地计划。

这仅在协调器等级上调用。

Return type

元组[列表[保存计划], 元数据]

abstract create_local_plan()[源代码]

计算当前等级的保存计划。

这将汇总并传递给 create_global_plan。 计划器特定的数据可以通过 SavePlan::planner_data 传递。

这会在所有等级上调用。

Return type

保存计划

abstract finish_plan(new_plan)[源代码]

合并由create_local_plan创建的计划和create_global_plan的结果。

这将在所有等级上调用。

Return type

保存计划

abstract resolve_data(write_item)[源代码]

转换并准备从 state_dict 中存储的 write_item,确保幂等性和线程安全。

state_dict中查找与write_item关联的对象,并在存储层使用它之前应用任何转换(例如序列化)。

在每个rank上被多次调用,每个WriteItem在最终的SavePlan中至少被调用一次。

此方法应具有幂等性和线程安全性。StorageWriter 实现可以自由地根据需要频繁调用它。

任何分配内存的转换都应在调用此方法时惰性执行,以减少检查点所需的峰值内存。

当返回张量时,它们可以位于任何设备或格式上,也可以是视图。存储层的责任是确定如何保存它们。

Return type

联合[张量, 字节流]

abstract set_up_planner(state_dict, is_coordinator)[源代码]

初始化此规划器以保存 state_dict

实现时应保存这些值,因为它们在保存过程中不会被提供。

这会在所有等级上调用。

class torch.distributed.checkpoint.SavePlan(items: List[torch.distributed.checkpoint.planner.WriteItem], storage_data: Any = None, planner_data: Any = None)[源代码]
class torch.distributed.checkpoint.planner.WriteItem(index, type, tensor_data=None)[源代码]

数据类,用于保存需要写入存储的信息。

tensor_storage_size()[源代码]

计算底层张量的存储大小,如果不是张量写入,则返回None。

Returns

可选的[int]存储大小,以字节为单位,如果有底层张量。

Return type

可选[整数]

我们提供了一个基于文件系统的存储层:

class torch.distributed.checkpoint.filesystem.FileSystemReader(path)[源代码]
class torch.distributed.checkpoint.filesystem.FileSystemWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000)[源代码]

使用文件IO的基本StorageWriter实现。

此实现做出了以下假设和简化:

  • 检查点路径是一个空目录或不存在的目录。

  • 文件创建是原子的

检查点由每个写请求的一个文件加上一个包含序列化元数据的.metadata文件组成。

此外,我们提供了以下用于处理Fsspec存储的抽象。

class torch.distributed.checkpoint.fsspec.FsspecReader(path)[源代码]
class torch.distributed.checkpoint.fsspec.FsspecWriter(path, single_file_per_rank=True, sync_files=True, thread_count=1, per_thread_copy_ahead=10000000)[源代码]

使用FFspec的基本StorageWriter实现。

此实现做出了以下假设和简化:

  • 检查点路径是一个空目录或不存在的目录。

  • 文件创建是原子的

检查点由每个写请求的一个文件加上一个包含序列化元数据的.metadata文件组成。

我们提供了LoadPlannerSavePlanner的默认实现,可以处理所有torch.distributed构造,如FSDP、DDP、ShardedTensor和DistributedTensor。

class torch.distributed.checkpoint.DefaultSavePlanner(flatten_state_dict=True, flatten_sharded_tensors=True, dedup_replicated_tensors=None)[源代码]
lookup_object(index)[源代码]

从规划器接口扩展,使其易于扩展默认规划器。

Return type

任意

transform_object(write_item, object)[源代码]

从规划器接口扩展,使其易于扩展默认规划器。

class torch.distributed.checkpoint.DefaultLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True)[源代码]

DefaultLoadPlanner 在 LoadPlanner 的基础上增加了多个功能。

特别是它添加了以下内容:

flatten_state_dict: 处理包含嵌套字典的state_dict flatten_sharded_tensors: 用于2D并行模式下的FSDP

lookup_tensor(index)[源代码]

从规划器接口扩展,使其易于扩展默认规划器。

Return type

张量

transform_tensor(read_item, tensor)[源代码]

从规划器接口扩展,使其易于扩展默认规划器。

由于遗留的设计决策,FSDPDDP 的状态字典可能具有不同的键或完全限定名称(例如,layer1.weight),即使原始的未并行化模型是相同的。此外,FSDP 提供了各种类型的模型状态字典,例如完整和分片状态字典。此外,优化器状态字典使用参数ID而不是完全限定名称来标识参数,在使用并行化时可能会导致问题(例如,流水线并行)。

为了应对这些挑战,我们为用户提供了一系列API,以便轻松管理state_dicts。get_model_state_dict返回一个模型状态字典,其键与未并行化的模型状态字典返回的键一致。同样,get_optimizer_state_dict提供了优化器状态字典,其键在所有应用的并行化中保持一致。为了实现这种一致性,get_optimizer_state_dict将参数ID转换为与未并行化的模型状态字典中找到的完全限定名称相同的名称。

请注意,这些API返回的结果可以直接与torch.distributed.checkpoint.save()torch.distributed.checkpoint.load()方法一起使用,而无需进行任何额外的转换。

请注意,此功能是实验性的,API签名可能会在未来发生变化。

torch.distributed.checkpoint.state_dict.get_state_dict(model, optimizers, *, submodules=None, options=None)[源代码]

返回模型的 state_dict 和优化器的 state_dict。

get_state_dict 可以处理任何由 PyTorch 并行化的模块,包括 FSDP/fully_shard、DDP/replicate、tensor_parallel/parallelize_module 以及这些并行化的任意组合。get_state_dict 的主要功能是:1.) 返回一个可以与不同数量的训练器和/或不同并行化方式重新分片的模型和优化器 state_dict。2.) 隐藏特定于并行化的 state_dict API。用户不需要调用这些 API。3.) 对结果 state_dict 进行健全性检查。

结果状态字典的键是规范的全限定名称(Fully Qualified Names,FQNs)。规范的FQN指的是基于参数在nn.Module层次结构中的位置的FQN。更具体地说,参数的规范FQN是通过调用module.named_parameters()module.named_buffers()返回的FQN,当模块未通过任何并行性进行分布时。由于优化器内部使用参数ID来表示参数,因此在调用此API时,将会有从参数ID到规范FQNs的转换。

get_state_dict 也可以处理未并行化的模块。在这种情况下,get_state_dict 只执行一个功能 —— 将优化器参数ID转换为规范的FQNs。

示例

>>> import torch
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
>>> from torch.nn.parallel import DistributedDataParallel as DDP
>>> from torch.distributed.checkpoint.state_dict import get_state_dict
>>> fsdp_model = FSDP(copy.deepcopy(model))
>>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_model = DDP(copy.deepcopy(model))
>>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
>>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(fsdp_model, fsdp_optim)
>>> # 如果我们简单地调用 ddp_model.state_dict() 和 fsdp_model.state_dict(),
>>> # 断言将会失败。
>>> assert ddp_state_dict == fsdp_state_dict
>>> assert ddp_optim_state == fsdp_optim_state_dict
Parameters
  • 模型 (nn.Module) – 模型的nn.Module。

  • 优化器 (联合[, 优化器, 可迭代[优化器]]) – 用于优化模型的优化器。

  • 子模块 (可选[集合[模块]]) – 可选[集合[nn.模块]]: 仅返回属于子模块的模型参数。

  • 选项 (StateDictOptions) – 控制如何返回模型状态字典和优化器状态字典的选项。详情请参见StateDictOptions

Returns

Tuple 包含模型 state_dict 和优化器 state_dict。

Return type

元组[字典[字符串, ValueType], OptimizerStateType]

torch.distributed.checkpoint.state_dict.get_model_state_dict(model, *, submodules=None, options=None)[源代码]

返回 model 的模型状态字典。

查看 get_state_dict 的详细用法。

Parameters
  • 模型 (nn.Module) – 模型的nn.Module。

  • 子模块 (可选[集合[模块]]) – 可选[集合[nn.模块]]: 仅返回属于子模块的模型参数。

  • 选项 (StateDictOptions) – 控制如何返回模型状态字典和优化器状态字典的选项。详情请参见StateDictOptions

Returns

模型state_dict的状态字典。

Return type

字典[字符串, ValueType]

torch.distributed.checkpoint.state_dict.get_optimizer_state_dict(model, optimizers, *, submodules=None, options=None)[源代码]

返回优化器的组合 state_dict。

查看 get_state_dict 的详细用法。

Parameters
  • 模型 (nn.Module) – 模型的nn.Module。

  • 优化器 (联合[, 优化器, 可迭代[优化器]]) – 用于优化模型的优化器。

  • 子模块 (可选[集合[模块]]) – 可选[集合[nn.模块]]:仅返回属于子模块的模型参数。

  • 选项 (StateDictOptions) – 控制如何返回模型状态字典和优化器状态字典的选项。详情请参见StateDictOptions

Returns

优化器的 state_dict。

Return type

优化器状态类型

torch.distributed.checkpoint.state_dict.set_state_dict(model, optimizers, *, model_state_dict, optim_state_dict, options=None)[源代码]

加载模型 state_dict 和优化器的 state_dict。

get_state_dict 相对应的函数,用于将 state_dict 设置到模型和优化器中。给定的 model_state_dictoptim_state_dict 不必由 get_state_dict 返回,但必须满足以下要求:1) 所有 FQNs 都是 get_state_dict 中定义的规范 FQNs,2) 如果一个张量被分片,它必须是 ShardedTensor 或 DTensor,3) 优化器的 state_dict 不能包含参数 ID;键应该是规范的 FQNs。

Parameters
  • 模型 (nn.Module) – 模型的nn.Module。

  • 优化器 (联合[优化器, 可迭代[优化器]]) – 用于优化模型的优化器。

  • model_state_dict (字典[字符串, ValueType]) – (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]): 要加载的模型状态字典。如果model_state_dict的键是nn.Module,则该键是model的子模块,并且值应该是该子模块的状态字典。在加载状态字典时,子模块的前缀将被附加到状态字典中。

  • optim_state_dict (优化器状态类型) – 优化器状态类型: 要加载的优化器状态字典。

  • 选项 (StateDictOptions) – 用于控制如何加载模型状态字典和优化器状态字典的选项。详情请参见StateDictOptions

Returns

  • missing_keys 是一个包含模型 state_dict 中缺失键的字符串列表。

  • unexpected_keys 是一个包含模型 state_dict 中意外键的字符串列表。

Return type

NamedTuple 包含 missing_keysunexpected_keys 字段

torch.distributed.checkpoint.state_dict.set_model_state_dict(model, model_state_dict, *, options=None)[源代码]

加载模型状态字典。

用于将 state_dict 设置到模型的 get_model_state_dict 的对应方法。详见 set_state_dict 的使用方法。

Parameters
  • 模型 (nn.Module) – 模型的nn.Module。

  • model_state_dict (字典[字符串, ValueType]) – (字典[字符串, ValueType]): 要加载的模型状态字典。如果model_state_dict的键是nn.Module,则该键是model的子模块,值应该是该子模块的状态字典。在加载状态字典时,子模块的前缀将被附加到状态字典中。

  • 选项 (StateDictOptions) – 用于控制如何加载模型状态字典和优化器状态字典的选项。详情请参见StateDictOptions

Returns

  • missing_keys 是一个包含缺失键的字符串列表

  • unexpected_keys 是一个包含意外键的字符串列表

Return type

NamedTuple 包含 missing_keysunexpected_keys 字段

torch.distributed.checkpoint.state_dict.set_optimizer_state_dict(model, optimizers, *, optim_state_dict, options=None)[源代码]

加载优化器的 state_dict。

用于将 state_dict 设置到优化器的 get_optimizer_state_dict 的对应方法。详见 set_state_dict 的使用方法。

Parameters
  • 模型 (nn.Module) – 模型的nn.Module。

  • 优化器 (联合[优化器, 可迭代[优化器]]) – 用于优化模型的优化器。

  • optim_state_dict (优化器状态类型) – 优化器状态类型: 要加载的优化器状态字典。

  • 选项 (StateDictOptions) – 用于控制如何加载模型状态字典和优化器状态字典的选项。详情请参见StateDictOptions

Returns

Return type

class torch.distributed.checkpoint.state_dict.StateDictOptions(full_state_dict=False, cpu_offload=False, ignore_frozen_params=False, keep_submodule_prefixes=True, strict=True)[源代码]

这个数据类指定了 get_state_dict/set_state_dict 的工作方式。

  • full_state_dict: 如果设置为True,返回的state_dict中的所有张量都将被收集。返回的state_dict中不会有ShardedTensor和DTensor。

  • cpu_offload: 将所有张量卸载到CPU。为了防止CPU内存不足,如果full_state_dict也为真,那么只有rank0会获取state_dict,所有其他rank将获取空的state_dict。

  • ignore_frozen_params: 如果值为True,返回的state_dict将不包含任何冻结的参数 – requires_grad 为 False。默认值为 False。

  • keep_submodule_prefixes:当 submodules 不为 None 时,此选项指示是否保留 state_dict 键中的子模块前缀。 例如,如果子模块是 module.pretrain,并且参数的完整 FQN 是 pretrain.layer1.weight。当此选项为 True 时,返回的 state_dict 中参数的键将是 pretrain.layer1.weight。如果选项为 False,则键将是 layer1.weight。 请注意,如果 keep_submodule_prefixes 为 False,可能会出现冲突的 FQNs,因此 submodules 中应只有一个子模块。

  • strict:当调用 set_state_dict 时,strict 选项用于 model.load_state_dict()。 默认值为 False。

对于习惯于在torch.save格式中使用和共享模型的用户,提供了以下方法,这些方法提供了在不同格式之间进行离线转换的实用工具。

torch.distributed.checkpoint.format_utils.dcp_to_torch_save(dcp_checkpoint_dir, torch_save_path)[源代码]

给定包含DCP检查点的目录,此函数将把它转换为Torch保存文件。

Parameters

警告

为了避免内存不足(OOM),建议仅在单个rank上运行此函数。

torch.distributed.checkpoint.format_utils.torch_save_to_dcp(torch_save_path, dcp_checkpoint_dir)[源代码]

给定一个torch保存文件的位置,将其转换为DCP检查点。

Parameters

警告

为了避免内存不足(OOM),建议仅在单个rank上运行此函数。

以下类也可以用于从torch.save格式进行在线加载和模型重新分片。

class torch.distributed.checkpoint.format_utils.BroadcastingTorchSaveReader(checkpoint_id=None, coordinator_rank=0)[源代码]

用于读取Torch保存文件的StorageReader。该读取器将在协调器秩上读取整个检查点,然后将每个张量广播并分片到所有秩。

. 注意:旨在与DynamicMetaLoadPlanner一起使用

警告

当前实现仅支持加载张量。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
prepare_global_plan(global_plan)[源代码]

StorageReader 方法的实现

Return type

列表[加载计划]

prepare_local_plan(plan)[源代码]

StorageReader 方法的实现

Return type

加载计划

read_data(plan, planner)[源代码]

在协调器等级上读取torch保存的数据,并在之后进行广播 这会产生通信成本,但避免了在每个等级上加载整个检查点,希望可以防止OOM问题

Return type

未来[无]

read_metadata()[源代码]

扩展默认的 StorageReader 以支持构建元数据文件

Return type

元数据

reset(checkpoint_id=None)[源代码]

StorageReader 方法的实现

set_up_storage_reader(metadata, is_coordinator)[源代码]

StorageReader 方法的实现

classmethod validate_checkpoint_id(checkpoint_id)[源代码]

StorageReader 方法的实现

Return type

bool

class torch.distributed.checkpoint.format_utils.DynamicMetaLoadPlanner(flatten_state_dict=True, flatten_sharded_tensors=True)[源代码]

扩展了DefaultLoadPlanner,它基于传入的状态字典创建一个新的Metadata对象,避免了从磁盘读取元数据的需要。这在读取没有元数据文件的格式(如Torch保存文件)时非常有用。

. 注意:旨在与BroadcastingTorchSaveReader一起使用

警告

当前实现仅支持加载张量。

>>> sd = {"mode": model}
>>> dcp.load(
>>>    sd,
>>>    storage_reader=BroadcastingTorchSaveReader(),
>>>    planner=DynamicMetaLoadPlanner(),
>>>    checkpoint_id="path_to_model.pt"
>>> )
set_up_planner(state_dict, metadata, is_coordinator)[源代码]

规划器的设置,通过从状态字典创建元数据对象来扩展默认行为