Shortcuts

gather_cpu_state_dict

torchtune.training.gather_cpu_state_dict(sharded_sd: Dict[str, DTensor], is_rank_zero: bool, device: Optional[device] = None) Dict[str, Any][source]

将分片状态字典转换为CPU上的完整状态字典 仅在rank0上返回非空结果,以避免CPU内存峰值

Parameters:
  • sharded_sd (Dict[str, DTensor]) – DTensors的分片状态字典

  • is_rank_zero (bool) – 用于检查进程是否在rank 0上的标志

  • device (可选[torch.device]) – 用于分片张量的设备。默认值:无

Returns:

CPU上的状态字典

Return type:

字典[str, 任意]