Shortcuts

torch.distributed.checkpoint.default_planner 的源代码

```html
# 版权所有 (c) Meta Platforms, Inc. 及其附属公司

import dataclasses
import io
import logging
import operator
from collections import ChainMap
from functools import reduce
from typing import Any, cast, Dict, List, Optional, Tuple, Union

import torch
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
from torch.distributed.checkpoint._nested_dict import (
    FLATTEN_MAPPING,
    flatten_state_dict,
)
from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
from torch.distributed.checkpoint._traverse import set_element
from torch.distributed.checkpoint.metadata import (
    BytesStorageMetadata,
    ChunkStorageMetadata,
    Metadata,
    MetadataIndex,
    STATE_DICT_TYPE,
    STORAGE_TYPES,
    TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import (
    LoadPlan,
    LoadPlanner,
    ReadItem,
    SavePlan,
    SavePlanner,
    WriteItem,
    WriteItemType,
)
from torch.distributed.checkpoint.planner_helpers import (
    _create_default_metadata_only_plan,
    _create_read_items,
    _create_write_items,
    _init_state_dict,
)
from torch.distributed.checkpoint.utils import find_state_dict_object

logger: logging.Logger = logging.getLogger(__name__)


__all__ = [
    "DefaultSavePlanner",
    "DefaultLoadPlanner",
    "create_default_local_load_plan",
    "create_default_global_load_plan",
    "create_default_local_save_plan",
    "create_default_global_save_plan",
]


# TODO: 更新 default_planner.py 的文档字符串
[docs]class DefaultSavePlanner(SavePlanner): mappings: FLATTEN_MAPPING def __init__( self, flatten_state_dict: bool = True, flatten_sharded_tensors: bool = True, dedup_replicated_tensors: Optional[bool] = None, ) -> None: self.flatten_state_dict = flatten_state_dict self.flatten_sharded_tensors = flatten_sharded_tensors self.mappings = {} if dedup_replicated_tensors is not None: logger.warning( "DefaultSavePlanner 的 `dedup_replicated_tensors` 参数正在被弃用,不再有任何效果。请从调用中删除此参数。" ) def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None: if self.flatten_state_dict: state_dict, self.mappings = flatten_state_dict(state_dict) if self.flatten_sharded_tensors: state_dict = _flatten_sharded_tensors(state_dict) self.state_dict = state_dict self.is_coordinator = is_coordinator def create_local_plan(self) -> SavePlan: plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) if self.flatten_state_dict: plan = dataclasses.replace(plan, planner_data=self.mappings) self.plan = plan return self.plan def create_global_plan( self, all_plans: List[SavePlan] ) -> Tuple[List[SavePlan], Metadata]: all_plans = dedup_save_plans(all_plans) global_plan, metadata = create_default_global_save_plan(all_plans) if self.flatten_state_dict: # | 不适用于 Python 3.8 或更早版本。 # merged_mappings = reduce( # lambda x, y: x | y, (p.planner_data for p in global_plan) # ) planner_data_dict = [p.planner_data for p in global_plan] merged_mappings = dict(ChainMap(*planner_data_dict)) <