Shortcuts

torch.distributed.elastic.timer.api 的源代码

# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在源代码树的根目录中的LICENSE文件中找到的BSD风格许可证进行许可。
import abc
import logging
import threading
import time
from contextlib import contextmanager
from inspect import getframeinfo, stack
from typing import Any, Dict, List, Optional, Set

__all__ = ['TimerRequest', 'TimerClient', 'RequestQueue', 'TimerServer', 'configure', 'expires']

log = logging.getLogger(__name__)

[docs]class TimerRequest: """ 表示倒计时计时器获取和释放的数据对象,用于在``TimerClient``和``TimerServer``之间使用。 负的``expiration_time``应被解释为“释放”请求。 .. 注意:: ``worker_id``的类型是实现特定的。 它是TimerServer和TimerClient实现用于唯一标识工作者的任何内容。 """ __slots__ = ["worker_id", "scope_id", "expiration_time"] def __init__(self, worker_id: Any, scope_id: str, expiration_time: float): self.worker_id = worker_id self.scope_id = scope_id self.expiration_time = expiration_time def __eq__(self, other): if isinstance(other, TimerRequest): return ( self.worker_id == other.worker_id and self.scope_id == other.scope_id and self.expiration_time == other.expiration_time ) return False
[docs]class TimerClient(abc.ABC): """ 客户端库,通过与TimerServer通信来获取和释放倒计时计时器。 """
[docs] @abc.abstractmethod def acquire(self, scope_id: str, expiration_time: float) -> None: """ 为持有此客户端对象的工作者获取一个计时器,给定scope_id和expiration_time。通常向TimerServer注册计时器。 """ pass
[docs] @abc.abstractmethod def release(self, scope_id: str): """ 释放此客户端代表的工作者的``scope_id``的计时器。调用此方法后,该范围内的倒计时计时器将不再有效。 """ pass
class RequestQueue(abc.ABC): """ 持有计时器获取/释放请求的消费者队列 """ @abc.abstractmethod def size(self) -> int: """ 返回调用此方法时队列的大小。请注意,当调用``get``时,队列的大小可能已经增加。队列的大小在调用``get``方法之前不应减少。也就是说,以下断言应成立: size = q.size() res = q.get(size, timeout=0) assert size == len(res) -- 或者 -- size = q.size() res = q.get(size * 2, timeout=1) assert size <= len(res) <= size * 2 """ pass @abc.abstractmethod def get(self, size: int, timeout: float) -> List[TimerRequest]: """ 以阻塞方式获取最多``size``个计时器请求(不超过``timeout``秒)。 """ pass
[docs]class TimerServer(abc.ABC): """ 监控活动计时器并在适当的时间使它们过期的实体。此服务器负责收割已过期计时器的工作者。 """ def __init__( self, request_queue: RequestQueue, max_interval: float, daemon: bool = True ): """ :param request_queue: 消费者``RequestQueue`` :param max_interval: 等待request_queue中项目的最长时间(以秒为单位) :param daemon: 是否将看门狗线程作为守护进程运行 """ super().__init__() self._request_queue = request_queue self._max_interval = max_interval self._daemon = daemon self._watchdog_thread: Optional[threading.Thread] = None self._stop_signaled = False
[docs] @abc.abstractmethod def register_timers(self, timer_requests: List<span class="