Shortcuts

torch.distributed.autograd 的源代码


import sys
import torch


def is_available():
    return hasattr(torch._C, "_dist_autograd_init")


if is_available() and not torch._C._dist_autograd_init():
    raise RuntimeError("Failed to initialize torch.distributed.autograd")

if is_available():
    from torch._C._distributed_autograd import (
        get_gradients,
        backward,
        _init,
        _new_context,
        _release_context,
        _get_max_id,
        _is_valid_context,
        _retrieve_context,
        _current_context,
        _get_debug_info,
        DistAutogradContext,
    )


[docs]class context: ''' 用于在使用分布式自动求导时包装前向和后向传递的上下文对象。在 ``with`` 语句中生成的 ``context_id`` 是必需的,用于唯一标识所有工作节点上的分布式后向传递。每个工作节点都存储与此 ``context_id`` 关联的元数据,这对于正确执行分布式自动求导传递是必需的。 示例:: >>> # xdoctest: +SKIP >>> import torch.distributed.autograd as dist_autograd >>> with dist_autograd.context() as context_id: >>> t1 = torch.rand((3, 3), requires_grad=True) >>> t2 = torch.rand((3, 3), requires_grad=True) >>> loss = rpc.rpc_sync("worker1", torch.add, args=(t1, t2)).sum() >>> dist_autograd.backward(context_id, [loss]) ''' def __enter__(self): self.autograd_context = _new_context() return self.autograd_context._context_id() def __exit__(self, type, value, traceback): _release_context(self.autograd_context._context_id())
优云智算