Shortcuts

torch.distributed.checkpoint.planner 的源代码

```html
import abc
import io
from dataclasses import dataclass
from enum import auto, Enum
from functools import reduce
from typing import Any, List, Optional, Tuple, Union

import torch

from .metadata import (
    ChunkStorageMetadata,
    Metadata,
    MetadataIndex,
    STATE_DICT_TYPE,
    TensorProperties,
)


__all__ = [
    "WriteItemType",
    "LoadItemType",
    "TensorWriteData",
    "WriteItem",
    "ReadItem",
    "SavePlan",
    "LoadPlan",
    "SavePlanner",
    "LoadPlanner",
]


class WriteItemType(Enum):
    TENSOR = auto()
    SHARD = auto()
    BYTE_IO = auto()


class LoadItemType(Enum):
    TENSOR = auto()
    BYTE_IO = auto()


@dataclass(frozen=True)
class TensorWriteData:
    chunk: ChunkStorageMetadata
    properties: TensorProperties
    size: torch.Size


[docs]@dataclass(frozen=True) class WriteItem: """数据类,包含需要写入存储的信息。""" index: MetadataIndex type: WriteItemType # 如果是张量写入,则存在值 tensor_data: Optional[TensorWriteData] = None
[docs] def tensor_storage_size(self) -> Optional[int]: """ 计算底层张量的存储大小,如果不是张量写入则返回None。 返回: Optional[int] 底层张量的存储大小(以字节为单位)。 """ if self.tensor_data is None: return None numels = reduce(lambda x, y: x * y, self.tensor_data.size, 1) dtype_size = torch._utils._element_size(self.tensor_data.properties.dtype) return numels * dtype_size
[docs]@dataclass(frozen=True) class ReadItem: # 读取项 type: LoadItemType # state_dict中的索引 dest_index: MetadataIndex # 目标张量的偏移量 dest_offsets: torch.Size # 检查点中的索引 storage_index: MetadataIndex # 检查点数据的偏移量 storage_offsets: torch.Size # 要复制的超立方体的大小 lengths: torch.Size
[docs]@dataclass(frozen=True) class SavePlan: items: List[WriteItem] storage_data: Any = None planner_data: Any = None
[docs]@dataclass class LoadPlan: items: List[ReadItem] storage_data: Any = None planner_data: Any = None
[docs]class SavePlanner(abc.ABC): """ 定义save_state_dict使用的协议的抽象类。 SavePlanners是有状态的对象,可用于自定义整个保存过程。 SavePlanner充当state_dict的访问代理,因此对其进行的任何转换 将对整个过程可见。 在save_state_dict期间,计划者子类可以预期以下调用序列: 1) set_up_planner - 在所有rank上调用。 信号检查点保存的开始。 2) create_local_plan - 在所有rank上调用。 处理state_dict并生成将发送进行全局规划的`SavePlan`。 3) create_global_plan - 仅在协调者rank上调用。 获取所有rank的SavePlan并做出任何全局决策。 4) finish_plan - 在所有rank上调用。 这为每个rank提供了调整全局规划决策的机会。 5) resolve_data - 在每个rank上多次调用 在`state_dict`中查找值以供存储层写入。 建议用户扩展DefaultSavePlanner而不是直接扩展此接口,因为 大多数更改可以通过单个方法的更改来表达。 有3种常见的扩展模式: 重写state_dict。这是扩展保存过程的最简单方法,因为它 不需要理解SavePlan的工作原理: >>> # xdoctest: +SKIP("undefined vars") >>> class RenamePlanner(DefaultSavePlanner): >>> def set_up_planner(self, state_dict, is_coordinator): >>> # 将所有键前缀为`foo_` >>> super().set_up_planner({"foo_" + k: v for k, v in state_dict.items()}, is_coordinator) 同时修改本地计划和查找。当需要精细控制数据的持久化时很有用 >>> # xdoctest: +SKIP("undefined vars") <span