Shortcuts

torch.distributed.checkpoint.state_dict_loader 的源代码

import os
import warnings
from typing import Any, cast, Dict, Optional, Union

import torch
import torch.distributed as dist
from torch.distributed.checkpoint.stateful import Stateful

from ._storage_utils import _storage_setup
from .default_planner import DefaultLoadPlanner
from .planner import LoadPlanner
from .storage import StorageReader
from .utils import _all_gather_keys, _api_bc_check, _DistWrapper, _profile

__all__ = ["load_state_dict", "load"]


[docs]def load_state_dict( state_dict: Dict[str, Any], storage_reader: StorageReader, process_group: Optional[dist.ProcessGroup] = None, coordinator_rank: int = 0, no_dist: bool = False, planner: Optional[LoadPlanner] = None, ) -> None: """此方法已弃用。请切换到 'load'。""" warnings.warn( "'load_state_dict' 已弃用,将在未来版本中移除。" "请改用 'load'。" ) storage_reader.reset() with _profile(): # TODO: 测试在这里返回 `load` 是否可行。 return _load_state_dict( state_dict, storage_reader, process_group, coordinator_rank, no_dist, planner, )
[docs]@_api_bc_check def load( state_dict: Dict[str, Any], *, checkpoint_id: Union[str, os.PathLike, None] = None, storage_reader: Optional[StorageReader] = None, planner: Optional[LoadPlanner] = None, process_group: Optional[dist.ProcessGroup] = None, ) -> None: """ 以 SPMD 风格加载分布式 ``state_dict``。 每个 rank 将尝试读取最少量的数据以满足请求的 `state_dict`。当加载 :class:`ShardedTensor` 或 :class:`DTensor` 实例时,每个 rank 只读取其本地分片的数据。 对于每个 ``Stateful`` 对象(同时具有 ``state_dict`` 和 ``load_state_dict``), load 将首先调用 ``state_dict`` 再尝试反序列化,然后在反序列化完成后调用 ``load_state_dict``。 .. 警告:: 所有 ``state_dict`` 中的张量必须在调用此函数之前分配到其目标设备上。 所有非张量数据使用 `torch.load()` 加载并就地修改 state_dict。 .. 警告:: 用户必须在根模块上调用 `load_state_dict` 以确保加载后处理和非张量数据正确传播。 .. 注意:: 如果未初始化进程组,此函数将假设意图是在本地进程中加载检查点。这在本地推理时很有用, 并且在使用常规张量(而不是 DTensor 或 ShardedTensor)时。 .. 注意:: 假设 rank 0 为协调者 rank。 参数: state_dict (Dict[str, Any]): 要保存的 state_dict。 checkpoint_id (Union[str, os.PathLike, None]): 此检查点实例的 ID。checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。 如果存储是键值存储,它也可以是键。(默认: ``None``) storage_reader (Optional[StorageReader]): 用于执行读取的 StorageWriter 实例。如果未指定,DCP 将根据 checkpoint_id 自动推断读取器。 如果 checkpoint_id 也为 None,将引发异常。(默认: ``None``) planner (Optional[LoadPlanner]): LoadPlanner 实例。如果未指定,将使用默认规划器。(默认: ``None``) process_group (Optional[ProcessGroup]): 用于跨 rank 同步的 ProcessGroup。(默认: ``None``) 返回: None. 示例 >>> # xdoctest: +SKIP >>> my_model = MyModule() >>> optimizer = Adagrad(my_model.parameters()) >>> model_state_dict = my_model.state_dict() >>> fs_storage_reader = torch.distributed.checkpoint.FileSystemReader("/checkpoint/1") >>> torch.distributed.checkpoint.load_state_dict( >>> state_dict=model_state_dict, >>> storage_reader=fs_storage_reader, >>> ) >>> # module.load_state_dict() 函数可能有自定义步骤 >>> # 以刷新 state_dict,必须调用它以 >>> # 确保正确行为。 >>> my_model.load_state_dict(model_state_dict) .. 注意:: load_state_dict 使用集合来协调跨 rank 的读取。 对于基于 NCCL 的进程组,对象的内部张量表示必须在通信发生之前移动到 GPU 设备。 在这种情况下,使用的设备由 ``torch.cuda.current_device()`` 给出, 并且用户有责任确保通过 ``torch.cuda.set_device()`` 设置此设备,以便每个 rank 都有单独的 GPU。 """ no_dist = not (dist.is_available() and dist.is_initialized()) if no_dist: warnings.warn( "torch.distributed 不可用或未初始化,假设意图是在单个进程中加载。" ) with _profile(): storage_reader = cast( StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True) ) if no_dist: keys = list(state_dict.keys()) else:
优云智算