torch.distributed.checkpoint.default_planner 的源代码
```html
# 版权所有 (c) Meta Platforms, Inc. 及其附属公司 import dataclasses import io import logging import operator from collections import ChainMap from functools import reduce from typing import Any, cast, Dict, List, Optional, Tuple, Union import torch from torch.distributed._shard._utils import narrow_tensor_by_index from torch.distributed._tensor import DTensor from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans from torch.distributed.checkpoint._nested_dict import ( FLATTEN_MAPPING, flatten_state_dict, ) from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors from torch.distributed.checkpoint._traverse import set_element from torch.distributed.checkpoint.metadata import ( BytesStorageMetadata, ChunkStorageMetadata, Metadata, MetadataIndex, STATE_DICT_TYPE, STORAGE_TYPES, TensorStorageMetadata, ) from torch.distributed.checkpoint.planner import ( LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem, WriteItemType, ) from torch.distributed.checkpoint.planner_helpers import ( _create_default_metadata_only_plan, _create_read_items, _create_write_items, _init_state_dict, ) from torch.distributed.checkpoint.utils import find_state_dict_object logger: logging.Logger = logging.getLogger(__name__) __all__ = [ "DefaultSavePlanner", "DefaultLoadPlanner", "create_default_local_load_plan", "create_default_global_load_plan", "create_default_local_save_plan", "create_default_global_save_plan", ] # TODO: 更新 default_planner.py 的文档字符串[docs]class DefaultSavePlanner(SavePlanner): mappings: FLATTEN_MAPPING def __init__( self, flatten_state_dict: bool = True, flatten_sharded_tensors: bool = True, dedup_replicated_tensors: Optional[bool] = None, ) -> None: self.flatten_state_dict = flatten_state_dict self.flatten_sharded_tensors = flatten_sharded_tensors self.mappings = {} if dedup_replicated_tensors is not None: logger.warning( "DefaultSavePlanner 的 `dedup_replicated_tensors` 参数正在被弃用,不再有任何效果。请从调用中删除此参数。" ) def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None: if self.flatten_state_dict: state_dict, self.mappings = flatten_state_dict(state_dict) if self.flatten_sharded_tensors: state_dict = _flatten_sharded_tensors(state_dict) self.state_dict = state_dict self.is_coordinator = is_coordinator def create_local_plan(self) -> SavePlan: plan = create_default_local_save_plan(self.state_dict, self.is_coordinator) if self.flatten_state_dict: plan = dataclasses.replace(plan, planner_data=self.mappings) self.plan = plan return self.plan def create_global_plan( self, all_plans: List[SavePlan] ) -> Tuple[List[SavePlan], Metadata]: all_plans = dedup_save_plans(all_plans) global_plan, metadata = create_default_global_save_plan(all_plans) if self.flatten_state_dict: # | 不适用于 Python 3.8 或更早版本。 # merged_mappings = reduce( # lambda x, y: x | y, (p.planner_data for p in global_plan) # ) planner_data_dict = [p.planner_data for p in global_plan] merged_mappings = dict(ChainMap(*planner_data_dict)) <