Shortcuts

torch.distributed.elastic.multiprocessing.api 的源代码

#!/usr/bin/env python3

# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在源代码树的根目录中的LICENSE文件中找到的BSD风格许可证进行许可。

import abc
import logging
import os
import re
import shutil
import signal
import subprocess
import sys
import tempfile
import time
from contextlib import nullcontext
from dataclasses import dataclass, field
from enum import IntFlag
from multiprocessing import synchronize
from types import FrameType
from typing import Any, Callable, Dict, Optional, Set, Tuple, Union
from abc import ABC, abstractmethod

import torch.multiprocessing as mp
from torch.distributed.elastic.multiprocessing.errors import ProcessFailure, record
from torch.distributed.elastic.multiprocessing.redirects import (
    redirect_stderr,
    redirect_stdout,
)

from torch.distributed.elastic.multiprocessing.subprocess_handler import SubprocessHandler, get_subprocess_handler
from torch.distributed.elastic.multiprocessing.tail_log import TailLog

IS_WINDOWS = sys.platform == "win32"
IS_MACOS = sys.platform == "darwin"


log = logging.getLogger(__name__)

__all__ = [
    "DefaultLogsSpecs",
    "SignalException",
    "Std",
    "to_map",
    "RunProcsResult",
    "PContext",
    "get_std_cm",
    "MultiprocessContext",
    "SubprocessContext",
]

class SignalException(Exception):
    """
    当进程接收到终止信号时,torchelastic代理进程中的终止处理程序会引发此异常。
    """

    def __init__(self, msg: str, sigval: signal.Signals) -> None:
        super().__init__(msg)
        self.sigval = sigval


def _terminate_process_handler(signum: int, frame: Optional[FrameType]) -> None:
    """终止处理程序,在主进程中引发异常。

    当进程接收到终止信号(SIGTERM、SIGINT)时,此终止处理程序将被调用。
    它会引发``SignalException``异常,该异常应由用户代码处理。
    终止处理程序完成后,Python不会终止进程,因此异常不应被静默忽略,否则进程将永远不会终止。
    """
    sigval = signal.Signals(signum)
    raise SignalException(f"Process {os.getpid()} got signal: {sigval}", sigval=sigval)


def _get_kill_signal() -> signal.Signals:
    """获取终止信号。Unix为SIGKILL,Windows为CTRL_C_EVENT。"""
    if IS_WINDOWS:
        return signal.CTRL_C_EVENT  # type: ignore[attr-defined] # noqa: F821
    else:
        return signal.SIGKILL


def _get_default_signal() -> signal.Signals:
    """获取默认终止信号。Unix为SIGTERM,Windows为CTRL_C_EVENT。"""
    if IS_WINDOWS:
        return signal.CTRL_C_EVENT  # type: ignore[attr-defined] # noqa: F821
    else:
        return signal.SIGTERM


def _validate_full_rank(d: Dict[int, Any], nprocs: int, what: str):
    actual_keys = set(d.keys())
    expected_keys = set(range(nprocs))

    if actual_keys != expected_keys:
        raise RuntimeError(
            f"{what}, local rank mapping mismatch,"
            f" expected: {expected_keys}, actual: {actual_keys}"
        )


_MAPPING_REGEX = r"^(\d:[0123],)*(\d:[0123])$"
_VALUE_REGEX = r"^[0123]$"


class Std(IntFlag):
    NONE = 0
    OUT = 1
    ERR = 2
    ALL = OUT | ERR

    @classmethod
    def from_str(cls, vm: str) -> Union["Std", Dict[int