torch.distributed.elastic.timer.api 的源代码
# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在源代码树的根目录中的LICENSE文件中找到的BSD风格许可证进行许可。
import abc
import logging
import threading
import time
from contextlib import contextmanager
from inspect import getframeinfo, stack
from typing import Any, Dict, List, Optional, Set
__all__ = ['TimerRequest', 'TimerClient', 'RequestQueue', 'TimerServer', 'configure', 'expires']
log = logging.getLogger(__name__)
[docs]class TimerRequest:
"""
表示倒计时计时器获取和释放的数据对象,用于在``TimerClient``和``TimerServer``之间使用。
负的``expiration_time``应被解释为“释放”请求。
.. 注意:: ``worker_id``的类型是实现特定的。
它是TimerServer和TimerClient实现用于唯一标识工作者的任何内容。
"""
__slots__ = ["worker_id", "scope_id", "expiration_time"]
def __init__(self, worker_id: Any, scope_id: str, expiration_time: float):
self.worker_id = worker_id
self.scope_id = scope_id
self.expiration_time = expiration_time
def __eq__(self, other):
if isinstance(other, TimerRequest):
return (
self.worker_id == other.worker_id
and self.scope_id == other.scope_id
and self.expiration_time == other.expiration_time
)
return False
[docs]class TimerClient(abc.ABC):
"""
客户端库,通过与TimerServer通信来获取和释放倒计时计时器。
"""
[docs] @abc.abstractmethod
def acquire(self, scope_id: str, expiration_time: float) -> None:
"""
为持有此客户端对象的工作者获取一个计时器,给定scope_id和expiration_time。通常向TimerServer注册计时器。
"""
pass
[docs] @abc.abstractmethod
def release(self, scope_id: str):
"""
释放此客户端代表的工作者的``scope_id``的计时器。调用此方法后,该范围内的倒计时计时器将不再有效。
"""
pass
class RequestQueue(abc.ABC):
"""
持有计时器获取/释放请求的消费者队列
"""
@abc.abstractmethod
def size(self) -> int:
"""
返回调用此方法时队列的大小。请注意,当调用``get``时,队列的大小可能已经增加。队列的大小在调用``get``方法之前不应减少。也就是说,以下断言应成立:
size = q.size()
res = q.get(size, timeout=0)
assert size == len(res)
-- 或者 --
size = q.size()
res = q.get(size * 2, timeout=1)
assert size <= len(res) <= size * 2
"""
pass
@abc.abstractmethod
def get(self, size: int, timeout: float) -> List[TimerRequest]:
"""
以阻塞方式获取最多``size``个计时器请求(不超过``timeout``秒)。
"""
pass
[docs]class TimerServer(abc.ABC):
"""
监控活动计时器并在适当的时间使它们过期的实体。此服务器负责收割已过期计时器的工作者。
"""
def __init__(
self, request_queue: RequestQueue, max_interval: float, daemon: bool = True
):
"""
:param request_queue: 消费者``RequestQueue``
:param max_interval: 等待request_queue中项目的最长时间(以秒为单位)
:param daemon: 是否将看门狗线程作为守护进程运行
"""
super().__init__()
self._request_queue = request_queue
self._max_interval = max_interval
self._daemon = daemon
self._watchdog_thread: Optional[threading.Thread] = None
self._stop_signaled = False
[docs] @abc.abstractmethod
def register_timers(self, timer_requests: List<span class="