Shortcuts

torch.distributed.checkpoint.stateful 的源代码

from typing import Any, Dict, runtime_checkable, TypeVar

from typing_extensions import Protocol


__all__ = ["Stateful", "StatefulT"]


[docs]@runtime_checkable class Stateful(Protocol): """ 用于可以被检查点和恢复的对象的状态协议。 """
[docs] def state_dict(self) -> Dict[str, Any]: """ 对象应返回其状态字典表示为字典。 此函数的输出将被检查点化,并在`load_state_dict()`中恢复。 .. 警告:: 由于恢复检查点的就地性质,此函数在`torch.distributed.checkpoint.load`期间也会被调用。 返回: Dict: 对象的状态字典 """ ...
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ 从提供的状态字典中恢复对象的状态。 参数: state_dict: 要从中恢复的状态字典 """ ...
StatefulT = TypeVar("StatefulT", bound=Stateful)