训练脚本¶
如果你的训练脚本适用于 torch.distributed.launch
,它将继续适用于 torchrun
,但有以下区别:
无需手动传递
RANK
,WORLD_SIZE
,MASTER_ADDR
, 和MASTER_PORT
。rdzv_backend
和rdzv_endpoint
可以被提供。对于大多数用户来说,这将设置为c10d
(参见 rendezvous)。默认的rdzv_backend
创建一个非弹性的rendezvous,其中rdzv_endpoint
持有主地址。确保你的脚本中有
load_checkpoint(path)
和save_checkpoint(path)
逻辑。当任意数量的 工作节点失败时,我们会使用相同的程序参数重新启动所有工作节点,因此你将丢失到最近检查点为止的进度 (参见 弹性启动)。use_env
标志已被移除。如果你之前是通过解析--local-rank
选项来获取本地排名,你需要从环境变量LOCAL_RANK
中获取本地排名(例如int(os.environ["LOCAL_RANK"])
)。
下面是一个训练脚本的示例,该脚本在每个epoch结束时进行检查点保存,因此在失败时可能丢失的最坏情况下的进度是一个完整epoch的训练量。
def main():
args = parse_args(sys.argv[1:])
state = load_checkpoint(args.checkpoint_path)
initialize(state)
# torch.distributed.run 确保这将工作
# 通过导出所有初始化进程组所需的env vars
torch.distributed.init_process_group(backend=args.backend)
for i in range(state.epoch, state.total_num_epochs)
for batch in iter(state.dataset)
train(batch, state.model)
state.epoch += 1
save_checkpoint(state)
有关符合torchelastic标准的训练脚本的实际示例,请访问我们的示例页面。