torch.distributed.checkpoint.stateful 的源代码
from typing import Any, Dict, runtime_checkable, TypeVar
from typing_extensions import Protocol
__all__ = ["Stateful", "StatefulT"]
[docs]@runtime_checkable
class Stateful(Protocol):
"""
用于可以被检查点和恢复的对象的状态协议。
"""
[docs] def state_dict(self) -> Dict[str, Any]:
"""
对象应返回其状态字典表示为字典。
此函数的输出将被检查点化,并在`load_state_dict()`中恢复。
.. 警告::
由于恢复检查点的就地性质,此函数在`torch.distributed.checkpoint.load`期间也会被调用。
返回:
Dict: 对象的状态字典
"""
...
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
从提供的状态字典中恢复对象的状态。
参数:
state_dict: 要从中恢复的状态字典
"""
...
StatefulT = TypeVar("StatefulT", bound=Stateful)