Shortcuts

torch.multiprocessing.spawn 的源代码

```html
import logging
import multiprocessing
import multiprocessing.connection
import os
import pickle
import signal
import sys
import tempfile
import time
import warnings
from typing import Optional

from . import _prctl_pr_set_pdeathsig  # type: ignore[attr-defined]

log = logging.getLogger(__name__)


class ProcessException(Exception):
    __slots__ = ["error_index", "error_pid"]

    def __init__(self, msg: str, error_index: int, pid: int):
        super().__init__(msg)
        self.msg = msg
        self.error_index = error_index
        self.pid = pid

    def __reduce__(self):
        return type(self), (self.msg, self.error_index, self.pid)


class ProcessRaisedException(ProcessException):
    """当进程因代码引发的异常而失败时引发的异常。"""

    def __init__(
        self,
        msg: str,
        error_index: int,
        error_pid: int,
    ):
        super().__init__(msg, error_index, error_pid)


class ProcessExitedException(ProcessException):
    """当进程因信号或以特定代码退出而失败时引发的异常。"""

    __slots__ = ["exit_code"]

    def __init__(
        self,
        msg: str,
        error_index: int,
        error_pid: int,
        exit_code: int,
        signal_name: Optional[str] = None,
    ):
        super().__init__(msg, error_index, error_pid)
        self.exit_code = exit_code
        self.signal_name = signal_name

    def __reduce__(self):
        return (
            type(self),
            (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
        )


def _wrap(fn, i, args, error_file):
    # prctl(2) 是一个特定于 Linux 的系统调用。
    # 在其他系统上,以下函数调用无效。
    # 这是为了确保非守护进程的子进程可以在父进程终止之前终止。
    _prctl_pr_set_pdeathsig(signal.SIGINT)

    try:
        fn(i, *args)
    except KeyboardInterrupt:
        pass  # SIGINT; 被父进程杀死,什么都不做
    except Exception:
        # 将异常传播到父进程,保留原始的回溯信息
        import traceback

        with open(error_file, "wb") as fh:
            pickle.dump(traceback.format_exc(), fh)
        sys.exit(1)


class ProcessContext:
    def __init__(self, processes, error_files):
        self.error_files = error_files
        self.processes = processes
        self.sentinels = {
            process.sentinel: index for index, process in enumerate(processes)
        }

    def pids(self):
        return [int(process.pid) for process in self.processes]

    def join(self, timeout=None):
        r"""在 spawn 上下文中加入一个或多个进程。

        尝试在此 spawn 上下文中加入一个或多个进程。
        如果其中一个以非零退出状态退出,此函数将杀死剩余的进程并引发一个异常,说明第一个进程退出的原因。

        如果所有进程都成功加入,则返回 ``True``,
        如果还有更多进程需要加入,则返回 ``False``。

        参数:
            timeout (float): 在放弃等待之前等待的时间。
        """
        # 确保即使我们已经完成,此函数也可以被调用。
        if len(self.sentinels) == 0:
            return True

        # 等待任何进程失败或所有进程成功。
优云智算