torch.distributed.optim.optimizer 的源代码
import logging
from collections import defaultdict
from threading import Lock
from typing import List, Optional
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.jit as jit
import torch.nn as nn
from torch import Tensor
from torch.distributed.rpc import RRef
from .utils import functional_optim_map
__all__ = ["DistributedOptimizer"]
logger = logging.getLogger(__name__)
# XXX: 我们在这里定义了一个 _ScriptModuleOptimizer,以显式地
# 将 FunctionalOptimizer 类编译为 TorchScript
# 这是因为 ScriptClass 实例仍然存在于
# Python 中,除非你显式地将其编译为 ScriptModule 的属性
# 或在 ScriptFunction 中传递它
# _ScriptLocalOptimizerInterface 作为 Optimizer ScriptModules 的通用
# 接口类型。
#
# TODO (wanchaol): 一旦我们添加了 TorchScript
# 类引用语义,删除此内容
@jit.interface
class _ScriptLocalOptimizerInterface:
def step(self, autograd_ctx_id: int) -> None:
pass
class _ScriptLocalOptimizer(nn.Module):
# TorchScript 不支持多线程并发编译。
# request_callback 可能会调用并发编译,因此我们
# 使用锁来序列化编译
compile_lock = Lock()
def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
super().__init__()
self._local_params = [rref.local_value() for rref in local_params_rref]
self.optim = optim_cls(self._local_params, *args, **kwargs)
@jit.export
def step(self, autograd_ctx_id: int):
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
# 使用功能优化器步骤应用梯度列表
grads: List[Optional[Tensor]] = [
all_local_grads[p] if p in all_local_grads else None
for p in self._local_params
]
self.optim.step(grads)
# TODO (wanchaol): 一旦我们在 distributed.optim 中将所有内容转换为功能优化器,删除/合并此内容
class _LocalOptimizer:
# 理想情况下,我们只需要为处理相同参数的 _LocalOptimizer 实例共享一个锁。我们在这里
# 做了一个简化的假设,即如果每个 worker 有多个 _LocalOptimizer 实例,它们将
# 优化相同的参数(例如,每个数据并行训练器将创建自己的 _LocalOptimizer 实例,但
# 它们将在每个 worker 上优化相同的参数)
global_lock = Lock()
def __init__(self, optim_cls, local_params_rref, *args, **kwargs):
self._local_params = [rref.local_value() for rref in local_params_rref]
self.optim = optim_cls(self._local_params, *args, **kwargs)
def step(self, autograd_ctx_id):
all_local_grads = dist_autograd.get_gradients(autograd_ctx_id)
with _LocalOptimizer.global_lock:
for param, grad in all_local_grads.items():
param.grad = grad
self.optim.step()
def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs))
def _local_optimizer_step(local_optim_rref, autograd_ctx_id):
local_optim = local_optim_rref.local_value()
local_optim.step(autograd_ctx_id)
# new/step 函数与 _ScriptLocalOptimizer 结合,提供无 GIL 的优化器
def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs):
optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)
with <span