torch.distributed.optim.post_localSGD_optimizer 的源代码
import warnings
import torch
import torch.distributed.algorithms.model_averaging.averagers as averagers
[docs]class PostLocalSGDOptimizer(torch.optim.Optimizer):
r"""
包装任意 :class:`torch.optim.Optimizer` 并运行 `post-local SGD `_,
该优化器在每一步运行本地优化器。
在预热阶段之后,它在应用本地优化器之后定期平均参数。
参数:
optim: 本地优化器。
averager: 用于运行 post-localSGD 算法的模型平均器实例。
示例::
>>> # xdoctest: +SKIP("undefined variables")
>>> import torch
>>> import torch.distributed as dist
>>> import torch.distributed.algorithms.model_averaging.averagers as averagers
>>> import torch.nn as nn
>>> from torch.distributed.optim import PostLocalSGDOptimizer
>>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import (
>>> PostLocalSGDState,
>>> post_localSGD_hook,
>>> )
>>>
>>> model = nn.parallel.DistributedDataParallel(
>>> module, device_ids=[rank], output_device=rank
>>> )
>>>
>>> # 注册一个 post-localSGD 通信钩子。
>>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100)
>>> model.register_comm_hook(state, post_localSGD_hook)
>>>
>>> # 创建一个 post-localSGD 优化器,包装本地优化器。
>>> # 注意 ``PostLocalSGDOptimizer`` 中使用的 ``warmup_steps`` 必须与
>>> # ``PostLocalSGDState`` 中使用的 ``start_localSGD_iter`` 相同。
>>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01)
>>> opt = PostLocalSGDOptimizer(
>>> optim=local_optim,
>>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100)
>>> )
>>>
>>> # 在前 100 步中,DDP 在每一步运行全局梯度平均。
>>> # 在 100 步之后,DDP 在每个子组内(默认情况下是节点内)运行梯度平均,
>>> # 并且 post-localSGD 优化器在应用本地优化器之后每 4 步运行一次全局模型平均。
>>> for step in range(0, 200):
>>> opt.zero_grad()
>>> loss = loss_fn(output, labels)
>>> loss.backward()
>>> opt.step()
"""
def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager):
self.optim = optim
self.param_groups = self.optim.param_groups
self.averager = averager
@property
def state(self):
return self.optim.state
def __repr__(self):
return self.optim.__repr__()
[docs] def state_dict(self):
r"""
这与 :class:`torch.optim.Optimizer` 的 :meth:`state_dict` 相同,
但添加了一个额外的条目来记录模型平均器的步骤到检查点
以确保重新加载不会导致不必要的再次预热。
"""
optim_state_dict = self.optim.state_dict()
optim_state_dict["step"] = self.averager.step
return optim_state_dict
[docs] def load_state_dict(self, state_dict):
r"""
这与 :class:`torch.optim.Optimizer` 的 :meth:`load_state_dict` 相同,
但还恢复了模型平均器的步骤值到
提供的 ``state_dict`` 中保存的值。
如果 ``state_dict`` 中没有 ``"step"`` 条目,
它将发出警告并将模型平均器的步骤初始化为 0。
"""
self.optim.load_state_dict(state_dict)
if "step" in state_dict:
self.averager.step = state_dict["step"]
else:
warnings.warn(
"Loaded state dict does not contain a step counter for an averager. "
"Setting step counter to 0."
)
self.averager.step = 0
[docs] def step(self):
r"""
执行单个优化步骤(参数更新)。
"""
self.optim.step()
self.averager.average_parameters(params=self.param_groups)
def zero_grad(self, set_to_none: bool = True): # type: ignore[override]
self.optim.zero_grad(set_to_none=set_to_none)
def add_param_group(self, param_group):
self.optim.add_param_group(param_group)