Shortcuts

开始使用 CommDebugMode

创建于:2024年8月19日 | 最后更新:2024年10月8日 | 最后验证:2024年11月5日

作者: Anshul Sinha

在本教程中,我们将探讨如何使用CommDebugMode与PyTorch的DistributedTensor(DTensor)进行调试,通过跟踪分布式训练环境中的集体操作。

先决条件

  • Python 3.8 - 3.11

  • PyTorch 2.2 或更高版本

什么是CommDebugMode以及它为什么有用

随着模型规模的不断增加,用户正在寻求利用各种并行策略组合来扩展分布式训练。然而,现有解决方案之间缺乏互操作性,这带来了重大挑战,主要是由于缺乏能够桥接这些不同并行策略的统一抽象。为了解决这个问题,PyTorch提出了DistributedTensor(DTensor),它抽象了分布式训练中张量通信的复杂性,提供了无缝的用户体验。然而,在处理现有并行解决方案和使用DTensor等统一抽象开发并行解决方案时,缺乏对底层集体通信发生的内容和时间的透明度,可能会使高级用户难以识别和解决问题。为了解决这一挑战,CommDebugMode,一个Python上下文管理器,将作为DTensor的主要调试工具之一,使用户能够在使用DTensor时查看集体操作发生的时间和原因,从而有效解决这一问题。

使用 CommDebugMode

以下是您如何使用CommDebugMode的方法:

# The model used in this example is a MLPModule applying Tensor Parallel
comm_mode = CommDebugMode()
    with comm_mode:
        output = model(inp)

# print the operation level collective tracing information
print(comm_mode.generate_comm_debug_tracing_table(noise_level=0))

# log the operation level collective tracing information to a file
comm_mode.log_comm_debug_tracing_table_to_file(
    noise_level=1, file_name="transformer_operation_log.txt"
)

# dump the operation level collective tracing information to json file,
# used in the visual browser below
comm_mode.generate_json_dump(noise_level=2)

这是在噪声级别为0时MLPModule的输出结果:

Expected Output:
    Global
      FORWARD PASS
        *c10d_functional.all_reduce: 1
        MLPModule
          FORWARD PASS
            *c10d_functional.all_reduce: 1
            MLPModule.net1
            MLPModule.relu
            MLPModule.net2
              FORWARD PASS
                *c10d_functional.all_reduce: 1

要使用CommDebugMode,您必须将运行模型的代码包装在CommDebugMode中,并调用您想要使用的API来显示数据。您还可以使用noise_level参数来控制显示信息的详细程度。以下是每个噪声级别显示的内容:

0. Prints module-level collective counts
1. Prints DTensor operations (not including trivial operations), module sharding information
2. Prints tensor operations (not including trivial operations)
3. Prints all operations

在上面的例子中,你可以看到集体操作 all_reduce 在 MLPModule 的前向传递中发生了一次。此外,你可以使用 CommDebugMode 来精确定位 all-reduce 操作发生在 MLPModule 的第二个线性层中。

以下是交互式模块树可视化,您可以使用它来上传自己的JSON转储:

CommDebugMode Module Tree
Drag file here

结论

在本教程中,我们学习了如何使用CommDebugMode来调试使用PyTorch的通信集合的分布式张量和并行解决方案。您可以在嵌入式可视化浏览器中使用自己的JSON输出。

有关CommDebugMode的更多详细信息,请参阅 comm_mode_features_example.py

优云智算