torch.distributed.elastic.multiprocessing.api 的源代码
#!/usr/bin/env python3
# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在源代码树的根目录中的LICENSE文件中找到的BSD风格许可证进行许可。
import abc
import logging
import os
import re
import shutil
import signal
import subprocess
import sys
import tempfile
import time
from contextlib import nullcontext
from dataclasses import dataclass, field
from enum import IntFlag
from multiprocessing import synchronize
from types import FrameType
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from abc import ABC, abstractmethod
import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
from torch.distributed.elastic.multiprocessing.redirects import (
redirect_stderr,
redirect_stdout,
)
from torch.distributed.elastic.multiprocessing.subprocess_handler import SubprocessHandler, get_subprocess_handler
from torch.distributed.elastic.multiprocessing.tail_log import TailLog
IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"
log = logging.getLogger(__name__)
__all__ = [
"DefaultLogsSpecs",
"SignalException",
"Std",
"to_map",
"RunProcsResult",
"PContext",
"get_std_cm",
"MultiprocessContext",
"SubprocessContext",
]
class SignalException(Exception):
"""
当进程接收到终止信号时,torchelastic代理进程中的终止处理程序会引发此异常。
"""
def __init__(self, msg: str, sigval: signal.Signals) -> None:
super().__init__(msg)
self.sigval = sigval
def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None:
"""终止处理程序,在主进程中引发异常。
当进程接收到终止信号(SIGTERM、SIGINT)时,此终止处理程序将被调用。
它会引发``SignalException``异常,该异常应由用户代码处理。
终止处理程序完成后,Python不会终止进程,因此异常不应被静默忽略,否则进程将永远不会终止。
"""
sigval = signal.Signals(signum)
raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)
def _get_kill_signal() -> signal.Signals:
"""获取终止信号。Unix为SIGKILL,Windows为CTRL_C_EVENT。"""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGKILL
def _get_default_signal() -> signal.Signals:
"""获取默认终止信号。Unix为SIGTERM,Windows为CTRL_C_EVENT。"""
if IS_WINDOWS:
return signal.CTRL_C_EVENT # type: ignore[attr-defined] # noqa: F821
else:
return signal.SIGTERM
def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
actual_keys = set(d.keys())
expected_keys = set(range(nprocs))
if actual_keys != expected_keys:
raise RuntimeError(
f"{what}, local rank mapping mismatch,"
f" expected: {expected_keys}, actual: {actual_keys}"
)
_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$"
_VALUE_REGEX = r"^[0123]$"
class Std(IntFlag):
NONE = 0
OUT = 1
ERR = 2
ALL = OUT | ERR
@classmethod
def from_str(cls, vm: str) -> Union["Std", Dict[int