torch.distributed.checkpoint.filesystem 的源代码
import collections
import dataclasses
import io
import os
import pickle
import queue
import threading
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import (
Callable,
cast,
Dict,
Generator,
IO,
Iterable,
Iterator,
List,
Optional,
Tuple,
Union,
)
import torch
from torch import Tensor
from torch._utils import _get_available_device_type, _get_device_module
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.futures import Future
from .metadata import Metadata, MetadataIndex
from .planner import (
LoadItemType,
LoadPlan,
LoadPlanner,
ReadItem,
SavePlan,
SavePlanner,
WriteItem,
WriteItemType,
)
from .storage import StorageReader, StorageWriter, WriteResult
from .utils import _create_file_view
__all__ = ["FileSystemWriter", "FileSystemReader"]
@dataclass
class _StorageInfo:
"""这是每个条目的存储信息。"""
relative_path: str
offset: int
length: int
@dataclass
class _StoragePrefix:
prefix: str
DEFAULT_SUFFIX = ".distcp"
class _TensorLoader(ABC):
@abstractmethod
def add(self, size: int, obj: object) -> None:
pass
@abstractmethod
def start_loading(self) -> None:
pass
@abstractmethod
def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
pass
class _SerialCpuLoader(_TensorLoader):
def __init__(self, resolve_fun: Callable) -> None:
self.resolve_fun = resolve_fun
self.items: List[Tuple[int, object]] = []
def add(self, size: int, obj: object) -> None:
self.items.append((size, obj))
def start_loading(self) -> None:
pass
def values(self) -> Iterator[Tuple[torch.Tensor, object]]:
for _, obj in self.items:
tensor = self.resolve_fun(obj).detach()
tensor = tensor.cpu()
if tensor.storage().size() != tensor.numel():
tensor = tensor.clone()
yield (
tensor,
obj,
)
class _OverlappingCpuLoader(_TensorLoader):
def __init__(
self,
resolve_fun: Callable,
stream: Optional[torch.Stream] = None,
inflight_threshhold: int = 1_000_000,
) -> None:
self.resolve_fun = resolve_fun
self.items: List[Tuple[int, object]] = []
self.inflight_threshhold = inflight_threshhold
self.in_flight_data = 0
self.current_items: collections.deque = collections.deque()
self.idx = <