Shortcuts

torch.distributed.elastic.rendezvous.etcd_rendezvous_backend 的源代码

# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在此源树根目录下的LICENSE文件中找到的BSD风格许可证进行许可。

import binascii
from base64 import b64decode, b64encode
from typing import Optional, Tuple, cast

import urllib3.exceptions  # type: ignore[import]
from etcd import Client as EtcdClient  # type: ignore[import]
from etcd import (
    EtcdAlreadyExist,
    EtcdCompareFailed,
    EtcdException,
    EtcdKeyNotFound,
    EtcdResult,
)
from torch.distributed import Store

from .api import RendezvousConnectionError, RendezvousParameters, RendezvousStateError
from .dynamic_rendezvous import RendezvousBackend, Token
from .etcd_store import EtcdStore
from .utils import parse_rendezvous_endpoint


[docs]class EtcdRendezvousBackend(RendezvousBackend): """表示一个基于etcd的rendezvous后端。 参数: client: 用于与etcd通信的``etcd.Client``实例。 run_id: rendezvous的运行ID。 key_prefix: 在etcd中存储rendezvous状态的路径。 ttl: rendezvous状态的TTL。如果未指定,默认为两小时。 """ _DEFAULT_TTL = 7200 # 2小时 _client: EtcdClient _key: str _ttl: int def __init__( self, client: EtcdClient, run_id: str, key_prefix: Optional[str] = None, ttl: Optional[int] = None, ) -> None: if not run_id: raise ValueError("运行ID必须是一个非空字符串。") self._client = client if key_prefix: self._key = key_prefix + "/" + run_id else: self._key = run_id if ttl and ttl > 0: self._ttl = ttl else: self._ttl = self._DEFAULT_TTL @property def name(self) -> str: """参见基类。""" return "etcd-v2"
[docs] def get_state(self) -> Optional[Tuple[bytes, Token]]: """参见基类。""" try: result = self._client.read(self._key) except EtcdKeyNotFound: return None except (EtcdException, urllib3.exceptions.TimeoutError) as exc: raise RendezvousConnectionError( "与etcd的连接失败。查看内部异常以获取详细信息。" ) from exc return self._decode_state(result)
[docs] def set_state( self, state: bytes, token: Optional[Token] = None ) -> Optional[Tuple[bytes, Token, bool]]: """参见基类。""" base64_state = b64encode(state).decode() kwargs = {} def get_state(): result = self.get_state() if result is not None: tmp = *result, False # Python 3.6不支持在return语句中解包元组。 return tmp return None if token: try: token = int(token) except ValueError: return get_state() if token: kwargs["prevIndex"] = token else: kwargs["prevExist"] = False try: result = self._client</span
优云智算