训练API
deepspeed.initialize() 在其第一个参数中返回一个类型为 DeepSpeedEngine 的 训练引擎。此引擎用于推进训练:
for step, batch in enumerate(data_loader):
#forward() method
loss = model_engine(batch)
#runs backpropagation
model_engine.backward(loss)
#weight update
model_engine.step()
前向传播
反向传播
优化器步骤
梯度累积
模型保存
此外,当创建DeepSpeed检查点时,会添加一个脚本zero_to_fp32.py,该脚本可用于将fp32主权重重建为单个pytorch state_dict文件。