训练API

deepspeed.initialize() 在其第一个参数中返回一个类型为 DeepSpeedEngine训练引擎。此引擎用于推进训练:

for step, batch in enumerate(data_loader):
    #forward() method
    loss = model_engine(batch)

    #runs backpropagation
    model_engine.backward(loss)

    #weight update
    model_engine.step()

前向传播

反向传播

优化器步骤

梯度累积

模型保存

此外,当创建DeepSpeed检查点时,会添加一个脚本zero_to_fp32.py,该脚本可用于将fp32主权重重建为单个pytorch state_dict文件。