torch.distributed.elastic.rendezvous.c10d_rendezvous_backend 的源代码
# 版权所有 (c) Facebook, Inc. 及其附属公司。
# 保留所有权利。
#
# 本源代码根据在源代码树的根目录中的LICENSE文件中找到的BSD风格许可证进行许可。
import binascii
import logging
import os
import tempfile
from base64 import b64decode, b64encode
from datetime import timedelta
from typing import Any, Optional, Tuple, cast
from torch.distributed import FileStore, Store, TCPStore
from torch.distributed.elastic.events import (
NodeState,
construct_and_record_rdzv_event,
)
from .api import (
RendezvousConnectionError,
RendezvousError,
RendezvousParameters,
RendezvousStateError,
)
from .dynamic_rendezvous import RendezvousBackend, Token
from .utils import _matches_machine_hostname, parse_rendezvous_endpoint
log = logging.getLogger(__name__)
[docs]class C10dRendezvousBackend(RendezvousBackend):
"""表示一个基于C10d的rendezvous后端。
参数:
store:
用于与C10d存储通信的:py:class:`torch.distributed.Store`实例。
run_id:
rendezvous的运行ID。
"""
# 参见__init__方法中的解释。
_NULL_SENTINEL = "Y2FuaW1hZGFt"
_store: Store
_key: str
def __init__(self, store: Store, run_id: str) -> None:
if not run_id:
raise ValueError("运行ID必须是一个非空字符串。")
self._store = store
self._key = "torch.rendezvous." + run_id
# 存储的读操作会阻塞调用者,直到指定的键可用。这种行为使得存储难以用作常规的键值字典。
#
# 作为一种解决方法,我们最初将一个哨兵值设置为rendezvous状态。每当返回此值时,我们将其视为None。
self._call_store("compare_set", self._key, "", self._NULL_SENTINEL)
@property
def name(self) -> str:
"""参见基类。"""
return "c10d"
[docs] def get_state(self) -> Optional[Tuple[bytes, Token]]:
"""参见基类。"""
base64_state: bytes = self._call_store("get", self._key)
return self._decode_state(base64_state)
[docs] def set_state(
self, state: bytes, token: Optional[Token] = None
) -> Optional[Tuple[bytes, Token, bool]]:
"""参见基类。"""
base64_state_str: str = b64encode(state).decode()
if token:
# 如果我们确定令牌肯定无效,则使用快捷方式。
if not isinstance(token, bytes):
result = self.get_state()
if result is not None:
tmp = *result, False
# Python 3.6不支持在return语句中解包元组。
return tmp
return None
token = token.decode()
else:
token = self._NULL_SENTINEL
base64_state: bytes = self._call_store("compare_set", self._key, token, base64_state_str)
state_token_pair = self._decode_state(base64_state)
if state_token_pair is None:
return None
new_state, new_token = state_token_pair
# C10d Store的compare_set方法没有提供一种简单的方法来确定我们的写入尝试是否成功。作为一种暴力解决方案,我们对本地状态和远程状态进行逐位比较。
return new_state, new_token, new_state == state
def _call_store(self, store_op: str, *args, **kwargs) ->