Shortcuts

torch.distributed.elastic.timer.local_timer 的源代码

# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在源代码树的根目录中的LICENSE文件中找到的BSD风格许可证进行许可。
import logging
import multiprocessing as mp
import os
import signal
import time
from queue import Empty
from typing import Any, Dict, List, Set, Tuple

from .api import RequestQueue, TimerClient, TimerRequest, TimerServer

__all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer']

log = logging.getLogger(__name__)

[docs]class LocalTimerClient(TimerClient): """ ``LocalTimerServer`` 的客户端。此客户端旨在与运行在同一主机上的 ``LocalTimerServer`` 一起使用,并使用 pid 来唯一标识一个工作进程。这在每台主机上为每个 GPU 生成一个子进程(训练器)的情况下特别有用,尤其是在具有多个 GPU 设备的主机上。 """ def __init__(self, mp_queue): super().__init__() self._mp_queue = mp_queue def acquire(self, scope_id, expiration_time): pid = os.getpid() acquire_request = TimerRequest(pid, scope_id, expiration_time) self._mp_queue.put(acquire_request) def release(self, scope_id): pid = os.getpid() release_request = TimerRequest(pid, scope_id, -1) self._mp_queue.put(release_request)
class MultiprocessingRequestQueue(RequestQueue): """ 由 python ``multiprocessing.Queue`` 支持的 ``RequestQueue`` """ def __init__(self, mp_queue: mp.Queue): super().__init__() self._mp_queue = mp_queue def size(self) -> int: return self._mp_queue.qsize() def get(self, size, timeout: float) -> List[TimerRequest]: requests = [] wait = timeout for _ in range(0, size): start = time.time() try: r = self._mp_queue.get(block=True, timeout=wait) except Empty: break requests.append(r) wait = wait - (time.time() - start) if wait <= 0: break return requests
[docs]class LocalTimerServer(TimerServer): """ 与 ``LocalTimerClient`` 一起工作的服务器。客户端应该是运行此服务器的父进程的子进程。作业中的每台主机都应在本地启动自己的计时器服务器,每个服务器实例管理本地工作进程的计时器(运行在同一主机上的进程)。 """ def __init__( self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True ): super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon) self._timers: Dict[Tuple[Any, str], TimerRequest] = {} def register_timers(self, timer_requests: List[TimerRequest]) -> None: for request in timer_requests: pid = request.worker_id scope_id = request.scope_id expiration_time = request.expiration_time # 负的过期时间是释放调用的代理 if expiration_time < 0: self._timers.pop((pid, scope_id), None) else: self._timers[(pid<span