torch.distributed.rpc.functions 的源代码
import functools
[docs]def async_execution(fn):
r"""
用于指示函数返回值保证为 :class:`~torch.futures.Future` 对象的装饰器,并且该函数可以在 RPC 被调用方异步运行。更具体地说,被调用方提取由包装函数返回的 :class:`~torch.futures.Future` 并在该 :class:`~torch.futures.Future` 上安装后续处理步骤作为回调。安装的回调将在 :class:`~torch.futures.Future` 完成时读取其值,并将该值作为 RPC 响应发送回去。这也意味着返回的 :class:`~torch.futures.Future` 仅存在于被调用方,并且永远不会通过 RPC 发送。当包装函数 (``fn``) 的执行需要暂停和恢复时,例如由于包含 :meth:`~torch.distributed.rpc.rpc_async` 或等待其他信号,此装饰器非常有用。
.. 注意:: 要启用异步执行,应用程序必须将此装饰器返回的函数对象传递给 RPC API。如果 RPC 检测到此装饰器安装的属性,它将知道此函数返回一个 ``Future`` 对象并相应地处理它。然而,这并不意味着在定义函数时此装饰器必须是外层的。例如,当与 ``@staticmethod`` 或 ``@classmethod`` 结合使用时,``@rpc.functions.async_execution`` 需要是内部装饰器,以允许目标函数被识别为静态或类函数。此目标函数仍然可以异步执行,因为当访问时,静态或类方法保留了 ``@rpc.functions.async_execution`` 安装的属性。
示例::
返回的 :class:`~torch.futures.Future` 对象可以来自 :meth:`~torch.distributed.rpc.rpc_async`、:meth:`~torch.futures.Future.then` 或 :class:`~torch.futures.Future` 构造函数。下面的示例展示了直接使用由 :meth:`~torch.futures.Future.then` 返回的 :class:`~torch.futures.Future`。
>>> from torch.distributed import rpc
>>>
>>> # 省略设置和关闭 RPC
>>>
>>> # 在所有工作节点上
>>> @rpc.functions.async_execution
>>> def async_add_chained(to, x, y, z):
>>> # 此函数在 "worker1" 上运行,并在通过 `then(cb)` API 安装回调时立即返回。同时,`rpc_async` 到 "worker2" 可以并发运行。当 `rpc_async` 的返回值到达 "worker1" 时,"worker1" 将相应地运行 lambda 函数并设置之前返回的 `Future` 的值,这将触发 RPC 将结果发送回 "worker0"。
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: fut.wait() + z
>>> )
>>>
>>> # 在 worker0 上
>>> # xdoctest: +SKIP
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> async_add_chained,
>>> args=("worker2", torch.ones(2), 1, 1)
>>> )
>>> print(ret) # 打印 tensor([3., 3.])
当与 TorchScript 装饰器结合使用时,此装饰器必须是外层的。
>>> from torch import Tensor
>>> from torch.futures import Future
>>> from torch.distributed import rpc
>>>
>>> # 省略设置和关闭 RPC
>>>
>>> # 在所有工作节点上
>>> @torch.jit.script
>>> def script_add(x: Tensor, y: Tensor) -> Tensor:
>>> return x + y
>>>
>>> @rpc.functions.async_execution
>>> @torch.jit.script
>>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
>>> return rpc.rpc_async(to, script_add, (x, y))
>>>
>>> # 在 worker0 上
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> async_add,
>>> args=("worker2", torch.ones(2), 1)
>>> )
>>> print(ret) # 打印 tensor([2., 2.])
当与静态或类方法结合使用时,此装饰器必须是内部的。
>>> from torch.distributed import rpc
>>>
>>> # 省略设置和关闭 RPC
>>>
>>> # 在所有工作节点上
>>> class AsyncExecutionClass:
>>>
>>> @staticmethod
>>> @rpc.functions.async_execution
>>> def static_async_add(to, x, y, z):
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: fut.wait() + z
>>> )
>>>
>>> @classmethod
>>> @rpc.functions.async_execution
>>> def class_async_add(cls, to, x, y, z):
>>> ret_fut = torch.futures.Future()
>>> rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: ret_fut.set_result(fut.wait() + z)
>>> )
>>> return ret_fut
>>>
>>> @rpc.functions.async_execution
>>> def bound_async_add(self, to, x, y, z):
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
>>> lambda fut: fut.wait() + z
>>> )
>>>
>>> # 在 worker0 上
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> AsyncExecutionClass.static_async_add,
>>> args=("worker2", torch.ones(2), 1, 2)
>>> )
>>> print(ret) # 打印 tensor([4., 4.])
>>>
>>> ret = rpc.rpc_sync(
>>> "worker1",
>>> AsyncExecutionClass.class_async_add,
>>> args=("worker2", torch.ones(2), 1, 2)
>>> )
>>> print(ret) # 打印 tensor([4., 4.])
此装饰器也适用于 RRef 助手,即 :meth:`torch.distributed.rpc.RRef.rpc_sync`、:meth:`torch.distributed.rpc.RRef.rpc_async` 和 :meth:`torch.distributed.rpc.RRef.remote`。
>>> from torch.distributed import rpc
>>>
>>> # 重用上面的 AsyncExecutionClass 类
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
>>> print(ret) # 打印 tensor([4., 4.])
>>>
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
>>> print(ret) # 打印 tensor([4., 4.])
>>>
>>> rref = rpc.remote("worker1", AsyncExecutionClass)
>>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
>>> print(ret) # 打印 tensor([4., 4.])
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
# 无法声明和使用函数对象的属性 (mypy#2087)
wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined]
return wrapper