torch.distributed.checkpoint.state_dict 的源代码
```html
import contextlib import functools import gc from dataclasses import asdict, dataclass, field from itertools import chain from typing import ( Any, Callable, cast, Dict, Iterable, List, no_type_check, Optional, Set, Tuple, Union, ) import torch import torch.distributed as dist import torch.nn as nn from torch.distributed._shard.sharded_tensor import ShardedTensor from torch.distributed._state_dict_utils import ( _gather_state_dict, _offload_state_dict_to_cpu, ) from torch.distributed._tensor import DTensor from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( _CHECKPOINT_PREFIX, ) from torch.distributed.fsdp import ( FullOptimStateDictConfig, FullStateDictConfig, FullyShardedDataParallel as FSDP, OptimStateDictConfig, ShardedOptimStateDictConfig, ShardedStateDictConfig, StateDictConfig, StateDictType, ) from torch.distributed.fsdp._common_utils import ( _get_module_fsdp_state_if_fully_sharded_module, FSDP_WRAPPED_MODULE, ) from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel as DDP FLAT_PARAM = "_flat_param" PG = "param_groups" PG_PREFIX = f"{PG}." STATE = "state" STATE_PREFIX = f"{STATE}." PARAMS = "params" FQNS_T = Set[str] _patched_state_dict: Set[Callable] = set() PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str] ValueType = Union[ PrimitiveType, List[PrimitiveType], Tuple[PrimitiveType], Dict[str, "ValueType"] ] DictValueType = Dict[str, ValueType] ListDictValueType = List[DictValueType] OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]] @contextlib.contextmanager def gc_context(): is_enabled = gc.isenabled() gc.disable() try: yield finally: # TODO: 添加垃圾回收的详细信息/时间日志 gc.collect() if is_enabled: gc.enable()[docs]@dataclass class StateDictOptions: """ 这个数据类指定了get_state_dict/set_state_dict的工作方式。 - ``full_state_dict``: 如果设置为True,返回的state_dict中的所有张量将被收集。返回的state_dict中不会有ShardedTensor和DTensor。 - ``cpu_offload``: 将所有张量卸载到CPU。为了防止CPU内存溢出,如果``full_state_dict``也为真,则只有rank0会获取state_dict,其他rank将获取空的state_dict。 - ``ignore_frozen_params``: 如果值为True,返回的state_dict将不包含任何冻结参数(即``requires_grad``为False的参数)。默认值为False。 - ``keep_submodule_prefixes``: 当``submodules``不为None时,此选项指示是否保留state_dict键中的子模块前缀。例如,如果子模块为``module.pretrain``,参数的全限定名为``pretrain.layer1.weight``。当此选项为True时,返回的state_dict中参数的键将为``pretrain.layer1.weight``。如果选项为False,键将为``layer1.weight``。 注意,如果``keep_submodule_prefixes``为False,可能会出现FQNs冲突,因此``submodules``中应该只有一个子模块。 - ``strict``: 当``set_state_dict``调用model.load_state_dict()时的``strict``选项。默认值为False。 """ full_state_dict: bool = False cpu_offload: bool = False ignore_frozen_params: bool = False keep_submodule_prefixes: bool = True strict: bool = True@dataclass class _StateDictInfo(StateDictOptions): fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] ] = field(default_factory=dict) all_fqns: Set[str] = field(default_factory=set) submodule_prefixes: Set[str] = field(default_factory=set) handle_model: bool = True <