torch.distributed.elastic.multiprocessing.errors 的源代码
#!/usr/bin/env python3
# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在源代码树的根目录中的LICENSE文件中找到的BSD风格许可证进行许可。
"""
分布式PyTorch作业中的每个主机都运行一个TorchElastic代理,
以及多个工作进程(作为TorchElastic代理的子进程)。
由于工作进程是由用户提供的(您的PyTorch脚本/作业),TorchElastic
有一种方法可以通过代理将训练器上的错误传播到调度器,
最终通知最终用户作业的状态并应用任何重试策略。
TorchElastic将错误分为3类:
+----------------+----------------+--------------------------------------------------------------+
| 类别 | 子类别 | 描述 |
+================+================+==============================================================+
| 用户错误 | 输入错误 | 对TorchElastic API的无效输入(例如,最小节点数 > 最大节点数) |
| +----------------+--------------------------------------------------------------+
| | 工作进程失败 | 工作子进程上的任何失败 |
+----------------+----------------+--------------------------------------------------------------+
| 平台错误 | n/a | 由代理引起的失败 |
+----------------+----------------+--------------------------------------------------------------+
| 基础设施错误 | n/a | 代理和工作进程领域之外的失败 |
| | | (例如,主机故障) |
+----------------+----------------+--------------------------------------------------------------+
除了“工作进程失败”之外的所有错误都从代理进程中规范地引发,
或隐式或显式地使代理进程崩溃。因此,标准的语言(python)提供的异常处理策略适用。
工作进程失败是特殊的,因为异常/失败源自与代理不同的进程,
因此需要在进程间传播错误(例如,代理不能简单地``try-catch``在工作进程上引发的异常)。
TorchElastic代理使用 :func:`torch.distributed.elastic.multiprocessing.start_processes`
启动工作进程,该进程内置了基于文件的进程间错误传播。
任何用 :func:`record` 装饰的函数或二进制入口点
将把未捕获的异常(带有跟踪信息)写入由环境变量 ``TORCHELASTIC_ERROR_FILE`` 指定的文件。
父进程(例如代理)在启动的每个子进程上设置此环境变量,
然后聚合所有子进程的错误文件,并传播具有**最小**时间戳的文件(例如,**第一个**错误)。
"""
import json
import os
import signal
import socket
import time
import warnings
from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
from string import Template
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from torch.distributed.elastic.utils.logging import get_logger
from .error_handler import ErrorHandler # noqa: F401
from .handlers import get_error_handler # noqa: F401
__all__ = ["ProcessFailure", "ChildFailedError", "record", "ErrorHandler", "get_error_handler"]
log = get_logger(__name__)
JSON = Dict
_EMPTY_ERROR_DATA = {"message": ""}
_NOT_AVAILABLE = ""
T = TypeVar("T")
[docs]@dataclass
class ProcessFailure:
"""
表示失败进程的结果。当工作进程失败时,它可能会将失败的根本原因记录到文件中。
尝试从提供的 ``error_file`` 中读取失败时间戳,
如果 ``error_file`` 不存在,则时间戳为当前时间戳(自纪元以来的秒数)。
``message`` 字段是失败的简要说明。如果
错误文件存在,则消息从错误文件中获取。
否则,将根据失败签名生成一个消息。
.. 注意:: 假设 ``error_file`` 由
``torch.distributed.elastic.multiprocessing.errors.error_handler.ErrorHandler`` 写入。
否则,行为未定义。
"""
local_rank: int
pid: int
exitcode: int
error_file: str
error_file_data: JSON = field(init=False)
message: str = field(init=False)
timestamp: int = field(init=False)
def __post_init__(self):
self.error_file_data = _EMPTY_ERROR_DATA
if os.path.isfile(self.error_file):
try:
with open(self.error_file) as fp:
self.error_file_data = json.load(fp)
log.debug(
"用户进程失败,错误数据: %s", json.dumps(self.error_file_data, indent=2)
)
self.message, self.timestamp = self._get_error_data(
self.error_file_data
)
except Exception:
log.exception("解析回复文件失败: %s", self.error_file)
raise
else:
self._set_no_reply_file()
# 如果没有消息,则生成一个信息性消息
if not self.message:
# 信号通常不会生成错误文件消息
if self.exitcode < 0:
self.message = (
f"信号 {-self.exitcode} ({self.signal_name()})"
f" 由PID {self.pid} 接收"
)
else:
self.</span