• Tutorials >
  • Asynchronous Saving with Distributed Checkpoint (DCP)
Shortcuts

使用分布式检查点(DCP)进行异步保存

创建于:2024年7月22日 | 最后更新:2024年7月22日 | 最后验证:2024年11月5日

作者: Lucas Pasqualin, Iris Zhang, Rodrigo Kumpera, Chien-Chin Huang

检查点通常是分布式训练工作负载关键路径中的瓶颈,随着模型和世界规模的增大,成本也越来越高。抵消这一成本的一个优秀策略是并行、异步地进行检查点保存。下面,我们扩展了分布式检查点教程入门中的保存示例,展示如何轻松地将其与torch.distributed.checkpoint.async_save集成。

What you will learn
  • 如何使用DCP并行生成检查点

  • 优化性能的有效策略

Prerequisites

异步检查点概述

在开始使用异步检查点之前,了解它与同步检查点的区别和限制非常重要。 具体来说:

  • Memory requirements - Asynchronous checkpointing works by first copying models into internal CPU-buffers.

    这是有帮助的,因为它确保在模型仍在检查点时,模型和优化器的权重不会改变, 但确实会使CPU内存增加checkpoint_size_per_rank X number_of_ranks倍。此外,用户应注意了解 其系统的内存限制。具体来说,固定内存意味着使用page-lock内存,与pageable内存相比,这种内存可能较为稀缺。

  • Checkpoint Management - Since checkpointing is asynchronous, it is up to the user to manage concurrently run checkpoints. In general, users can

    通过处理从async_save返回的future对象,采用他们自己的管理策略。对于大多数用户,我们建议一次只限制一个异步请求的检查点,以避免每个请求的额外内存压力。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    checkpoint_future = None
    for step in range(10):
        optimizer.zero_grad()
        model(torch.rand(8, 16, device="cuda")).sum().backward()
        optimizer.step()

        # waits for checkpointing to finish if one exists, avoiding queuing more then one checkpoint request at a time
        if checkpoint_future is not None:
            checkpoint_future.result()

        state_dict = { "app": AppState(model, optimizer) }
        checkpoint_future = dcp.async_save(state_dict, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running async checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

使用固定内存获得更多性能

如果上述优化仍然不够高效,您可以利用GPU模型的额外优化,该优化利用固定内存缓冲区进行检查点暂存。 具体来说,此优化针对异步检查点的主要开销,即内存复制到检查点缓冲区。通过在检查点请求之间维护一个固定内存缓冲区, 用户可以利用直接内存访问来加速此复制过程。

注意

这种优化的主要缺点是在检查点步骤之间缓冲区的持久性。如果没有固定内存优化(如上所示),任何检查点缓冲区在检查点完成后都会立即释放。使用固定内存实现时,此缓冲区在步骤之间保持,导致在整个应用程序生命周期中持续相同的峰值内存压力。

import os

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.checkpoint import StorageWriter

CHECKPOINT_DIR = "checkpoint"


class AppState(Stateful):
    """This is a useful wrapper for checkpointing the Application State. Since this object is compliant
    with the Stateful protocol, DCP will automatically call state_dict/load_stat_dict as needed in the
    dcp.save/load APIs.

    Note: We take advantage of this wrapper to hande calling distributed state dict methods on the model
    and optimizer.
    """

    def __init__(self, model, optimizer=None):
        self.model = model
        self.optimizer = optimizer

    def state_dict(self):
        # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT
        model_state_dict, optimizer_state_dict = get_state_dict(model, optimizer)
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict
        }

    def load_state_dict(self, state_dict):
        # sets our state dicts on the model and optimizer, now that we've loaded
        set_state_dict(
            self.model,
            self.optimizer,
            model_state_dict=state_dict["model"],
            optim_state_dict=state_dict["optim"]
        )

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(16, 16)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(16, 8)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


def setup(rank, world_size):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355 "

    # initialize the process group
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


def cleanup():
    dist.destroy_process_group()


def run_fsdp_checkpoint_save_example(rank, world_size):
    print(f"Running basic FSDP checkpoint saving example on rank {rank}.")
    setup(rank, world_size)

    # create a model and move it to GPU with id rank
    model = ToyModel().to(rank)
    model = FSDP(model)

    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

    # The storage writer defines our 'staging' strategy, where staging is considered the process of copying
    # checkpoints to in-memory buffers. By setting `cached_state_dict=True`, we enable efficient memory copying
    # into a persistent buffer with pinned memory enabled.
    # Note: It's important that the writer persists in between checkpointing requests, since it maintains the
    # pinned memory buffer.
    writer = StorageWriter(cached_state_dict=True)
    checkpoint_future = None
    for step in range(10):
        optimizer.zero_grad()
        model(torch.rand(8, 16, device="cuda")).sum().backward()
        optimizer.step()

        state_dict = { "app": AppState(model, optimizer) }
        if checkpoint_future is not None:
            # waits for checkpointing to finish, avoiding queuing more then one checkpoint request at a time
            checkpoint_future.result()
        dcp.async_save(state_dict, storage_writer=writer, checkpoint_id=f"{CHECKPOINT_DIR}_step{step}")

    cleanup()


if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Running fsdp checkpoint example on {world_size} devices.")
    mp.spawn(
        run_fsdp_checkpoint_save_example,
        args=(world_size,),
        nprocs=world_size,
        join=True,
    )

结论

总之,我们已经学会了如何使用DCP的async_save() API在关键训练路径之外生成检查点。我们还了解了使用此API引入的额外内存和并发开销,以及利用固定内存进一步加速的额外优化。

优云智算