Shortcuts

torch.distributed.checkpoint.format_utils 的源代码

```html
import argparse
import os
from enum import Enum
from typing import cast, Dict, List, Optional, Union

import torch
import torch.distributed as dist
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torch.distributed.checkpoint._traverse import set_element
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.metadata import (
    Metadata,
    STATE_DICT_TYPE,
    STORAGE_TYPES,
    TensorProperties,
    TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner
from torch.distributed.checkpoint.planner_helpers import _create_chunk_list
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
from torch.distributed.checkpoint.state_dict_saver import _save_state_dict
from torch.distributed.checkpoint.storage import StorageReader
from torch.futures import Future


__all__ = [
    "dcp_to_torch_save",
    "torch_save_to_dcp",
    "BroadcastingTorchSaveReader",
    "DynamicMetaLoadPlanner",
]


class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
    """
    DefaultLoadPlanner的扩展,它从保存的元数据中重建state_dict。
    在没有首先初始化模型的情况下加载state_dict时很有用,例如
    当将DCP检查点转换为Torch保存文件时。

    . 注意:使用此LoadPlanner时,`state_dict`必须是一个空字典

    .. 警告::
        由于整个state dict被初始化,建议仅在单个rank或进程上使用
        此LoadPlanner以避免OOM。

    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def set_up_planner(
        self,
        state_dict: STATE_DICT_TYPE,
        metadata: Metadata,
        is_coordinator: bool,
    ) -> None:
        assert not state_dict

        # 从元数据重建state dict
        for k, v in metadata.state_dict_metadata.items():
            if isinstance(v, TensorStorageMetadata):
                v = torch.empty(v.size, dtype=v.properties.dtype)  # type: ignore[assignment]
            if k in metadata.planner_data:
                set_element(state_dict, metadata.planner_data[k], v)
            else:
                state_dict[k] = v

        super().set_up_planner(state_dict, metadata, is_coordinator)


[docs]class BroadcastingTorchSaveReader(StorageReader): """ 用于读取Torch保存文件的StorageReader。此读取器将在协调器rank上读取整个检查点 然后广播并将每个张量分片到所有rank。 . 注意:旨在与DynamicMetaLoadPlanner一起使用 .. 警告:: 当前实现仅支持加载张量。 >>> # xdoctest: +SKIP("undefined vars") >>> sd = {"mode": model} >>> dcp.load( >>> sd, >>> storage_reader=BroadcastingTorchSaveReader(), >>> planner=DynamicMetaLoadPlanner(), >>> checkpoint_id="path_to_model.pt" >>> ) """ def __init__( self, checkpoint_id: Optional[Union[str, os.PathLike]] = None, coordinator_rank: int = 0, ) -> None: self.checkpoint_id = checkpoint_id self.coordinator_rank = coordinator_rank
[docs] def read_metadata(self) -> Metadata: """扩展默认的StorageReader以支持构建元数据文件""" # 元数据在planner.set_up_planner中构建,因为我们实际上并没有从 # 磁盘读取元数据 return Metadata(state_dict_metadata={})
[docs] def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: """ 在协调器rank上读取torch保存数据,然后广播 这会产生通信成本,但避免了在每个rank上加载 整个检查点,希望防止OOM问题 """ planner = <span class