Shortcuts

使用ZeroRedundancyOptimizer分片优化器状态

创建于:2021年2月26日 | 最后更新:2021年10月20日 | 最后验证:未验证

在本教程中,您将学习:

什么是ZeroRedundancyOptimizer?

The idea of ZeroRedundancyOptimizer comes from DeepSpeed/ZeRO project and Marian that shard optimizer states across distributed data-parallel processes to reduce per-process memory footprint. In the Getting Started With Distributed Data Parallel tutorial, we have shown how to use DistributedDataParallel (DDP) to train models. In that tutorial, each process keeps a dedicated replica of the optimizer. Since DDP has already synchronized gradients in the backward pass, all optimizer replicas will operate on the same parameter and gradient values in every iteration, and this is how DDP keeps model replicas in the same state. Oftentimes, optimizers also maintain local states. For example, the Adam optimizer uses per-parameter exp_avg and exp_avg_sq states. As a result, the Adam optimizer’s memory consumption is at least twice the model size. Given this observation, we can reduce the optimizer memory footprint by sharding optimizer states across DDP processes. More specifically, instead of creating per-param states for all parameters, each optimizer instance in different DDP processes only keeps optimizer states for a shard of all model parameters. The optimizer step() function only updates the parameters in its shard and then broadcasts its updated parameters to all other peer DDP processes, so that all model replicas still land in the same state.

如何使用ZeroRedundancyOptimizer

下面的代码演示了如何使用 ZeroRedundancyOptimizer。 大部分代码与Distributed Data Parallel notes中展示的简单DDP示例相似。 主要区别在于example函数中的if-else子句,该子句包裹了优化器的构造,在 ZeroRedundancyOptimizerAdam优化器之间切换。

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP

def print_peak_memory(prefix, device):
    if device == 0:
        print(f"{prefix}: {torch.cuda.max_memory_allocated(device) // 1e6}MB ")

def example(rank, world_size, use_zero):
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    # create default process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

    # create local model
    model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)])
    print_peak_memory("Max memory allocated after creating local model", rank)

    # construct DDP model
    ddp_model = DDP(model, device_ids=[rank])
    print_peak_memory("Max memory allocated after creating DDP", rank)

    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    if use_zero:
        optimizer = ZeroRedundancyOptimizer(
            ddp_model.parameters(),
            optimizer_class=torch.optim.Adam,
            lr=0.01
        )
    else:
        optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)

    # forward pass
    outputs = ddp_model(torch.randn(20, 2000).to(rank))
    labels = torch.randn(20, 2000).to(rank)
    # backward pass
    loss_fn(outputs, labels).backward()

    # update parameters
    print_peak_memory("Max memory allocated before optimizer step()", rank)
    optimizer.step()
    print_peak_memory("Max memory allocated after optimizer step()", rank)

    print(f"params sum is: {sum(model.parameters()).sum()}")



def main():
    world_size = 2
    print("=== Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, True),
        nprocs=world_size,
        join=True)

    print("=== Not Using ZeroRedundancyOptimizer ===")
    mp.spawn(example,
        args=(world_size, False),
        nprocs=world_size,
        join=True)

if __name__=="__main__":
    main()

输出如下所示。当启用ZeroRedundancyOptimizerAdam时,优化器step()的峰值内存消耗是普通Adam内存消耗的一半。这与我们的预期一致,因为我们在两个进程之间分片Adam优化器状态。输出还显示,使用ZeroRedundancyOptimizer后,模型参数在一次迭代后仍然具有相同的值(无论是否使用ZeroRedundancyOptimizer,参数总和相同)。

=== Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1361.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875
=== Not Using ZeroRedundancyOptimizer ===
Max memory allocated after creating local model: 335.0MB
Max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 992.0MB
Max memory allocated after optimizer step(): 1697.0MB
params sum is: -3453.6123046875
params sum is: -3453.6123046875