torch.distributed.checkpoint.format_utils 的源代码
```html
import argparse import os from enum import Enum from typing import cast, Dict, List, Optional, Union import torch import torch.distributed as dist from torch.distributed._shard._utils import narrow_tensor_by_index from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter from torch.distributed.checkpoint._nested_dict import flatten_state_dict from torch.distributed.checkpoint._traverse import set_element from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner from torch.distributed.checkpoint.metadata import ( Metadata, STATE_DICT_TYPE, STORAGE_TYPES, TensorProperties, TensorStorageMetadata, ) from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner from torch.distributed.checkpoint.planner_helpers import _create_chunk_list from torch.distributed.checkpoint.state_dict_loader import _load_state_dict from torch.distributed.checkpoint.state_dict_saver import _save_state_dict from torch.distributed.checkpoint.storage import StorageReader from torch.futures import Future __all__ = [ "dcp_to_torch_save", "torch_save_to_dcp", "BroadcastingTorchSaveReader", "DynamicMetaLoadPlanner", ] class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): """ DefaultLoadPlanner的扩展,它从保存的元数据中重建state_dict。 在没有首先初始化模型的情况下加载state_dict时很有用,例如 当将DCP检查点转换为Torch保存文件时。 . 注意:使用此LoadPlanner时,`state_dict`必须是一个空字典 .. 警告:: 由于整个state dict被初始化,建议仅在单个rank或进程上使用 此LoadPlanner以避免OOM。 """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def set_up_planner( self, state_dict: STATE_DICT_TYPE, metadata: Metadata, is_coordinator: bool, ) -> None: assert not state_dict # 从元数据重建state dict for k, v in metadata.state_dict_metadata.items(): if isinstance(v, TensorStorageMetadata): v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] if k in metadata.planner_data: set_element(state_dict, metadata.planner_data[k], v) else: state_dict[k] = v super().set_up_planner(state_dict, metadata, is_coordinator)[docs]class BroadcastingTorchSaveReader(StorageReader): """ 用于读取Torch保存文件的StorageReader。此读取器将在协调器rank上读取整个检查点 然后广播并将每个张量分片到所有rank。 . 注意:旨在与DynamicMetaLoadPlanner一起使用 .. 警告:: 当前实现仅支持加载张量。 >>> # xdoctest: +SKIP("undefined vars") >>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> ) """ def __init__( self, checkpoint_id: Optional[Union[str, os.PathLike]] = None, coordinator_rank: int = 0, ) -> None: self.checkpoint_id = checkpoint_id self.coordinator_rank = coordinator_rank[docs] def read_metadata(self) -> Metadata: """扩展默认的StorageReader以支持构建元数据文件""" # 元数据在planner.set_up_planner中构建,因为我们实际上并没有从 # 磁盘读取元数据 return Metadata(state_dict_metadata={})[docs] def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: """ 在协调器rank上读取torch保存数据,然后广播 这会产生通信成本,但避免了在每个rank上加载 整个检查点,希望防止OOM问题 """ planner = <span class