torch.distributed.elastic.agent.server.api 的源代码
# mypy: ignore-errors
# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在此源代码树根目录下的 LICENSE 文件中找到的 BSD 风格许可证进行许可。
import abc
import functools
import json
import os
import signal
import socket
import time
import traceback
import warnings
from contextlib import closing
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch.distributed.elastic.rendezvous as rdzv
import torch.distributed.elastic.utils.store as store_util
from torch.distributed.elastic.rendezvous import RendezvousGracefulExitError
from torch.distributed import Store
from torch.distributed.elastic.events import Event, EventSource, record
from torch.distributed.elastic.metrics import prof, put_metric
from torch.distributed.elastic.multiprocessing import (
ProcessFailure,
SignalException,
)
from torch.distributed.elastic.utils.logging import get_logger
__all__ = [
"WorkerSpec",
"Worker",
"WorkerState",
"WorkerGroup",
"RunResult",
"ElasticAgent",
"SimpleElasticAgent",
]
_TERMINAL_STATE_SYNC_ID = "torchelastic/agent/terminal_state"
DEFAULT_ROLE = "default"
log = get_logger(__name__)
[docs]@dataclass
class WorkerSpec:
"""关于特定类型工作者的蓝图信息。
对于给定的角色,只能存在一个工作者规范。
工作者规范预计在所有节点(机器)上都是同质的,
即每个节点为特定规范运行相同数量的工作者。
参数:
role: 用户定义的具有此规范的工作者的角色
local_world_size: 要运行的本地工作者数量
fn: (已弃用,请改用 entrypoint)
entrypoint: 工作者函数或命令
args: 传递给 ``entrypoint`` 的参数
rdzv_handler: 处理此组工作者的 rdzv
max_restarts: 工作者的最大重试次数
monitor_interval: 每 ``n`` 秒监控工作者状态
master_port: 在 rank 0 上运行 c10d store 的固定端口
如果未指定,则选择一个随机空闲端口
master_addr: 在 rank 0 上运行 c10d store 的固定 master_addr
如果未指定,则选择代理 rank 0 的主机名
redirects: 将标准流重定向到文件,
通过传递映射选择性地重定向特定本地 rank
tee: 将指定的标准流(s) 同时输出到控制台和文件,
通过传递映射选择性地为特定本地 rank 进行 tee,
优先于 ``redirects`` 设置。
"""
role: str
local_world_size: int
rdzv_handler: rdzv.RendezvousHandler
fn: Optional[Callable] = None
# TODO @kiuk - 将 entrypoint 设为必填字段
entrypoint: Union[Callable, str, None] = None
args: Tuple = ()
max_restarts: int = 3
monitor_interval: float = 30.0
master_port: Optional[int] = None
master_addr: Optional[str] = None
local_addr: Optional[str] = None
def __post_init__(self):
assert self.local_world_size > 0
assert self.monitor_interval > 0
if self.fn:
warnings.warn(
"WorkerSpec.fn 将被弃用,"
"请改用 WorkerSpec.entrypoint",
category=DeprecationWarning,
)
self.entrypoint = self.fn
assert self.entrypoint
[docs] def get_entrypoint_name(self):
"""获取入口点名称。
如果 entrypoint 是一个函数(例如 ``Callable``),返回其 ``__qualname__``
否则,如果 entrypoint 是一个二进制文件(例如 ``str``),返回二进制文件名。
"""
if isinstance(self.entrypoint, str):
return os.path.basename(self.entrypoint)
else:
assert self.entrypoint is not None
return self.entrypoint.__qualname__
[docs]class Worker:
"""一个工作者实例。
与 ``WorkerSpec`` 表示工作者的规范不同。``Worker`` 是从 ``WorkerSpec`` 创建的。
``Worker`` 与 ``WorkerSpec`` 的关系类似于对象与类的关系。
``id`` 的解释由 ``ElasticAgent`` 的具体实现决定。对于本地代理,它可能是工作者的 ``pid (int)``,
对于远程代理,它可能是 ``host:port (string)``。
参数:
id (Any): 唯一标识一个工作者(由代理解释)
local_rank (int): 工作者的本地排名
global_rank (int): 工作者的全局排名
role_rank (int): 具有相同角色的所有工作者中的排名
world_size (int): 工作者总数