Shortcuts

torch.distributed.checkpoint.state_dict 的源代码

```html
import contextlib
import functools
import gc
from dataclasses import asdict, dataclass, field
from itertools import chain
from typing import (
    Any,
    Callable,
    cast,
    Dict,
    Iterable,
    List,
    no_type_check,
    Optional,
    Set,
    Tuple,
    Union,
)

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import (
    _gather_state_dict,
    _offload_state_dict_to_cpu,
)
from torch.distributed._tensor import DTensor
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    _CHECKPOINT_PREFIX,
)
from torch.distributed.fsdp import (
    FullOptimStateDictConfig,
    FullStateDictConfig,
    FullyShardedDataParallel as FSDP,
    OptimStateDictConfig,
    ShardedOptimStateDictConfig,
    ShardedStateDictConfig,
    StateDictConfig,
    StateDictType,
)
from torch.distributed.fsdp._common_utils import (
    _get_module_fsdp_state_if_fully_sharded_module,
    FSDP_WRAPPED_MODULE,
)
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel as DDP


FLAT_PARAM = "_flat_param"
PG = "param_groups"
PG_PREFIX = f"{PG}."
STATE = "state"
STATE_PREFIX = f"{STATE}."
PARAMS = "params"
FQNS_T = Set[str]

_patched_state_dict: Set[Callable] = set()


PrimitiveType = Union[DTensor, ShardedTensor, torch.Tensor, int, float, str]
ValueType = Union[
    PrimitiveType, List[PrimitiveType], Tuple[PrimitiveType], Dict[str, "ValueType"]
]
DictValueType = Dict[str, ValueType]
ListDictValueType = List[DictValueType]
OptimizerStateType = Dict[str, Union[DictValueType, ListDictValueType]]


@contextlib.contextmanager
def gc_context():
    is_enabled = gc.isenabled()
    gc.disable()
    try:
        yield
    finally:
        # TODO: 添加垃圾回收的详细信息/时间日志
        gc.collect()
        if is_enabled:
            gc.enable()


[docs]@dataclass class StateDictOptions: """ 这个数据类指定了get_state_dict/set_state_dict的工作方式。 - ``full_state_dict``: 如果设置为True,返回的state_dict中的所有张量将被收集。返回的state_dict中不会有ShardedTensor和DTensor。 - ``cpu_offload``: 将所有张量卸载到CPU。为了防止CPU内存溢出,如果``full_state_dict``也为真,则只有rank0会获取state_dict,其他rank将获取空的state_dict。 - ``ignore_frozen_params``: 如果值为True,返回的state_dict将不包含任何冻结参数(即``requires_grad``为False的参数)。默认值为False。 - ``keep_submodule_prefixes``: 当``submodules``不为None时,此选项指示是否保留state_dict键中的子模块前缀。例如,如果子模块为``module.pretrain``,参数的全限定名为``pretrain.layer1.weight``。当此选项为True时,返回的state_dict中参数的键将为``pretrain.layer1.weight``。如果选项为False,键将为``layer1.weight``。 注意,如果``keep_submodule_prefixes``为False,可能会出现FQNs冲突,因此``submodules``中应该只有一个子模块。 - ``strict``: 当``set_state_dict``调用model.load_state_dict()时的``strict``选项。默认值为False。 """ full_state_dict: bool = False cpu_offload: bool = False ignore_frozen_params: bool = False keep_submodule_prefixes: bool = True strict: bool = True
@dataclass class _StateDictInfo(StateDictOptions): fqn_param_mapping: Dict[ Union[str, torch.Tensor], Union[FQNS_T, torch.Tensor] ] = field(default_factory=dict) all_fqns: Set[str] = field(default_factory=set) submodule_prefixes: Set[str] = field(default_factory=set) handle_model: bool = True <
优云智算