Shortcuts

训练脚本

如果你的训练脚本适用于 torch.distributed.launch,它将继续适用于 torchrun,但有以下区别:

  1. 无需手动传递 RANK, WORLD_SIZE, MASTER_ADDR, 和 MASTER_PORT

  2. rdzv_backendrdzv_endpoint 可以被提供。对于大多数用户来说,这将设置为 c10d(参见 rendezvous)。默认的 rdzv_backend 创建一个非弹性的rendezvous,其中 rdzv_endpoint 持有主地址。

  3. 确保你的脚本中有 load_checkpoint(path)save_checkpoint(path) 逻辑。当任意数量的 工作节点失败时,我们会使用相同的程序参数重新启动所有工作节点,因此你将丢失到最近检查点为止的进度 (参见 弹性启动)。

  4. use_env 标志已被移除。如果你之前是通过解析 --local-rank 选项来获取本地排名,你需要从环境变量 LOCAL_RANK 中获取本地排名(例如 int(os.environ["LOCAL_RANK"]))。

下面是一个训练脚本的示例,该脚本在每个epoch结束时进行检查点保存,因此在失败时可能丢失的最坏情况下的进度是一个完整epoch的训练量。

def main():
     args = parse_args(sys.argv[1:])
     state = load_checkpoint(args.checkpoint_path)
     initialize(state)

     # torch.distributed.run 确保这将工作
     # 通过导出所有初始化进程组所需的env vars
     torch.distributed.init_process_group(backend=args.backend)

     for i in range(state.epoch, state.total_num_epochs)
          for batch in iter(state.dataset)
              train(batch, state.model)

          state.epoch += 1
          save_checkpoint(state)

有关符合torchelastic标准的训练脚本的实际示例,请访问我们的示例页面。