Shortcuts

torch.distributed.fsdp.api 的源代码

"""
此文件包含FSDP的公共API,例如用于构造函数参数的类。
"""

from dataclasses import dataclass
from enum import auto, Enum

from typing import Optional, Sequence, Type

import torch
from torch.nn.modules.batchnorm import _BatchNorm

__all__ = [
    "ShardingStrategy",
    "BackwardPrefetch",
    "MixedPrecision",
    "CPUOffload",
    "StateDictType",
    "StateDictConfig",
    "FullStateDictConfig",
    "LocalStateDictConfig",
    "ShardedStateDictConfig",
    "OptimStateDictConfig",
    "FullOptimStateDictConfig",
    "LocalOptimStateDictConfig",
    "ShardedOptimStateDictConfig",
    "StateDictSettings",
]


[docs]class ShardingStrategy(Enum): """ 这指定了用于分布式训练的分片策略 :class:`FullyShardedDataParallel`。 - ``FULL_SHARD``: 参数、梯度和优化器状态被分片。 对于参数,此策略在正向之前解分片(通过全收集),在正向之后重新分片,在反向计算之前解分片,在反向计算之后重新分片。对于梯度,它在反向计算之后同步并分片(通过减少散射)。分片的优化器状态在每个rank上本地更新。 - ``SHARD_GRAD_OP``: 梯度和优化器状态在计算期间被分片,此外,参数在计算外部被分片。对于参数,此策略在正向之前解分片,在正向之后不重新分片,仅在反向计算之后重新分片。分片的优化器状态在每个rank上本地更新。在 ``no_sync()`` 内部,参数在反向计算之后不会重新分片。 - ``NO_SHARD``: 参数、梯度和优化器状态不被分片,而是跨rank复制,类似于PyTorch的 :class:`DistributedDataParallel` API。对于梯度,此策略在反向计算之后同步它们(通过全减少)。未分片的优化器状态在每个rank上本地更新。 - ``HYBRID_SHARD``: 在节点内应用 ``FULL_SHARD``,并在节点间复制参数。这减少了通信量,因为昂贵的全收集和减少散射仅在节点内进行,对于中等大小的模型可能更高效。 - ``_HYBRID_SHARD_ZERO2``: 在节点内应用 ``SHARD_GRAD_OP``,并在节点间复制参数。这类似于 ``HYBRID_SHARD``,但可能会提供更高的吞吐量,因为未分片的参数在正向传递后不会被释放,从而节省了反向前的前向全收集。 """ FULL_SHARD = auto() SHARD_GRAD_OP = auto() NO_SHARD = auto() HYBRID_SHARD = auto() _HYBRID_SHARD_ZERO2 = auto()
[docs]class BackwardPrefetch(Enum): """ 这配置了显式的反向预取,通过在反向传递中启用通信和计算的重叠来提高吞吐量,代价是稍微增加了内存使用。 - ``BACKWARD_PRE``: 这启用了最多的重叠,但增加了最多的内存使用。这在前一组参数的梯度计算之前预取下一组参数。这重叠了*下一个全收集*和*当前梯度计算*,并且在峰值时,它持有当前一组参数、下一组参数和当前一组梯度在内存中。 - ``BACKWARD_POST``: 这启用了较少的重叠,但需要较少的内存使用。这在前一组参数的梯度计算之后预取下一组参数。这重叠了*当前减少散射*和*下一个梯度计算*,并且在分配下一组参数的内存之前释放当前一组参数,仅在峰值时持有下一组参数和当前一组梯度在内存中。 - FSDP的 ``backward_prefetch`` 参数接受 ``None``,这完全禁用了反向预取。这没有重叠,并且不会增加内存使用。通常,我们不推荐此设置,因为它可能会显著降低吞吐量。 更多技术背景:对于使用NCCL后端的单个进程组,任何集合,即使是从不同的流发出的,也会争夺相同的每个设备的NCCL流,这意味着集合发出的相对顺序对于重叠很重要。两个反向预取值对应于不同的发出顺序。 """ # 注意:对于两种模式,定义“当前”和“下一个”的顺序在当前实现中并不总是精确的。目标错误的预取只是意味着参数内存比需要的时间更早分配,可能会增加峰值内存使用,但不会影响正确性。 BACKWARD_PRE = auto() BACKWARD_POST = auto()
[docs]@dataclass class MixedPrecision: """ 这配置了FSDP原生的混合精度训练。 属性: param_dtype (Optional[torch.dtype]): 这指定了模型参数在前向和反向期间的dtype,因此也是前向和反向计算的dtype。在前向和反向之外,*分片的*参数保持在全精度(例如,用于优化器步骤),并且在模型检查点中,参数总是以全精度保存。(默认值: ``None``) reduce_dtype (Optional[torch.dtype]): 这指定了梯度减少(即减少散射或全减少)的dtype。如果这是 ``None`` 但 ``param_dtype`` 不是 ``None``,则这采用 ``param_dtype`` 值,仍然以低精度运行梯度减少。这允许与 ``param_dtype`` 不同,例如强制梯度减少以全精度运行。(默认值: ``None``) buffer_dtype (Optional[torch.dtype]): 这指定了缓冲区的dtype。FSDP不分片缓冲区。相反,FSDP在第一个前向传递中将它们转换为 ``buffer_dtype``,并在之后保持该dtype。对于模型检查点,缓冲区以全精度保存,除了 ``LOCAL_STATE_DICT``。(默认值: ``None``) keep_low_precision_grads (bool): 如果 ``False``,则FSDP在反向传递后将梯度提升到全精度,为优化器步骤做准备。如果 ``True``,则FSDP保持梯度在用于梯度减少的dtype中,如果使用支持低精度运行的自定义优化器,这可以节省内存。(默认值: ``False``) cast_forward_inputs (bool): 如果 ``True``,则此FSDP模块将其前向args和kwargs转换为 ``param_dtype``。这是为了确保参数和输入dtype在前向计算中匹配,因为许多操作需要这一点。当仅对部分但不是所有FSDP模块应用混合精度时,这可能需要设置为 ``True``,在这种情况下,混合精度的FSDP子模块需要重新转换其输入。(默认值: ``False``) cast_root_forward_inputs (bool): 如果 ``True``,则根FSDP模块将其前向args和kwargs转换为 ``param_dtype``,覆盖 ``cast_forward_inputs`` 的值。对于非根FSDP模块,这不做任何事情。(默认值: ``True``) _module_classes_to_ignore: (Sequence[Type[nn.Module]]): 这指定了在使用 ``auto_wrap_policy`` 时要忽略的模块类:这些类的模块将单独应用FSDP,且禁用混合精度(意味着最终的FSDP构造将偏离指定的策略)。如果未指定 ``auto_wrap_policy``,则这不做任何事情。此API是实验性的,可能会更改。(默认值: ``(_BatchNorm,)``) .. 注意:: 此API是实验性的,可能会更改。 .. 注意:: 只有浮点张量被转换为其指定的dtype。 .. 注意:: 在 ``summon_full_params`` 中,参数被强制为全精度,但缓冲区不是。 .. 注意:: 层归一化和批归一化即使在输入为低精度(如 ``float16`` 或 ``bfloat16``)时也以 ``float32`` 累积。禁用这些归一化模块的FSDP混合精度仅意味着仿射参数保持在 ``float32`` 中。然而,这为这些归一化模块带来了单独的全收集和减少散射,这可能效率低下,因此如果工作负载允许,用户应首选仍然对这些模块应用混合精度。 .. 注意:: 默认情况下,如果用户传递了一个包含任何 ``_BatchNorm`` 模块的模型并指定了一个 ``auto_wrap_policy``,则批归一化模块将单独应用FSDP,且禁用混合精度。参见 ``_module_classes_to_ignore`` 参数。 .. 注意:: ``MixedPrecision`` 默认 ``cast_root_forward_inputs=True`` 和 ``cast_forward_inputs=False``。对于根FSDP实例,其 ``cast_root_forward_inputs`` 优先于其 ``cast_forward_inputs``。对于非根FSDP实例,它们的 ``cast_root_forward_inputs`` 值被忽略。默认设置足以满足典型情况,即每个FSDP实例具有相同的 ``MixedPrecision`` 配置,并且只需要在模型的前向传递开始时将输入转换为 ``param_dtype``。 .. 注意:: 对于具有不同 ``MixedPrecision`` 配置的嵌套FSDP实例,我们建议设置单独的 ``cast_forward_inputs`` 值以配置在每个实例的前向之前是否转换输入。在这种情况下,由于转换发生在每个FSDP实例的前向之前,父FSDP实例应在其FSDP子模块之前运行其非FSDP子模块,以避免由于不同的 ``MixedPrecision`` 配置而改变激活dtype。 示例:: >>> # xdoctest: +SKIP("undefined variables") >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) >>> model[1] = FSDP( >>> model[1], >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>> ) >>> model = FSDP( >>> model, >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>> ) 上述示例显示了一个有效的示例。另一方面,如果 ``model[1]`` 被替换为 ``model[0]``,意味着使用不同 ``MixedPrecision`` 的子模块首先运行其前向,则 ``model[1]`` 将错误地看到 ``float16`` 激活而不是 ``bfloat16`` 激活。 """ param_dtype: Optional[torch.dtype] = None reduce_dtype: Optional<