FSDP 笔记¶
FSDP 预取细节¶
对于与前向
计算重叠的前向
全收集操作,有两种可能的机制:
隐式前向预取(始终启用)
显式正向预取(
forward_prefetch=True
)
隐式 forward
预取指的是依赖于从单独的 CUDA 流发出 all-gathers,以允许 all-gather 与之前发出的 forward
计算重叠(从 CPU 的角度来看)。例如,如果我们有层 0 all-gather -> 层 0 forward
计算 -> 层 1 all-gather -> …,那么层 1 all-gather 可以与层 0 forward
计算重叠,即使 CPU 线程在之后发出它。(第一个 all-gather 将无法与任何内容重叠。)
显式 forward
预取指的是改变 CPU 线程的发送顺序:例如,层 0 全收集 -> 层 1 全收集 -> 层 0 forward
计算 -> …。在急切模式下,通常无法知道在执行层 0 时下一层是哪一层(例如示例中的层 1)。因此,显式 forward
预取应仅用于每次迭代执行顺序固定的模型(我们有时称之为“静态图”)。不满足此约束的模型的一个例子是 FLAVA。
显式 forward
预取仅节省发出图层 forward
计算内核所需的时间,代价是必须在当前输出张量仍在使用时分配下一个 all-gather 的输出张量。通过在当前 forward
计算内核之前发出下一个 all-gather,可以在 GPU 上更早地开始下一个 all-gather。对于大多数 LLM 工作负载,情况并非如此,因此没有启用 forward_prefetch=True
的动力。
相比之下,对于backward
,我们必须使用显式的backward
预取,否则通信和计算之间将不会有重叠。原因是我们在所有收集和减少分散操作中使用了一个单一的NCCL进程组(部分原因是在较早的NCCL版本中,在同一设备上对相同等级使用多个进程组是不安全的)。一个单一的NCCL进程组意味着一个单一的内部NCCL流,在这个流上减少分散和所有收集操作是串行运行的。因此,除非我们显式地重新排序CPU的发送顺序为下一个所有收集 -> 当前减少分散,否则当前的减少分散将阻塞下一个所有收集,从而阻塞下一个backward
计算,阻止当前的减少分散与通信重叠。
通信负载大小¶
在FSDP中,通信包括:
在
forward
中对参数进行 all-gather在
backward
中对参数进行 all-gather在
backward
中对梯度进行 reduce-scatter
如果使用激活检查点(checkpoint()
),则不会有额外的通信,因为参数在backward
期间无论如何都会被预取。
在FSDP设计中,每个等级的通信负载确定如下:每次调用
FullyShardedDataParallel
创建一个通信组,该组由
module.parameters()
中的参数组成,除了已经分配给嵌套的 FullyShardedDataParallel
实例的参数。例如,对于Llama,如果你将 FullyShardedDataParallel
应用于每个
transformer块以及根模块,那么每个transformer块都有一个通信组,最后还有一个包含初始嵌入和最终线性的通信组。
每个通信组对应一个单独的all-gather调用和一个单独的reduce-scatter调用。通过这种方式,你如何应用 FullyShardedDataParallel
决定了通信的大小。通常,将FSDP应用于每个transformer块对于LLMs来说是一个很好的启发式方法,考虑到当前的设计,很难做得比这更好。
让我们考虑一个例子,我们有一个基于Transformer的模型在8个GPU上分片,其中分片仅发生在Transformer块级别,每个Transformer块包含1.6B参数,并且参数为fp32(每个4字节)。这意味着一旦分片,每个Transformer块在每个rank上将包含0.2B参数。
在
forward
过程中,将以0.2*4 = 0.8GB
的块在all-gather中进行通信反向传播过程将每次通信2次
0.8GB
(1次全收集和1次减少分散)
换句话说,每次通信的有效载荷为0.8GB
,总共会有3次通信。如果模型由10个transformer块组成,那么总共会有30次通信,总计30*0.8=24GB
。
正式化每个通信每个等级的有效载荷大小为
total_transformer_block_params_in_B*dtype_bytes/num_gpus
(GBs)。
请注意,在这个示例中,我们没有包括嵌入所需的额外通信,这也应该被考虑在内。并且数学计算将取决于输入和输出嵌入是否绑定。如果它们没有绑定,将会有两倍的通信量。
FSDP 缓冲区大小¶
首先,我们来讨论为通信分配的缓冲区:
forward
目前需要2倍的all-gather缓冲区大小。原因如下:
如在FSDP预取细微差别中所解释的,在显式forward
预取(forward_prefetch=True`) 情况 下 第 0 层 全收集 -> 第 0 层 前向 计算 -> 第 1
层 全收集 需要 2 个 全收集大小 的 缓冲区,因为 一个 缓冲区 用于 当前 ``forward
,而另一个用于预取。
虽然隐式的 forward
预取(forward_prefetch=False
,默认)情况下,理论上同一条序列只需要1个缓冲区,但实际上仍然是2倍于all-gather大小的缓冲区。原因是,在扁平参数FSDP设计中,我们不会从all-gather缓冲区中复制出来。用于计算的参数直接从all-gather缓冲区中查看(事实上,“扁平参数”的主要好处正是这个原因)。在这种情况下,当“第1层all-gather”与“第0层前向计算”重叠时,“第0层前向计算”使用的是从“第0层all-gather”缓冲区中查看的参数。
一个自然的问题是,什么时候你会希望forward_prefetch=False
?对于静态图模型(如大多数LLMs),有一个主要的技术原因。实际上,我们为了一些CPU密集型的内部模型快速添加了这个选项,并且在单元测试中没有测试过每一条代码路径,所以我们对此不太有信心。forward_prefetching=False
可能会稍微容易理解一些,因为我们不需要检查记录的前向顺序作为可能的“失败模式”;一个模块的所有收集操作总是可以在其自己的record_function
标签下的分析器跟踪中找到。
backward
目前至少需要2倍的全收集缓冲区大小,并且可能需要更多。原因如下:
当前的FSDP设计使用recordStream
来管理在一个流中产生的分配并在另一个流中消费的分配,这可能导致比预期更多的内存使用。这种“非确定性”的程度取决于GPU内核时间相对于CPU的时间。limit_all_gathers=True
参数是一种缓解措施——更多详情请参阅FSDP & CUDACachingAllocator的讨论。
现有 FSDP 与 autograd 的工作方式:
现有的FSDP对
flat_param
进行全收集,这是自动求导的叶子节点。它调用
torch.split
来获取对应于其组成原始参数的flat_param
的1D视图。它对每个1D分割调用
torch.view
以重新查看为ND。这意味着在
backward
中,我们最终得到ViewBackward
(ND -> 1D)和SplitWithSizesBackward
(这是一个 concat)。特别是,每个单独的梯度都是作为单独的分配计算的,并且显式地进行 concat 以构建 reduce-scatter 输入缓冲区。这意味着在峰值内存点处,reduce-scatter 的缓冲区大小实际上是 2 倍。
总之,对于 backward
,它大约是 reduce-scatter 的 2 倍缓冲区大小加上任何 recordStream
效果。
其次,让我们讨论一下额外的缓冲区:
一旦从所有等级收集了分片参数,它们需要一个额外的缓冲区,大小为 total_transformer_block_params_in_B*dtype_bytes,用于完整参数 - 因此继续前面的例子,如果每个transformer块是1.6B参数,并且参数是fp32格式,那么它将是 1.6*4=6.4GB 缓冲区。
并且需要两个这样的缓冲区,因为当前有一个正在使用,另一个正在预取。
总结一下,我们有:
2倍的通信缓冲区大小为
total_transformer_block_params_in_B*dtype_bytes/num_gpus
2倍未分片的transformer块参数缓冲区
``total_transformer_block_params_in_B*dtype_bytes
或者如果你一直在跟随这个例子:
2*1.6*4/8=1.6GB
2**1.6*4=12.8GB
以及总共的 14.4GB
。
现在让我们简要讨论一下当我们从计算中省略这些嵌入时会发生什么:
根据我们在讨论中提到的规则,您在笔记中以“通信缓冲区大小确定如下”开头的部分,我们可以分析如下:
假设我们将FSDP应用于根模块(例如,
Transformer
类)。进一步假设我们将FSDP应用于每个transformer块(例如,TransformerBlock
类)。最常见的是,嵌入和最终的线性投影是根
Transformer
类的直接子类。根据我们的规则,这意味着嵌入和最终的线性投影被分配给根
Transformer
的平面参数。我们有一个_另一个_特殊规则,即根节点在正向传播后不会释放其参数,因为它们无论如何都会在反向传播中立即全部收集。
将这些结合起来,这意味着根节点的flat参数,包括嵌入和最终投影,都会被全部收集以开始前向传播,并保持在GPU内存中直到反向传播结束。
如果嵌入和最终线性层没有共享权重,那么我们可以进一步将FSDP应用于嵌入层和最终线性层。对于共享权重的参数,我们要求它们成为同一个扁平参数的一部分(否则会被重复计算)。这将允许嵌入层在使用后被释放,并且只在反向传播的末尾进行全收集。
希望这能更好地说明问题——每个FSDP模块在其
module.parameters
中被分配参数,除了那些已经分配给另一个嵌套FSDP模块的参数,并且FSDP模块的forward
定义了其参数的“活动”间隔。因此,嵌套的nn.Module
结构可以影响all-gather/free调度,从而影响内存/吞吐量性能。