torch.distributed.elastic.rendezvous.etcd_rendezvous_backend 的源代码
# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在此源树根目录下的LICENSE文件中找到的BSD风格许可证进行许可。
import binascii
from base64 import b64decode, b64encode
from typing import Optional, Tuple, cast
import urllib3.exceptions # type: ignore[import]
from etcd import Client as EtcdClient # type: ignore[import]
from etcd import (
EtcdAlreadyExist,
EtcdCompareFailed,
EtcdException,
EtcdKeyNotFound,
EtcdResult,
)
from torch.distributed import Store
from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
from .dynamic_rendezvous import RendezvousBackend, Token
from .etcd_store import EtcdStore
from .utils import parse_rendezvous_endpoint
[docs]class EtcdRendezvousBackend(RendezvousBackend):
"""表示一个基于etcd的rendezvous后端。
参数:
client:
用于与etcd通信的``etcd.Client``实例。
run_id:
rendezvous的运行ID。
key_prefix:
在etcd中存储rendezvous状态的路径。
ttl:
rendezvous状态的TTL。如果未指定,默认为两小时。
"""
_DEFAULT_TTL = 7200 # 2小时
_client: EtcdClient
_key: str
_ttl: int
def __init__(
self,
client: EtcdClient,
run_id: str,
key_prefix: Optional[str] = None,
ttl: Optional[int] = None,
) -> None:
if not run_id:
raise ValueError("运行ID必须是一个非空字符串。")
self._client = client
if key_prefix:
self._key = key_prefix + "/" + run_id
else:
self._key = run_id
if ttl and ttl > 0:
self._ttl = ttl
else:
self._ttl = self._DEFAULT_TTL
@property
def name(self) -> str:
"""参见基类。"""
return "etcd-v2"
[docs] def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""参见基类。"""
try:
result = self._client.read(self._key)
except EtcdKeyNotFound:
return None
except (EtcdException, urllib3.exceptions.TimeoutError) as exc:
raise RendezvousConnectionError(
"与etcd的连接失败。查看内部异常以获取详细信息。"
) from exc
return self._decode_state(result)
[docs] def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""参见基类。"""
base64_state = b64encode(state).decode()
kwargs = {}
def get_state():
result = self.get_state()
if result is not None:
tmp = *result, False
# Python 3.6不支持在return语句中解包元组。
return tmp
return None
if token:
try:
token = int(token)
except ValueError:
return get_state()
if token:
kwargs["prevIndex"] = token
else:
kwargs["prevExist"] = False
try:
result = self._client</span