Shortcuts

torch.distributed.checkpoint.fsspec 的源代码

# Mypy 不会尝试推断任何已安装的第三方库的类型。
# mypy: ignore-errors

import io
import os
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional, Union

import fsspec
from fsspec import AbstractFileSystem
from fsspec.core import url_to_fs

from torch.distributed.checkpoint.filesystem import (
    FileSystemBase,
    FileSystemReader,
    FileSystemWriter,
)

__all__ = [
    "FsspecWriter",
    "FsspecReader",
]


class FileSystem(FileSystemBase):
    def __init__(self) -> None:
        self.fs: Optional[AbstractFileSystem] = None

    @contextmanager
    def create_stream(
        self, path: Union[str, os.PathLike], mode: str
    ) -> Generator[io.IOBase, None, None]:
        assert self.fs is not None
        with self.fs.transaction:
            with fsspec.open(str(path), mode) as stream:
                yield stream

    def concat_path(
        self, path: Union[str, os.PathLike], suffix: str
    ) -> Union[str, os.PathLike]:
        return os.path.join(path, suffix)

    def init_path(self, path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
        self.fs, _ = url_to_fs(path)
        return path

    def rename(
        self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]
    ) -> None:
        self.fs.rename(path, new_path)

    def mkdir(self, path: [str, os.PathLike]) -> None:
        self.fs.makedirs(path, exist_ok=True)

    @classmethod
    def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
        if isinstance(checkpoint_id, Path):
            return False

        try:
            url_to_fs(checkpoint_id)
        except ValueError as e:
            return False

        return True


[docs]class FsspecWriter(FileSystemWriter): """ 使用 FFspec 的基本 StorageWriter 实现。 此实现做了以下假设和简化: * 检查点路径是一个空目录或不存在的目录。 * 文件创建是原子的 检查点由每个写请求的一个文件加上一个包含序列化元数据的 `.metadata` 文件组成。 """ def __init__( self, path: Union[str, os.PathLike], single_file_per_rank: bool = True, sync_files: bool = True, thread_count: int = 1, per_thread_copy_ahead: int = 10_000_000, ) -> None: """ 初始化指向 `path` 的写入器。 参数: path: 检查点将被写入的目录。 single_file_per_rank: 每个 rank 生成一个文件,而不是每个张量/blob 生成一个文件。默认为 True。 sync_files : 强制文件同步到永久存储。默认为 True。 thread_count: 用于写入的 IO 线程数。默认为 1。 per_thread_copy_ahead: 在保存之前从 GPU 复制多少字节。默认 10Mb。 注意:如果禁用 sync_files,则在发生故障时无法保证检查点的完整性。 """ super().__init__( path, single_file_per_rank, sync_files, <span class="n