Shortcuts

torch.distributed.checkpoint.storage 的源代码

import abc
import os
from dataclasses import dataclass
from typing import Any, List, Union

from torch.futures import Future

from .metadata import Metadata, MetadataIndex
from .planner import LoadPlan, LoadPlanner, SavePlan, SavePlanner

__all__ = ["WriteResult", "StorageWriter", "StorageReader"]


@dataclass(frozen=True)
class WriteResult:
    index: MetadataIndex

    size_in_bytes: int
    storage_data: Any


[docs]class StorageWriter(abc.ABC): """ 由 ``save_state_dict`` 用于写入存储的接口。 一个 StorageWriter 实例在分布式检查点中既充当协调者又充当跟随者。作为初始化的一部分,每个实例都会被告知其角色。 子类应期望以下调用序列。 0) (所有 rank) 如果用户传递了有效的 checkpoint_id,则设置 checkpoint_id。 1) (所有 rank) set_up_storage_writer() 2) (所有 rank) prepare_local_plan() 3) (协调者) prepare_global_plan() 4) (所有 rank) write_data() 5) (协调者) finish() """
[docs] @abc.abstractmethod def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None: """ 调用以指示即将发生全新的检查点写入。 如果用户为此检查点写入设置了 checkpoint_id,则可能会存在 checkpoint_id。checkpoint_id 的含义取决于存储。它可以是文件夹/文件的路径或键值存储的键。 参数: checkpoint_id (Union[str, os.PathLike, None]): 此检查点实例的 ID。checkpoint_id 的含义取决于存储。它可以是文件夹或文件的路径。 如果存储更像是一个键值存储,它也可以是一个键。 (默认值: ``None``) """ ...
[docs] @abc.abstractmethod def set_up_storage_writer(self, is_coordinator: bool) -> None: """ 初始化此实例。 参数: is_coordinator (bool): 此实例是否负责协调检查点。 """ pass
[docs] @abc.abstractmethod def prepare_local_plan(self, plan: SavePlan) -> SavePlan: """ 执行存储特定的本地规划。 虽然此方法可以生成一个完全不同的计划,但推荐的方法是将存储特定数据存储在 SavePlan::storage_data 中。 参数: plan (SavePlan): 正在使用的 ``SavePlanner`` 的本地计划。 返回: 存储本地规划后的转换 ``SavePlan`` """ pass
[docs] @abc.abstractmethod def prepare_global_plan(self, plans: List[SavePlan]) -> List[SavePlan]: """ 执行存储的集中规划。 此方法仅在协调者实例上调用。 虽然此方法可以生成一个完全不同的计划,但推荐的方法是将存储特定数据存储在 SavePlan::storage_data 中。 参数: plans: 一个 ``SavePlan`` 实例列表,每个 rank 一个。 返回: 存储全局规划后的转换 ``SavePlan`` 列表 """ pass
[docs] @abc.abstractmethod def write_data( self, plan: SavePlan, planner: SavePlanner ) -> Future[List[WriteResult]]: """ 使用 ``planner`` 从 ``plan`` 中写入所有项目以解析数据。 子类应调用 ``SavePlanner::resolve_data`` 从计划中的每个项目获取对要写入的底层对象的访问权限。 子类应懒惰地调用 `resolve_data`,因为它可能会分配内存。对于张量,请做出以下假设: - 它们可能位于任何设备上,包括与 ``WriteItem::tensor_data`` 不匹配的设备 - 它们可能是视图或不连续的。只需保存投影。 参数: plan (SavePlan): 要执行的保存计划。 planner (SavePlanner): 用于解析项目到数据的计划对象。 返回: 一个完成时返回 WriteResult 列表的 future """ pass
[docs] @abc.abstractmethod def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: """ 写入元数据并将当前检查点标记为成功。 用于序列化 `metadata` 的实际格式/模式是实现细节。唯一的要求是它可以在相同的对象图中恢复。 参数: metadata (Metadata): 新检查点的元数据 results: 所有 rank 的 WriteResults 列表。 返回: None """ pass
[docs] @classmethod @abc.abstractmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: """ 检查给定的 checkpoint_id 是否受存储支持。这使我们能够启用自动存储选择。 """ ...
[docs]class StorageReader<