Shortcuts

torch.distributed.optim.post_localSGD_optimizer 的源代码

import warnings

import torch
import torch.distributed.algorithms.model_averaging.averagers as averagers


[docs]class PostLocalSGDOptimizer(torch.optim.Optimizer): r""" 包装任意 :class:`torch.optim.Optimizer` 并运行 `post-local SGD `_, 该优化器在每一步运行本地优化器。 在预热阶段之后,它在应用本地优化器之后定期平均参数。 参数: optim: 本地优化器。 averager: 用于运行 post-localSGD 算法的模型平均器实例。 示例:: >>> # xdoctest: +SKIP("undefined variables") >>> import torch >>> import torch.distributed as dist >>> import torch.distributed.algorithms.model_averaging.averagers as averagers >>> import torch.nn as nn >>> from torch.distributed.optim import PostLocalSGDOptimizer >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( >>> PostLocalSGDState, >>> post_localSGD_hook, >>> ) >>> >>> model = nn.parallel.DistributedDataParallel( >>> module, device_ids=[rank], output_device=rank >>> ) >>> >>> # 注册一个 post-localSGD 通信钩子。 >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) >>> model.register_comm_hook(state, post_localSGD_hook) >>> >>> # 创建一个 post-localSGD 优化器,包装本地优化器。 >>> # 注意 ``PostLocalSGDOptimizer`` 中使用的 ``warmup_steps`` 必须与 >>> # ``PostLocalSGDState`` 中使用的 ``start_localSGD_iter`` 相同。 >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) >>> opt = PostLocalSGDOptimizer( >>> optim=local_optim, >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) >>> ) >>> >>> # 在前 100 步中,DDP 在每一步运行全局梯度平均。 >>> # 在 100 步之后,DDP 在每个子组内(默认情况下是节点内)运行梯度平均, >>> # 并且 post-localSGD 优化器在应用本地优化器之后每 4 步运行一次全局模型平均。 >>> for step in range(0, 200): >>> opt.zero_grad() >>> loss = loss_fn(output, labels) >>> loss.backward() >>> opt.step() """ def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager): self.optim = optim self.param_groups = self.optim.param_groups self.averager = averager @property def state(self): return self.optim.state def __repr__(self): return self.optim.__repr__()
[docs] def state_dict(self): r""" 这与 :class:`torch.optim.Optimizer` 的 :meth:`state_dict` 相同, 但添加了一个额外的条目来记录模型平均器的步骤到检查点 以确保重新加载不会导致不必要的再次预热。 """ optim_state_dict = self.optim.state_dict() optim_state_dict["step"] = self.averager.step return optim_state_dict
[docs] def load_state_dict(self, state_dict): r""" 这与 :class:`torch.optim.Optimizer` 的 :meth:`load_state_dict` 相同, 但还恢复了模型平均器的步骤值到 提供的 ``state_dict`` 中保存的值。 如果 ``state_dict`` 中没有 ``"step"`` 条目, 它将发出警告并将模型平均器的步骤初始化为 0。 """ self.optim.load_state_dict(state_dict) if "step" in state_dict: self.averager.step = state_dict["step"] else: warnings.warn( "Loaded state dict does not contain a step counter for an averager. " "Setting step counter to 0." ) self.averager.step = 0
[docs] def step(self): r""" 执行单个优化步骤(参数更新)。 """ self.optim.step() self.averager.average_parameters(params=self.param_groups)
def zero_grad(self, set_to_none: bool = True): # type: ignore[override] self.optim.zero_grad(set_to_none=set_to_none) def add_param_group(self, param_group): self.optim.add_param_group(param_group)