gather_cpu_state_dict¶
- torchtune.training.gather_cpu_state_dict(sharded_sd: Dict[str, DTensor], is_rank_zero: bool, device: Optional[device] = None) Dict[str, Any][source]¶
将分片状态字典转换为CPU上的完整状态字典 仅在rank0上返回非空结果,以避免CPU内存峰值
- Parameters:
sharded_sd (Dict[str, DTensor]) – DTensors的分片状态字典
is_rank_zero (bool) – 用于检查进程是否在rank 0上的标志
device (可选[torch.device]) – 用于分片张量的设备。默认值:无
- Returns:
CPU上的状态字典
- Return type:
字典[str, 任意]