使用ZeroRedundancyOptimizer分片优化器状态¶
创建于:2021年2月26日 | 最后更新:2021年10月20日 | 最后验证:未验证
在本教程中,您将学习:
ZeroRedundancyOptimizer的高级理念。
如何在分布式训练中使用ZeroRedundancyOptimizer及其影响。
需求¶
PyTorch 1.8+
什么是ZeroRedundancyOptimizer?¶
The idea of ZeroRedundancyOptimizer
comes from DeepSpeed/ZeRO project and
Marian that shard
optimizer states across distributed data-parallel processes to
reduce per-process memory footprint. In the
Getting Started With Distributed Data Parallel
tutorial, we have shown how to use
DistributedDataParallel
(DDP) to train models. In that tutorial, each process keeps a dedicated replica
of the optimizer. Since DDP has already synchronized gradients in the
backward pass, all optimizer replicas will operate on the same parameter and
gradient values in every iteration, and this is how DDP keeps model replicas in
the same state. Oftentimes, optimizers also maintain local states. For example,
the Adam optimizer uses per-parameter exp_avg and exp_avg_sq states. As a
result, the Adam optimizer’s memory consumption is at least twice the model
size. Given this observation, we can reduce the optimizer memory footprint by
sharding optimizer states across DDP processes. More specifically, instead of
creating per-param states for all parameters, each optimizer instance in
different DDP processes only keeps optimizer states for a shard of all model
parameters. The optimizer step() function only updates the parameters in its
shard and then broadcasts its updated parameters to all other peer DDP
processes, so that all model replicas still land in the same state.
如何使用ZeroRedundancyOptimizer?¶
下面的代码演示了如何使用
ZeroRedundancyOptimizer。
大部分代码与Distributed Data Parallel notes中展示的简单DDP示例相似。
主要区别在于example函数中的if-else子句,该子句包裹了优化器的构造,在
ZeroRedundancyOptimizer
和Adam优化器之间切换。
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP
def print_peak_memory(prefix, device):
if device == 0:
print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")
def example(rank, world_size, use_zero):
torch.manual_seed(0)
torch.cuda.manual_seed(0)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
# create default process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
# create local model
model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
print_peak_memory("Max memory allocated after creating local model", rank)
# construct DDP model
ddp_model = DDP(model, device_ids=[rank])
print_peak_memory("Max memory allocated after creating DDP", rank)
# define loss function and optimizer
loss_fn = nn.MSELoss()
if use_zero:
optimizer = ZeroRedundancyOptimizer(
ddp_model.parameters(),
optimizer_class=torch.optim.Adam,
lr=0.01
)
else:
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
# forward pass
outputs = ddp_model(torch.randn(20, 2000).to(rank))
labels = torch.randn(20, 2000).to(rank)
# backward pass
loss_fn(outputs, labels).backward()
# update parameters
print_peak_memory("Max memory allocated before optimizer step()", rank)
optimizer.step()
print_peak_memory("Max memory allocated after optimizer step()", rank)
print(f"params sum is: {sum(model.parameters()).sum()}")
def main():
world_size = 2
print("=== Using ZeroRedundancyOptimizer ===")
mp.spawn(example,
args=(world_size, True),
nprocs=world_size,
join=True)
print("=== Not Using ZeroRedundancyOptimizer ===")
mp.spawn(example,
args=(world_size, False),
nprocs=world_size,
join=True)
if __name__=="__main__":
main()
输出如下所示。当启用ZeroRedundancyOptimizer与Adam时,优化器step()的峰值内存消耗是普通Adam内存消耗的一半。这与我们的预期一致,因为我们在两个进程之间分片Adam优化器状态。输出还显示,使用ZeroRedundancyOptimizer后,模型参数在一次迭代后仍然具有相同的值(无论是否使用ZeroRedundancyOptimizer,参数总和相同)。
=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875