FullyShardedDataParallel¶
- class torch.distributed.fsdp.FullyShardedDataParallel(module, process_group=None, sharding_strategy=None, cpu_offload=None, auto_wrap_policy=None, backward_prefetch=BackwardPrefetch.BACKWARD_PRE, mixed_precision=None, ignored_modules=None, param_init_fn=None, device_id=None, sync_module_states=False, forward_prefetch=False, limit_all_gathers=True, use_orig_params=False, ignored_states=None, device_mesh=None)[源代码]¶
一个用于在数据并行工作器之间分片模块参数的包装器。
这是受到Xu 等人以及DeepSpeed的 ZeRO Stage 3 的启发。 FullyShardedDataParallel 通常简称为 FSDP。
有关高级笔记,请参阅FSDP 笔记。
示例:
>>> import torch >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> torch.cuda.set_device(device_id) >>> sharded_module = FSDP(my_module) >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) >>> loss = x.sum() >>> loss.backward() >>> optim.step()
警告
优化器必须在模块被FSDP包装之后进行初始化,因为FSDP会以一种可能不保留原始参数变量的方式对模块的参数进行分片和转换。因此,之前初始化的优化器可能会有对参数的陈旧引用。
警告
如果目标CUDA设备的ID为
dev_id,则应满足以下任一条件:(1)module应已放置在该设备上,(2)应使用torch.cuda.set_device(dev_id)设置设备,或(3) 应将dev_id传递给device_id构造函数参数。此FSDP实例的计算设备将是该目标设备。对于(1)和(3),FSDP初始化始终在GPU上进行。对于(2),FSDP初始化发生在module的当前设备上,该设备可能是CPU。警告
FSDP 目前在使用 CPU 卸载时不支持在
no_sync()外部进行梯度累积。尝试这样做会导致结果不正确,因为 FSDP 将使用新减少的梯度,而不是与任何现有梯度进行累积。警告
在构造后更改原始参数变量名称将导致未定义行为。
警告
传入
sync_module_states=True标志要求module在GPU上或使用device_id参数来指定一个CUDA设备,FSDP将在FSDP构造函数中将module移动到该设备。这是因为sync_module_states=True需要GPU通信。警告
截至 PyTorch 1.12,FSDP 仅对共享参数提供有限支持(例如,将一个
Linear层的权重设置为另一个层的权重)。特别是,共享参数的模块必须作为同一 FSDP 单元的一部分进行包装。如果您的用例需要增强的共享参数支持,请在 https://github.com/pytorch/pytorch/issues/77724 上联系我们。警告
FSDP 在冻结参数(即设置
param.requires_grad=False)方面有一些限制。对于use_orig_params=False,每个 FSDP 实例必须管理全部冻结或全部非冻结的参数。对于use_orig_params=True,FSDP 支持混合冻结 和非冻结,但我们建议不要这样做,因为这样梯度内存使用量将高于预期(即,相当于没有冻结这些参数)。这意味着理想情况下,冻结的参数 应该被隔离到它们自己的nn.Module中,并分别用 FSDP 包装。注意
尝试运行包含在FSDP实例中的子模块的前向传递是不支持的,并且会导致错误。这是因为子模块的参数将被分片,但它本身不是一个FSDP实例,因此它的前向传递不会适当地收集完整的参数。这可能会在尝试仅运行编码器-解码器模型的编码器时发生,并且编码器未在其自己的FSDP实例中包装。要解决此问题,请将子模块包装在其自己的FSDP单元中。
注意
FSDP 将输入张量移动到
forward方法到 GPU 计算设备,因此用户不需要手动将它们从 CPU 移动。警告
用户不应在未使用
summon_full_params()上下文的情况下,修改前向和后向之间的参数,因为这些修改可能不会持久化。此外,对于use_orig_params=False,在前向和后向之间访问原始参数可能会引发非法内存访问。警告
对于
use_orig_params=True,ShardingStrategy.SHARD_GRAD_OP在前向传播后暴露的是未分片的参数,而不是分片的参数,因为它不会释放未分片的参数,这与ShardingStrategy.FULL_SHARD不同。需要注意的是,由于梯度总是分片或None,ShardingStrategy.SHARD_GRAD_OP在前向传播后不会暴露带有未分片参数的分片梯度。如果你想检查梯度,可以尝试summon_full_params()并设置with_grads=True。警告
FSDP 在正向和反向计算期间,由于自动求导相关的原因,将托管模块的参数替换为
torch.Tensor视图。 如果你的模块的正向依赖于保存的参数引用,而不是每次迭代时重新获取引用,那么它将不会看到 FSDP 新创建的视图,自动求导将无法正常工作。注意
使用
limit_all_gathers=True,您可能会在 FSDP 前向传播中看到一个间隙,其中 CPU 线程没有发出任何内核。这是有意为之,并显示了速率限制器的作用。以这种方式同步 CPU 线程可以防止为后续的 all-gathers 过度分配内存,并且它不应实际延迟 GPU 内核执行。注意
当使用
sharding_strategy=ShardingStrategy.HYBRID_SHARD进行分片处理组为节点内且复制处理组为节点间时,设置NCCL_CROSS_NIC=1可以帮助在一些集群设置中改善复制处理组上的 all-reduce 时间。警告
FSDP 由于其注册反向钩子的方式,不支持双反向传播。
- Parameters
模块 (nn.Module) – 这是要使用FSDP包装的模块。
process_group (可选[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]) – 这是模型分片所在的进程组,因此也是FSDP用于全收集和减少分散集体通信的进程组。如果
None,则FSDP使用默认进程组。对于混合分片策略,例如ShardingStrategy.HYBRID_SHARD,用户可以传入一个进程组元组,分别表示要分片和复制的组。如果None,则FSDP为用户构建进程组,以在节点内分片并在节点间复制。(默认值:None)sharding_strategy (可选[ShardingStrategy]) – 这配置了分片策略,可能会在内存节省和通信开销之间进行权衡。详情请参见
ShardingStrategy。(默认值:FULL_SHARD)cpu_offload (可选[CPUOffload]) – 这配置了CPU卸载。如果设置为
None,则不会发生CPU卸载。详情请参见CPUOffload。 (默认值:None)auto_wrap_policy (可选[联合[可调用[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]) –
这指定了将FSDP应用于
module子模块的策略, 这是为了实现通信和计算的重叠,从而影响性能。如果None,则FSDP仅应用于module,用户应手动将FSDP应用于父模块 (自底向上进行)。为了方便,这直接接受ModuleWrapPolicy,允许用户指定要包装的 模块类(例如,transformer块)。否则, 这应该是一个可调用对象,它接受三个参数module: nn.Module,recurse: bool,和nonwrapped_numel: int,并应返回一个bool,指定 如果recurse=False,则传入的module是否应应用FSDP, 或者如果recurse=True,则遍历是否应继续进入 模块的子树。用户可以向可调用对象添加额外的 参数。torch.distributed.fsdp.wrap.py中的size_based_auto_wrap_policy给出了一个示例可调用对象, 如果子树中的参数超过 100M numel,则将FSDP应用于模块。我们建议在应用FSDP后打印模型 并根据需要进行调整。示例:
>>> def custom_auto_wrap_policy( >>> module: nn.Module, >>> recurse: bool, >>> nonwrapped_numel: int, >>> # 额外的自定义参数 >>> min_num_params: int = int(1e8), >>> ) -> bool: >>> return nonwrapped_numel >= min_num_params >>> # 配置自定义的`min_num_params` >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
backward_prefetch (可选[BackwardPrefetch]) – 这配置了所有-gathers的显式反向预取。如果
None,则FSDP不进行反向预取,并且在反向传递中没有通信和计算的重叠。有关详细信息,请参阅BackwardPrefetch。(默认值:BACKWARD_PRE)mixed_precision (可选[MixedPrecision]) – 此配置为FSDP设置原生混合精度。如果设置为
None,则不使用混合精度。否则,可以设置参数、缓冲区和梯度缩减的数据类型。详情请参见MixedPrecision。(默认值:None)ignored_modules (可选[可迭代[torch.nn.Module]]) – 该实例忽略其自身参数和子模块参数及缓冲区的模块。
ignored_modules中的任何模块都不应是FullyShardedDataParallel实例,并且如果子模块是已构造的FullyShardedDataParallel实例,则嵌套在此实例下时不会被忽略。此参数可用于在使用auto_wrap_policy或在 FSDP 不管理参数分片时避免在模块粒度上对特定参数进行分片。(默认值:None)param_init_fn (可选[可调用[[nn.Module], 无]]) –
一个
可调用[torch.nn.Module] -> 无,指定如何将当前位于元设备上的模块初始化到实际设备上。从 v1.12 开始,FSDP 通过is_meta检测具有参数或缓冲区的模块,并应用指定的param_init_fn或调用nn.Module.reset_parameters()。对于这两种情况,实现应仅初始化模块的参数/缓冲区,而不是其子模块的参数/缓冲区。这是为了避免重新初始化。此外,FSDP 还通过 torchdistX 的 (https://github.com/pytorch/torchdistX)deferred_init()API 支持延迟初始化,其中延迟模块通过调用指定的param_init_fn或 torchdistX 的默认materialize_module()进行初始化。如果指定了param_init_fn,则它将应用于所有元设备模块,这意味着它可能需要根据模块类型进行处理。FSDP 在参数展平和分片之前调用初始化函数。示例:
>>> module = MyModule(device="meta") >>> def my_init_fn(module: nn.Module): >>> # 例如,根据模块类型进行初始化 >>> ... >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) >>> print(next(fsdp_model.parameters()).device) # 当前 CUDA 设备 >>> # 使用 torchdistX >>> module = deferred_init.deferred_init(MyModule, device="cuda") >>> # 将通过 deferred_init.materialize_module() 进行初始化。 >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy)
device_id (可选[联合[int, torch.device]]) – 一个
int或torch.device指定在哪个CUDA设备上进行FSDP初始化,包括模块初始化 如果需要和参数分片。如果module在CPU上,则应指定此项以提高初始化速度。如果设置了默认的CUDA设备(例如通过torch.cuda.set_device), 则用户可以传递torch.cuda.current_device到此。 (默认值:None)sync_module_states (bool) – 如果
True,则每个FSDP模块将 从rank 0广播模块参数和缓冲区,以确保它们在各个rank之间复制(为 此构造函数增加通信开销)。这可以帮助通过load_state_dict以 内存高效的方式加载state_dict检查点。请参阅FullStateDictConfig以获取此示例。(默认值:False)forward_prefetch (bool) – 如果
True,则FSDP显式地在当前前向计算之前预取下一个前向传递的所有收集操作。这仅对CPU密集型工作负载有用,在这种情况下,提前发出下一个所有收集操作可能会提高重叠度。这应该仅用于静态图模型,因为预取操作遵循第一次迭代的执行顺序。(默认值:False)limit_all_gathers (bool) – 如果
True,则FSDP显式同步CPU线程以确保仅从两个连续的FSDP实例(当前正在运行的实例和预取了all-gather的下一个实例)使用GPU内存。如果False,则FSDP允许CPU线程在没有额外同步的情况下发出all-gather。(默认值:True)我们通常将此功能称为“速率限制器”。此标志仅应在特定CPU绑定工作负载且内存压力较低的情况下设置为False,在这种情况下,CPU线程可以积极发出所有内核而无需担心GPU内存使用情况。use_orig_params (bool) – 将其设置为
True时,FSDP 使用module的原始参数。FSDP 通过nn.Module.named_parameters()向用户公开这些原始参数,而不是 FSDP 的内部FlatParameter。这意味着优化器步骤在原始参数上运行,从而启用每个原始参数的超参数。FSDP 保留原始参数变量并在非分片和分片形式之间操作它们的数据,它们始终是底层非分片或分片FlatParameter的视图。使用当前算法,分片形式始终为 1D,丢失原始张量结构。原始参数可能在其数据中全部、部分或没有数据存在于给定秩中。在无数据的情况下,其数据将类似于大小为 0 的空张量。用户不应编写依赖于给定原始参数在其分片形式中存在数据的程序。True是使用torch.compile()所必需的。将其设置为False时,FSDP 通过nn.Module.named_parameters()向用户公开其内部FlatParameter。(默认值:False)ignored_states (可选[可迭代[torch.nn.Parameter]], 可选[可迭代[torch.nn.Module]]) – 被忽略的参数或模块,这些参数或模块将不会由这个FSDP实例管理,意味着这些参数不会被分片,它们的梯度也不会在各个rank之间进行减少。此参数与现有的
ignored_modules参数统一,我们可能会在不久后弃用ignored_modules。为了向后兼容,我们保留了ignored_states和ignored_modules`,但FSDP只允许其中一个被指定为非None。
- apply(fn)[源代码]¶
递归地将
fn应用于每个子模块(由.children()返回)以及自身。典型用途包括初始化模型的参数(另请参阅 torch.nn.init)。
与
torch.nn.Module.apply相比,此版本在应用fn之前还会收集完整的参数。它不应在另一个summon_full_params上下文中调用。- Parameters
fn (
Module-> None) – 应用于每个子模块的函数- Returns
自身
- Return type
- 模块
- clip_grad_norm_(max_norm, norm_type=2.0)[源代码]¶
裁剪所有参数的梯度范数。
范数是基于所有参数的梯度计算的,这些梯度被视为一个单一向量,并且梯度是就地修改的。
- Parameters
- Returns
参数的总范数(视为单个向量)。
- Return type
注意
如果每个FSDP实例都使用
NO_SHARD,这意味着没有梯度在各个rank之间分片,那么你可以直接使用torch.nn.utils.clip_grad_norm_()。注意
如果至少有一些FSDP实例使用分片策略(即除了
NO_SHARD之外的策略),那么你应该使用这个方法而不是torch.nn.utils.clip_grad_norm_(),因为这个方法处理了梯度在各个rank之间分片的事实。注意
返回的总范数将具有跨所有参数/梯度的“最大”dtype,如PyTorch的类型提升语义所定义。例如,如果所有参数/梯度使用低精度dtype,则返回的范数的dtype将是该低精度dtype,但如果至少存在一个使用FP32的参数/梯度,则返回的范数的dtype将是FP32。
警告
这需要在所有rank上调用,因为它使用了集体通信。
- static flatten_sharded_optim_state_dict(sharded_optim_state_dict, model, optim)[源代码]¶
展平一个分片优化器状态字典。
该API与
shard_full_optim_state_dict()类似。唯一的区别是输入的sharded_optim_state_dict应该来自sharded_optim_state_dict()。因此,每个rank上都会有all-gather调用来收集ShardedTensor。- Parameters
sharded_optim_state_dict (Dict[str, Any]) – 对应于未展平参数的分片优化器状态字典,并持有分片优化器状态。
模型 (torch.nn.Module) – 参考
shard_full_optim_state_dict().optim (torch.optim.Optimizer) – 用于
模型参数的优化器。
- Returns
- Return type
- static fsdp_modules(module, root_only=False)[源代码]¶
返回所有嵌套的FSDP实例。
这可能包括
module本身,并且仅在root_only=True时包括FSDP根模块。- Parameters
模块 (torch.nn.Module) – 根模块,可能是也可能不是
FSDP模块。root_only (bool) – 是否只返回FSDP根模块。 (默认值:
False)
- Returns
嵌套在输入
module中的 FSDP 模块。- Return type
- static full_optim_state_dict(model, optim, optim_input=None, rank0_only=True, group=None)[源代码]¶
返回完整的优化器状态字典。
在 rank 0 上整合完整的优化器状态并将其返回为
dict,遵循torch.optim.Optimizer.state_dict()的约定,即包含键"state"和"param_groups"。model中包含的FSDP模块的展平参数被映射回其未展平的参数。警告
这需要在所有rank上调用,因为它使用了集体通信。然而,如果
rank0_only=True,那么状态字典仅在rank 0上填充,所有其他rank返回一个空的dict。警告
与
torch.optim.Optimizer.state_dict()不同,此方法使用完整的参数名称作为键,而不是参数ID。注意
类似于在
torch.optim.Optimizer.state_dict()中,优化器状态字典中包含的张量没有被克隆,因此可能会有别名惊喜。为了最佳实践,建议立即保存返回的优化器状态字典,例如使用torch.save()。- Parameters
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例),其参数 被传递到优化器optim中。optim (torch.optim.Optimizer) – 用于
模型参数的优化器。optim_input (可选[联合[列表[字典[str, 任意]], 可迭代[torch.nn.Parameter]]]) – 传递给优化器
optim的输入,表示参数组列表或参数的可迭代对象; 如果None,则此方法假设输入为model.parameters()。此参数已弃用,不再需要传递它。(默认值:None)rank0_only (bool) – 如果为
True,仅在rank 0上保存填充的dict;如果为False,则在所有rank上保存。(默认值:True)组 (dist.ProcessGroup) – 模型的进程组,如果使用默认进程组则为
None。(默认值:None)
- Returns
包含优化器状态的字典,用于
model的原始未展平参数,并包含键“state”和“param_groups”,遵循torch.optim.Optimizer.state_dict()的约定。如果rank0_only=True,则非零等级返回一个空字典。- Return type
Dict[str, Any]
- static get_state_dict_type(module)[源代码]¶
获取以
module为根的 FSDP 模块的状态字典类型和相应的配置。目标模块不一定是FSDP模块。
- Returns
包含当前设置的 state_dict_type 和 state_dict / optim_state_dict 配置的
StateDictSettings。- Raises
如果 StateDictSettings 不同,则抛出 AssertionError –
FSDP 子模块不同。 –
- Return type
- named_buffers(*args, **kwargs)[源代码]¶
返回一个遍历模块缓冲区的迭代器,同时生成缓冲区的名称和缓冲区本身。
拦截缓冲区名称并移除在
summon_full_params()上下文管理器内部的所有 FSDP 特定扁平缓冲区前缀的出现。
- named_parameters(*args, **kwargs)[源代码]¶
返回一个遍历模块参数的迭代器,同时生成参数的名称和参数本身。
拦截参数名称并移除在
summon_full_params()上下文管理器内部的所有FSDP特定扁平化参数前缀的出现。
- no_sync()[源代码]¶
禁用跨FSDP实例的梯度同步。
在此上下文中,梯度将被累积在模块变量中,这些变量将在退出上下文后的第一次前向-后向传递中同步。这应该仅在根FSDP实例上使用,并将递归应用于所有子FSDP实例。
注意
这可能会导致更高的内存使用,因为FSDP会累积完整的模型梯度(而不是梯度分片),直到最终同步。
注意
当与CPU卸载一起使用时,在上下文管理器内部,梯度不会被卸载到CPU。相反,它们只会在最终同步后立即被卸载。
- Return type
- static optim_state_dict(model, optim, optim_state_dict=None, group=None)[源代码]¶
转换对应于分片模型的优化器的状态字典。
给定的状态字典可以转换为以下三种类型之一: 1) 完整的优化器状态字典,2) 分片优化器状态字典,3) 本地优化器状态字典。
对于完整的优化器状态字典,所有状态都是未展平且未分片的。 可以通过
state_dict_type()指定仅Rank0和仅CPU,以避免OOM。对于分片优化器状态字典,所有状态都是未展平但分片的。 可以通过
state_dict_type()指定仅CPU以进一步节省内存。对于本地 state_dict,不会进行转换。但状态将从 nn.Tensor 转换为 ShardedTensor 以表示其分片性质(这尚未支持)。
示例:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # 保存检查点 >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # 加载检查点 >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- Parameters
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例),其参数 被传递到优化器optim中。optim (torch.optim.Optimizer) – 用于
模型参数的优化器。optim_state_dict (Dict[str, Any]) – 要转换的目标优化器状态字典。如果值为 None,将使用 optim.state_dict()。( 默认值:
None)组 (dist.ProcessGroup) – 模型在跨参数分片或使用默认进程组时的进程组。( 默认值:
None)
- Returns
包含优化器状态的
dictmodel。优化器状态的分片基于state_dict_type。- Return type
Dict[str, Any]
- static optim_state_dict_to_load(model, optim, optim_state_dict, is_named_optimizer=False, load_directly=False, group=None)[源代码]¶
转换优化器状态字典,使其可以加载到与FSDP模型关联的优化器中。
给定一个通过
optim_state_dict()转换的optim_state_dict,它会被转换为可以加载到optim的扁平化优化器状态字典,其中optim是model的优化器。model必须由FullyShardedDataParallel进行分片。>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> from torch.distributed.fsdp import StateDictType >>> from torch.distributed.fsdp import FullStateDictConfig >>> from torch.distributed.fsdp import FullOptimStateDictConfig >>> # 保存检查点 >>> model, optim = ... >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> state_dict = model.state_dict() >>> original_osd = optim.state_dict() >>> optim_state_dict = FSDP.optim_state_dict( >>> model, >>> optim, >>> optim_state_dict=original_osd >>> ) >>> save_a_checkpoint(state_dict, optim_state_dict) >>> # 加载检查点 >>> model, optim = ... >>> state_dict, optim_state_dict = load_a_checkpoint() >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.FULL_STATE_DICT, >>> FullStateDictConfig(rank0_only=False), >>> FullOptimStateDictConfig(rank0_only=False), >>> ) >>> model.load_state_dict(state_dict) >>> optim_state_dict = FSDP.optim_state_dict_to_load( >>> model, optim, optim_state_dict >>> ) >>> optim.load_state_dict(optim_state_dict)
- Parameters
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例),其参数 被传递到优化器optim中。optim (torch.optim.Optimizer) – 用于
模型参数的优化器。optim_state_dict (Dict[str, Any]) – 要加载的优化器状态。
is_named_optimizer (bool) – 此优化器是否为NamedOptimizer或KeyedOptimizer。仅当
optim是TorchRec的KeyedOptimizer或torch.distributed的NamedOptimizer时,才设置为True。load_directly (bool) – 如果设置为True,此API将在返回结果之前调用optim.load_state_dict(result)。否则,用户需要负责调用
optim.load_state_dict()(默认值:False)组 (dist.ProcessGroup) – 模型在跨参数分片或使用默认进程组时的进程组。( 默认值:
None)
- Return type
- register_comm_hook(state, hook)[源代码]¶
注册一个通信钩子。
这是一个增强功能,为用户提供了一个灵活的钩子,用户可以指定FSDP如何在多个工作节点之间聚合梯度。 这个钩子可以用于实现几种算法,如 GossipGrad和梯度压缩, 这些算法在训练时涉及不同的通信策略来进行参数同步,使用
FullyShardedDataParallel。警告
FSDP通信钩子应在运行初始前向传递之前注册,并且只能注册一次。
- Parameters
状态 (对象) –
传递给钩子以在训练过程中维护任何状态信息。 示例包括梯度压缩中的错误反馈, GossipGrad中下一个要通信的对等方等。 它由每个工作节点本地存储, 并由工作节点上的所有梯度张量共享。
钩子 (可调用对象) – 可调用对象,具有以下签名之一: 1)
hook: Callable[torch.Tensor] -> None: 此函数接收一个Python张量,该张量表示 与该FSDP单元所包装的模型对应的所有变量的完整、展平、未分片的梯度 (未被其他FSDP子单元包装的变量)。 然后执行所有必要的处理并返回None; 2)hook: Callable[torch.Tensor, torch.Tensor] -> None: 此函数接收两个Python张量,第一个张量表示 与该FSDP单元所包装的模型对应的所有变量的完整、展平、未分片的梯度 (未被其他FSDP子单元包装的变量)。后者 表示一个预先分配的张量,用于存储归约后分片梯度的一部分。 在这两种情况下,可调用对象执行所有必要的处理并返回None。 具有签名1的可调用对象预计处理NO_SHARD情况下的梯度通信。 具有签名2的可调用对象预计处理分片情况下的梯度通信。
- static rekey_optim_state_dict(optim_state_dict, optim_state_key_type, model, optim_input=None, optim=None)[源代码]¶
重新键优化器状态字典
optim_state_dict以使用键类型optim_state_key_type。这可以用于实现具有FSDP实例的模型与没有FSDP实例的模型的优化器状态字典之间的兼容性。
要重新生成一个FSDP完整优化器状态字典(即从
full_optim_state_dict())以使用参数ID并可加载到 一个非包装模型中:>>> wrapped_model, wrapped_optim = ... >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) >>> nonwrapped_model, nonwrapped_optim = ... >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) >>> nonwrapped_optim.load_state_dict(rekeyed_osd)
要将非包装模型中的普通优化器状态字典重新键入,以便可以加载到包装模型中:
>>> nonwrapped_model, nonwrapped_optim = ... >>> osd = nonwrapped_optim.state_dict() >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) >>> wrapped_model, wrapped_optim = ... >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) >>> wrapped_optim.load_state_dict(sharded_osd)
- Returns
使用由
optim_state_key_type指定的参数键重新键入的优化器状态字典。- Return type
Dict[str, Any]
- static scatter_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None, group=None)[源代码]¶
将完整的优化器状态字典从排名0散布到所有其他排名。
返回每个rank上的分片优化器状态字典。 返回值与
shard_full_optim_state_dict()相同,并且在rank 0上,第一个参数应该是full_optim_state_dict()的返回值。示例:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) # 仅在rank 0上非空 >>> # 定义新模型,可能具有不同的world size >>> new_model, new_optim, new_group = ... >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) >>> new_optim.load_state_dict(sharded_osd)
注意
可以使用
shard_full_optim_state_dict()和scatter_full_optim_state_dict()来获取分片优化器状态字典。假设完整的优化器状态字典位于CPU内存中,前者要求每个rank在CPU内存中拥有完整的字典,每个rank独立地对字典进行分片而不进行任何通信,而后者仅要求rank 0在CPU内存中拥有完整的字典,rank 0将每个分片移动到GPU内存(用于NCCL)并适当地进行通信。因此,前者具有更高的聚合CPU内存成本,而后者具有更高的通信成本。- Parameters
full_optim_state_dict (可选[字典[str, 任意]]) – 对应于未展平参数的优化器状态字典,并在rank 0上保存完整的非分片优化器状态;该参数在非零rank上被忽略。
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例),其参数对应于full_optim_state_dict中的优化器状态。optim_input (可选[联合[列表[字典[str, 任意]], 可迭代[torch.nn.Parameter]]]) – 传递给优化器的输入,表示参数组列表或参数的可迭代对象; 如果为
None,则此方法假设输入为model.parameters()。此参数已弃用,不再需要传递它。(默认值:None)optim (可选[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比
optim_input更推荐的参数。(默认值:None)组 (dist.ProcessGroup) – 模型的进程组,如果使用默认进程组则为
None。(默认值:None)
- Returns
完整的优化器状态字典现在重新映射为 扁平化的参数,而不是未扁平化的参数,并且 仅限于包含此rank的优化器状态部分。
- Return type
Dict[str, Any]
- static set_state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[源代码]¶
设置目标模块的所有后代FSDP模块的
state_dict_type。还可以为模型的和优化器的状态字典提供(可选)配置。目标模块不一定是FSDP模块。如果目标模块是FSDP模块,其
state_dict_type也将被更改。注意
此API应仅针对顶级(根)模块调用。
注意
此API使用户能够透明地使用传统的
state_dictAPI 在根 FSDP 模块被另一个nn.Module包装的情况下保存模型检查点。例如, 以下代码将确保在所有非 FSDP 实例上调用state_dict,同时分发到 sharded_state_dict 实现 对于 FSDP:示例:
>>> model = DDP(FSDP(...)) >>> FSDP.set_state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), >>> ) >>> param_state_dict = model.state_dict() >>> optim_state_dict = FSDP.optim_state_dict(model, optim)
- Parameters
模块 (torch.nn.Module) – 根模块。
state_dict_type (StateDictType) – 要设置的
state_dict_type。state_dict_config (可选[StateDictConfig]) – 目标
state_dict_type的配置。optim_state_dict_config (可选[OptimStateDictConfig]) – 优化器状态字典的配置。
- Returns
一个包含模块先前状态字典类型和配置的StateDictSettings。
- Return type
- static shard_full_optim_state_dict(full_optim_state_dict, model, optim_input=None, optim=None)[源代码]¶
分片一个完整的优化器状态字典。
将
full_optim_state_dict中的状态重新映射为扁平化的参数,而不是未扁平化的参数,并且仅限于此 rank 的优化器状态部分。 第一个参数应该是full_optim_state_dict()的返回值。示例:
>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> model, optim = ... >>> full_osd = FSDP.full_optim_state_dict(model, optim) >>> torch.save(full_osd, PATH) >>> # 定义新模型,可能具有不同的世界大小 >>> new_model, new_optim = ... >>> full_osd = torch.load(PATH) >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) >>> new_optim.load_state_dict(sharded_osd)
注意
可以使用
shard_full_optim_state_dict()和scatter_full_optim_state_dict()来获取分片优化器状态字典。假设完整的优化器状态字典位于CPU内存中,前者要求每个rank在CPU内存中拥有完整的字典,每个rank单独对字典进行分片而无需任何通信,而后者仅要求rank 0在CPU内存中拥有完整的字典,rank 0将每个分片移动到GPU内存(用于NCCL)并适当地将其通信给各个rank。因此,前者具有更高的聚合CPU内存成本,而后者具有更高的通信成本。- Parameters
full_optim_state_dict (Dict[str, Any]) – 对应于未展平参数的优化器状态字典,并持有完整的非分片优化器状态。
模型 (torch.nn.Module) – 根模块(可能是也可能不是
FullyShardedDataParallel实例),其参数对应于full_optim_state_dict中的优化器状态。optim_input (可选[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]) – 传递给优化器的输入,表示参数组列表或参数的可迭代对象; 如果为
None,则此方法假设输入为model.parameters()。此参数已弃用,不再需要传递它。(默认值:None)optim (可选[torch.optim.Optimizer]) – 将加载此方法返回的状态字典的优化器。这是比
optim_input更推荐的参数。(默认值:None)
- Returns
完整的优化器状态字典现在重新映射为 扁平化的参数,而不是未扁平化的参数,并且 仅限于包含此rank的优化器状态部分。
- Return type
Dict[str, Any]
- static sharded_optim_state_dict(model, optim, group=None)[源代码]¶
返回优化器状态字典的分片形式。
该API类似于
full_optim_state_dict(),但此API将所有非零维度的状态分块为ShardedTensor以节省内存。 此API仅应在模型state_dict通过上下文管理器with state_dict_type(SHARDED_STATE_DICT):导出时使用。有关详细用法,请参阅
full_optim_state_dict()。警告
返回的状态字典包含
ShardedTensor, 不能直接被常规的optim.load_state_dict使用。
- static state_dict_type(module, state_dict_type, state_dict_config=None, optim_state_dict_config=None)[源代码]¶
设置目标模块的所有后代FSDP模块的
state_dict_type。此上下文管理器具有与
set_state_dict_type()相同的功能。请阅读set_state_dict_type()的文档以获取详细信息。示例:
>>> model = DDP(FSDP(...)) >>> with FSDP.state_dict_type( >>> model, >>> StateDictType.SHARDED_STATE_DICT, >>> ): >>> checkpoint = model.state_dict()
- Parameters
模块 (torch.nn.Module) – 根模块。
state_dict_type (StateDictType) – 要设置的
state_dict_type。state_dict_config (可选[StateDictConfig]) – 目标
state_dict_type的模型state_dict配置。optim_state_dict_config (可选[OptimStateDictConfig]) – 目标
state_dict_type的优化器state_dict配置。
- Return type
- static summon_full_params(module, recurse=True, writeback=True, rank0_only=False, offload_to_cpu=False, with_grads=False)[源代码]¶
使用此上下文管理器公开FSDP实例的完整参数。
在模型进行前向/后向传播后,获取参数以进行额外处理或检查可能很有用。它可以接受一个非FSDP模块,并根据
recurse参数,召唤所有包含的FSDP模块及其子模块的完整参数。注意
这可以用于内部FSDPs。
注意
这不能在正向或反向传播中使用。也不能在此上下文中启动正向和反向传播。
注意
参数将在上下文管理器退出后恢复到它们的本地分片,存储行为与前向传播相同。
注意
可以修改完整的参数,但只有与本地参数分片对应的部分在上下文管理器退出后会持久化(除非
writeback=False,在这种情况下更改将被丢弃)。在FSDP不进行参数分片的情况下,目前仅当world_size == 1,或NO_SHARD配置时,修改将持久化,无论writeback如何。注意
此方法适用于本身不是FSDP但可能包含多个独立FSDP单元的模块。在这种情况下,给定的参数将应用于所有包含的FSDP单元。
警告
请注意,
rank0_only=True与writeback=True目前不支持同时使用,并且会引发错误。这是因为模型参数的形状在上下文中会因不同等级而异,并且在退出上下文时写入它们可能会导致不同等级之间的不一致。警告
请注意,
offload_to_cpu和rank0_only=False会导致完整参数被冗余地复制到同一台机器上的 GPU 的 CPU 内存中,这可能会导致 CPU 内存不足的风险。建议使用offload_to_cpu并设置rank0_only=True。- Parameters
递归 (bool, 可选) – 递归地为嵌套的FSDP实例召唤所有参数(默认值:True)。
回写 (布尔值, 可选) – 如果
False,在上下文管理器退出后,对参数的修改将被丢弃; 禁用此功能可能会稍微提高效率(默认值:True)rank0_only (bool, 可选) – 如果
True,完整参数仅在全局排名0上具体化。这意味着在上下文中,只有排名0将拥有完整参数,而其他排名将拥有分片参数。请注意,在writeback=True的情况下设置rank0_only=True是不支持的,因为模型参数形状在上下文中将因排名不同而不同,并且在退出上下文时写入它们可能导致排名间的不一致。offload_to_cpu (bool, Optional) – 如果
True,完整参数将被卸载到CPU。请注意,此卸载目前仅在参数被分片时发生(仅在world_size = 1或NO_SHARD配置时不会发生)。建议将offload_to_cpu与rank0_only=True一起使用,以避免将模型参数的冗余副本卸载到相同的CPU内存中。with_grads (bool, 可选) – 如果
True,梯度也会与参数一起取消分片。目前,这仅在传递use_orig_params=True到 FSDP 构造函数和offload_to_cpu=False到此方法时才支持。 (默认值:False)
- Return type
- class torch.distributed.fsdp.BackwardPrefetch(value)[源代码]¶
这配置了显式的反向预取,通过在反向传播过程中允许通信和计算重叠,从而提高了吞吐量,但代价是稍微增加了内存使用。
BACKWARD_PRE: 这使得重叠最多,但内存使用量也最大。这会在当前参数集的梯度计算之前预取下一组参数。这重叠了下一个all-gather和当前梯度计算,并且在峰值时,它在内存中保存当前参数集、下一组参数和当前梯度集。BACKWARD_POST: 这减少了重叠但需要较少的内存使用。这会在当前参数集的梯度计算之后预取下一组参数。这重叠了当前的reduce-scatter和下一轮梯度计算,并在为下一组参数分配内存之前释放当前参数集,仅在内存峰值时保留下一组参数和当前梯度集。FSDP 的
backward_prefetch参数接受None,这将完全禁用反向预取。这不会重叠,也不会增加内存使用。通常情况下,我们不推荐此设置,因为它可能会显著降低吞吐量。
更多技术背景:对于使用NCCL后端的单个进程组,任何集体操作,即使是从不同的流发出的,也会争夺相同的每个设备的NCCL流,这意味着集体操作发出的相对顺序对于重叠操作很重要。两个反向预取值对应于不同的发出顺序。
- class torch.distributed.fsdp.ShardingStrategy(value)[源代码]¶
这指定了用于分布式训练的分片策略,由
FullyShardedDataParallel使用。FULL_SHARD: 参数、梯度和优化器状态被分片。 对于参数,此策略在正向传播前解片(通过全收集),在正向传播后重新分片,在反向传播计算前解片,在反向传播计算后重新分片。对于梯度,它在反向传播计算后同步并分片(通过减少分散)。分片的优化器状态在每个等级本地更新。SHARD_GRAD_OP: 在计算过程中,梯度和优化器状态被分片,此外,参数在计算外部也被分片。对于参数,此策略在正向传播之前取消分片,在正向传播之后不重新分片,仅在反向传播计算之后重新分片。分片的优化器状态在每个rank本地更新。在no_sync()内部,参数在反向传播计算之后不会重新分片。NO_SHARD: 参数、梯度和优化器状态不会被分片,而是像PyTorch的DistributedDataParallelAPI一样在各个rank之间复制。对于梯度,此策略在反向计算后通过all-reduce同步它们。未分片的优化器状态在每个rank上本地更新。HYBRID_SHARD: 在节点内应用FULL_SHARD,并在节点间复制参数。这减少了通信量,因为昂贵的全收集和减少分散操作仅在节点内进行,这对于中等大小的模型可以提高性能。_HYBRID_SHARD_ZERO2: 在节点内应用SHARD_GRAD_OP,并在节点间复制参数。这类似于HYBRID_SHARD,但这种方法可能提供更高的吞吐量,因为未分片的参数在正向传递后不会被释放,从而节省了预反向中的所有收集操作。
- class torch.distributed.fsdp.MixedPrecision(param_dtype=None, reduce_dtype=None, buffer_dtype=None, keep_low_precision_grads=False, cast_forward_inputs=False, cast_root_forward_inputs=True, _module_classes_to_ignore=(<class 'torch.nn.modules.batchnorm._BatchNorm'>, ))[源代码]¶
这配置了FSDP-原生的混合精度训练。
- Variables
param_dtype(可选[torch.dtype])– 这指定了模型参数在前向和反向传播期间的dtype,从而指定了前向和反向计算的dtype。在前向和反向传播之外,分片的参数保持在全精度(例如,用于优化器步骤),并且对于模型检查点,参数总是以全精度保存。(默认值:
None)reduce_dtype (可选[torch.dtype]) – 这指定了梯度缩减(即reduce-scatter或all-reduce)的数据类型。如果这是
None但param_dtype不是None,则这采用param_dtype的值,仍然以低精度运行梯度缩减。这允许与param_dtype不同,例如强制梯度缩减以全精度运行。(默认值:None)buffer_dtype (可选[torch.dtype]) – 这指定了缓冲区的数据类型。FSDP 不会对缓冲区进行分片。相反,FSDP 在第一次前向传递中将它们转换为
buffer_dtype,并在之后保持该数据类型。对于模型检查点,缓冲区以全精度保存,除了LOCAL_STATE_DICT。(默认值:None)keep_low_precision_grads (bool) – 如果
False,则在反向传播后,FSDP会将梯度提升为全精度,以准备优化器步骤。如果True,则FSDP会保持梯度在使用梯度缩减时使用的数据类型,如果使用支持低精度运行的自定义优化器,这可以节省内存。(默认值:False)cast_forward_inputs (bool) – 如果
True,则此FSDP模块将其前向参数和关键字参数转换为param_dtype。这是为了确保参数和输入数据类型在前向计算中匹配,因为许多操作都需要这一点。当仅对部分但不是所有FSDP模块应用混合精度时,可能需要将其设置为True,在这种情况下,混合精度的FSDP子模块需要重新转换其输入。(默认值:False)cast_root_forward_inputs (bool) – 如果
True,则根FSDP模块将其前向args和kwargs转换为param_dtype,覆盖cast_forward_inputs的值。对于非根FSDP模块,这不会执行任何操作。(默认值:True)_module_classes_to_ignore (Sequence[Type[torch.nn.modules.module.Module]]) – (Sequence[Type[nn.Module]]): 这指定了在使用
auto_wrap_policy时忽略混合精度的模块类:这些类的模块将单独应用FSDP,且混合精度被禁用(这意味着最终的FSDP构造将偏离指定的策略)。如果未指定auto_wrap_policy,则此设置无效。此API是实验性的,可能会发生变化。(默认值:(_BatchNorm,))
注意
此API是实验性的,可能会发生变化。
注意
仅浮点张量会被转换为其指定的数据类型。
注意
在
summon_full_params中,参数被强制为全精度,但缓冲区不会。注意
即使输入是低精度(如
float16或bfloat16),层归一化和批归一化仍然以float32进行累积。禁用 FSDP 对这些归一化模块的混合精度仅意味着仿射参数保持在float32。然而,这会导致这些归一化模块的单独全收集和减少分散操作,这可能效率不高,因此如果工作负载允许,用户应优先考虑对这些模块仍然应用混合精度。注意
默认情况下,如果用户传递的模型包含任何
_BatchNorm模块并指定了auto_wrap_policy,那么批归一化模块将单独应用 FSDP,并且禁用混合精度。请参阅_module_classes_to_ignore参数。注意
MixedPrecision默认情况下cast_root_forward_inputs=True并且cast_forward_inputs=False。对于根 FSDP 实例, 其cast_root_forward_inputs优先于其cast_forward_inputs。对于非根 FSDP 实例,它们的cast_root_forward_inputs值被忽略。默认设置对于典型情况已经足够, 即每个 FSDP 实例具有相同的MixedPrecision配置,并且只需要在模型前向传递开始时将输入转换为param_dtype。注意
对于具有不同
MixedPrecision配置的嵌套FSDP实例,我们建议为每个实例设置单独的cast_forward_inputs值,以配置是否在每个实例的前向传播之前进行输入转换。在这种情况下,由于转换发生在每个FSDP实例的前向传播之前,父FSDP实例应在其FSDP子模块之前运行其非FSDP子模块,以避免由于不同的MixedPrecision配置而导致激活数据类型发生变化。示例:
>>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) >>> model[1] = FSDP( >>> model[1], >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), >>> ) >>> model = FSDP( >>> model, >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), >>> )
上述展示了一个工作示例。另一方面,如果将
model[1]替换为model[0],意味着使用不同MixedPrecision的子模块首先运行其前向传播,那么model[1]将会错误地看到float16激活值,而不是bfloat16激活值。
- class torch.distributed.fsdp.CPUOffload(offload_params=False)[源代码]¶
这配置了CPU卸载。
- Variables
offload_params (bool) – 这指定了是否在参数不参与计算时将其卸载到CPU。如果
True,则这也会将梯度卸载到CPU,意味着优化器步骤在CPU上运行。
- class torch.distributed.fsdp.StateDictConfig(offload_to_cpu=False)[源代码]¶
StateDictConfig是所有state_dict配置类的基类。用户应实例化一个子类(例如FullStateDictConfig)以配置 FSDP 支持的相应state_dict类型的设置。- Variables
offload_to_cpu (bool) – 如果
True,则FSDP将状态字典值卸载到CPU,如果False,则FSDP将它们保留在GPU上。 (默认值:False)
- class torch.distributed.fsdp.FullStateDictConfig(offload_to_cpu=False, rank0_only=False)[源代码]¶
FullStateDictConfig是一个配置类,旨在与StateDictType.FULL_STATE_DICT一起使用。我们建议在保存完整状态字典时启用offload_to_cpu=True和rank0_only=True,以分别节省 GPU 内存和 CPU 内存。此配置类 旨在通过state_dict_type()上下文管理器使用,如下所示:>>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP >>> fsdp = FSDP(model, auto_wrap_policy=...) >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): >>> state = fsdp.state_dict() >>> # `state` 在非 rank 0 上将为空,而在 rank 0 上将包含 CPU 张量。 >>> # 重新加载检查点以进行推理、微调、迁移学习等: >>> model = model_fn() # 初始化模型以准备用 FSDP 包装 >>> if dist.get_rank() == 0: >>> # 仅在 rank 0 上加载检查点以避免内存冗余 >>> state_dict = torch.load("my_checkpoint.pt") >>> model.load_state_dict(state_dict) >>> # 所有 rank 像往常一样初始化 FSDP 模块。`sync_module_states` 参数 >>> # 将加载的检查点状态从 rank 0 传递给其他 rank。 >>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True) >>> # 在此之后,所有 rank 都具有加载了检查点的 FSDP 模型。
- Variables
rank0_only (bool) – 如果
True,则只有rank 0保存完整的state dict,非零rank保存一个空的dict。如果False,则所有rank都保存完整的state dict。(默认值:False)
- class torch.distributed.fsdp.ShardedStateDictConfig(offload_to_cpu=False, _use_dtensor=False)[源代码]¶
ShardedStateDictConfig是一个配置类,用于与StateDictType.SHARDED_STATE_DICT一起使用。- Variables
_use_dtensor (布尔值) – 如果
True,则 FSDP 将状态字典值保存为DTensor,如果False,则 FSDP 将它们保存为ShardedTensor。(默认值:False)
警告
_use_dtensor是ShardedStateDictConfig的一个私有字段,它被 FSDP 用来确定状态字典值的类型。用户不应手动修改_use_dtensor。
- class torch.distributed.fsdp.OptimStateDictConfig(offload_to_cpu=True)[源代码]¶
OptimStateDictConfig是所有optim_state_dict配置类的基类。 用户应实例化一个子类(例如FullOptimStateDictConfig)以配置 FSDP 支持的相应optim_state_dict类型的设置。- Variables
offload_to_cpu (bool) – 如果
True,则FSDP将状态字典的张量值卸载到CPU,如果False,则FSDP将它们保留在原始设备上(除非启用了参数CPU卸载,否则为GPU)。(默认值:True)
- class torch.distributed.fsdp.FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False)[源代码]¶
- Variables
rank0_only (bool) – 如果
True,则只有rank 0保存完整的state dict,非零rank保存一个空的dict。如果False,则所有rank都保存完整的state dict。(默认值:False)
- class torch.distributed.fsdp.ShardedOptimStateDictConfig(offload_to_cpu=True, _use_dtensor=False)[源代码]¶
ShardedOptimStateDictConfig是一个配置类,旨在与StateDictType.SHARDED_STATE_DICT一起使用。- Variables
_use_dtensor (布尔值) – 如果
True,则FSDP将状态字典值保存为DTensor,如果False,则FSDP将它们保存为ShardedTensor。(默认值:False)
警告
_use_dtensor是ShardedOptimStateDictConfig的一个私有字段,它被 FSDP 用来确定状态字典值的类型。用户不应手动修改_use_dtensor。
- class torch.distributed.fsdp.StateDictSettings(state_dict_type: torch.distributed.fsdp.api.StateDictType, state_dict_config: torch.distributed.fsdp.api.StateDictConfig, optim_state_dict_config: torch.distributed.fsdp.api.OptimStateDictConfig)[源代码]¶