torch.distributed.elastic.timer.file_based_local_timer 的源代码
# 版权所有 (c) Meta Platforms, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在此源代码树的根目录中的LICENSE文件中找到的BSD风格许可证进行许可。
import io
import json
import logging
import os
import select
import signal
import sys
import threading
import time
from typing import Callable, Dict, List, Optional, Set, Tuple
from torch.distributed.elastic.timer.api import TimerClient, TimerRequest
__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"]
log = logging.getLogger(__name__)
class FileTimerRequest(TimerRequest):
"""
表示倒计时计时器获取和释放的数据对象,
用于在 ``FileTimerClient`` 和 ``FileTimerServer`` 之间使用。
负 ``expiration_time`` 应解释为“释放”请求。
``signal`` 是从服务器进程中收割工作进程的信号。
"""
__slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"]
def __init__(self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0) -> None:
self.version = 1
self.worker_pid = worker_pid
self.scope_id = scope_id
self.expiration_time = expiration_time
self.signal = signal
def __eq__(self, other) -> bool:
if isinstance(other, FileTimerRequest):
return (
self.version == other.version
and self.worker_pid == other.worker_pid
and self.scope_id == other.scope_id
and self.expiration_time == other.expiration_time
and self.signal == other.signal
)
return False
def to_json(self) -> str:
return json.dumps(
{
"version": self.version,
"pid": self.worker_pid,
"scope_id": self.scope_id,
"expiration_time": self.expiration_time,
"signal": self.signal
},
)
[docs]class FileTimerClient(TimerClient):
"""
``FileTimerServer`` 的客户端。此客户端旨在与 ``FileTimerServer`` 在同一主机上运行,并使用 pid 唯一标识工作进程。
此客户端使用命名管道将计时器请求发送到 ``FileTimerServer``。此客户端是生产者,而 ``FileTimerServer`` 是消费者。多个客户端可以与同一个 ``FileTimerServer`` 一起工作。
参数:
file_path: str, FIFO 特殊文件的路径。``FileTimerServer`` 必须通过调用 os.mkfifo() 创建它。
signal: signal, 用于终止进程的信号。使用负信号或零信号不会终止进程。
"""
def __init__(self, file_path: str, signal=(signal.SIGKILL if sys.platform != "win32" else
signal.CTRL_C_EVENT)) -> None: # type: ignore[attr-defined]
super().__init__()
self._file_path = file_path
self.signal = signal
def _open_non_blocking(self) -> Optional[io.TextIOWrapper]:
try:
fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK)
return os.fdopen(fd, "wt")
except Exception:
return None
def _send_request(self, request: FileTimerRequest) -> None:
# 服务器可能已经崩溃或可能尚未启动。
# 在这种情况下,以阻塞模式调用 open() 会阻塞客户端。
# 为了避免这种问题,以非阻塞模式打开它,如果服务器不存在,将引发 OSError。
file = self._open_non_blocking()
if file is None:
raise BrokenPipeError("无法发送 FileTimerRequest,因为 FileTimerServer 不可用。")
with file:
json_request = request.to_json()
# 写入不超过 select.PIPE_