torch.distributed.elastic.timer.local_timer 的源代码
# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在源代码树的根目录中的LICENSE文件中找到的BSD风格许可证进行许可。
import logging
import multiprocessing as mp
import os
import signal
import time
from queue import Empty
from typing import Any, Dict, List, Set, Tuple
from .api import RequestQueue, TimerClient, TimerRequest, TimerServer
__all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer']
log = logging.getLogger(__name__)
[docs]class LocalTimerClient(TimerClient):
"""
``LocalTimerServer`` 的客户端。此客户端旨在与运行在同一主机上的 ``LocalTimerServer`` 一起使用,并使用 pid 来唯一标识一个工作进程。这在每台主机上为每个 GPU 生成一个子进程(训练器)的情况下特别有用,尤其是在具有多个 GPU 设备的主机上。
"""
def __init__(self, mp_queue):
super().__init__()
self._mp_queue = mp_queue
def acquire(self, scope_id, expiration_time):
pid = os.getpid()
acquire_request = TimerRequest(pid, scope_id, expiration_time)
self._mp_queue.put(acquire_request)
def release(self, scope_id):
pid = os.getpid()
release_request = TimerRequest(pid, scope_id, -1)
self._mp_queue.put(release_request)
class MultiprocessingRequestQueue(RequestQueue):
"""
由 python ``multiprocessing.Queue`` 支持的 ``RequestQueue``
"""
def __init__(self, mp_queue: mp.Queue):
super().__init__()
self._mp_queue = mp_queue
def size(self) -> int:
return self._mp_queue.qsize()
def get(self, size, timeout: float) -> List[TimerRequest]:
requests = []
wait = timeout
for _ in range(0, size):
start = time.time()
try:
r = self._mp_queue.get(block=True, timeout=wait)
except Empty:
break
requests.append(r)
wait = wait - (time.time() - start)
if wait <= 0:
break
return requests
[docs]class LocalTimerServer(TimerServer):
"""
与 ``LocalTimerClient`` 一起工作的服务器。客户端应该是运行此服务器的父进程的子进程。作业中的每台主机都应在本地启动自己的计时器服务器,每个服务器实例管理本地工作进程的计时器(运行在同一主机上的进程)。
"""
def __init__(
self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True
):
super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon)
self._timers: Dict[Tuple[Any, str], TimerRequest] = {}
def register_timers(self, timer_requests: List[TimerRequest]) -> None:
for request in timer_requests:
pid = request.worker_id
scope_id = request.scope_id
expiration_time = request.expiration_time
# 负的过期时间是释放调用的代理
if expiration_time < 0:
self._timers.pop((pid, scope_id), None)
else:
self._timers[(pid<span