Shortcuts

torch.distributed.optim.optimizer 的源代码

import logging

from collections import defaultdict
from threading import Lock
from typing import List, Optional

import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.jit as jit
import torch.nn as nn
from torch import Tensor
from torch.distributed.rpc import RRef
from .utils import functional_optim_map

__all__ = ["DistributedOptimizer"]

logger = logging.getLogger(__name__)


# XXX: 我们在这里定义了一个 _ScriptModuleOptimizer,以显式地
# 将 FunctionalOptimizer 类编译为 TorchScript
# 这是因为 ScriptClass 实例仍然存在于
# Python 中,除非你显式地将其编译为 ScriptModule 的属性
# 或在 ScriptFunction 中传递它
# _ScriptLocalOptimizerInterface 作为 Optimizer ScriptModules 的通用
# 接口类型。
#
# TODO (wanchaol): 一旦我们添加了 TorchScript
# 类引用语义,删除此内容
@jit.interface
class _ScriptLocalOptimizerInterface:
    def step(self, autograd_ctx_id: int) -> None:
        pass


class _ScriptLocalOptimizer(nn.Module):
    # TorchScript 不支持多线程并发编译。
    # request_callback 可能会调用并发编译,因此我们
    # 使用锁来序列化编译
    compile_lock = Lock()

    def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
        super().__init__()
        self._local_params = [rref.local_value() for rref in local_params_rref]
        self.optim = optim_cls(self._local_params, *args, **kwargs)

    @jit.export
    def step(self, autograd_ctx_id: int):
        all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
        # 使用功能优化器步骤应用梯度列表
        grads: List[Optional[Tensor]] = [
            all_local_grads[p] if p in all_local_grads else None
            for p in self._local_params
        ]

        self.optim.step(grads)


# TODO (wanchaol): 一旦我们在 distributed.optim 中将所有内容转换为功能优化器,删除/合并此内容
class _LocalOptimizer:
    # 理想情况下,我们只需要为处理相同参数的 _LocalOptimizer 实例共享一个锁。我们在这里
    # 做了一个简化的假设,即如果每个 worker 有多个 _LocalOptimizer 实例,它们将
    # 优化相同的参数(例如,每个数据并行训练器将创建自己的 _LocalOptimizer 实例,但
    # 它们将在每个 worker 上优化相同的参数)
    global_lock = Lock()

    def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
        self._local_params = [rref.local_value() for rref in local_params_rref]
        self.optim = optim_cls(self._local_params, *args, **kwargs)

    def step(self, autograd_ctx_id):
        all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)

        with _LocalOptimizer.global_lock:
            for param, grad in all_local_grads.items():
                param.grad = grad
            self.optim.step()


def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
    return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs))


def _local_optimizer_step(local_optim_rref, autograd_ctx_id):
    local_optim = local_optim_rref.local_value()
    local_optim.step(autograd_ctx_id)


# new/step 函数与 _ScriptLocalOptimizer 结合,提供无 GIL 的优化器
def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
    optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)

    with <span