Shortcuts

加载预训练模型并恢复训练

在使用DI-engine进行强化学习实验时,加载预训练的ckpt文件以从检查点恢复训练是一个常见需求。本文详细解释了如何使用DI-engine加载预训练模型并无缝恢复训练,以cartpole_ppo_config.py为例。

加载预训练模型

配置 load_ckpt_before_run

要加载预训练模型,首先需要在配置文件中指定预训练的ckpt文件的路径。此路径通过load_ckpt_before_run字段进行配置。

示例代码:

from easydict import EasyDict

cartpole_ppo_config = dict(
    exp_name='cartpole_ppo_seed0',
    env=dict(
        collector_env_num=8,
        evaluator_env_num=5,
        n_evaluator_episode=5,
        stop_value=195,
    ),
    policy=dict(
        cuda=False,
        action_space='discrete',
        model=dict(
            obs_shape=4,
            action_shape=2,
            action_space='discrete',
            encoder_hidden_size_list=[64, 64, 128],
            critic_head_hidden_size=128,
            actor_head_hidden_size=128,
        ),
        learn=dict(
            epoch_per_collect=2,
            batch_size=64,
            learning_rate=0.001,
            value_weight=0.5,
            entropy_weight=0.01,
            clip_ratio=0.2,
            # ======== Path to the pretrained checkpoint (ckpt) ========
            learner=dict(hook=dict(load_ckpt_before_run='/path/to/your/ckpt/iteration_100.pth.tar')),
            resume_training=False,
        ),
        collect=dict(
            n_sample=256,
            unroll_len=1,
            discount_factor=0.9,
            gae_lambda=0.95,
        ),
        eval=dict(evaluator=dict(eval_freq=100, ), ),
    ),
)
cartpole_ppo_config = EasyDict(cartpole_ppo_config)
main_config = cartpole_ppo_config

在上述示例中,load_ckpt_before_run 明确指定了预训练模型的路径 /path/to/your/ckpt/iteration_100.pth.tar。当你运行这段代码时,DI-engine 会自动从该路径加载模型权重,并从中继续训练。

模型加载过程

模型加载过程主要发生在entry路径下的主文件中。下面,我们以serial_entry_onpolicy.py文件为例来解释这个过程。

加载预训练模型的关键操作是通过DI-engine的hook机制实现的:

# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.get('resume_training', False):
    collector.envstep = learner.collector_envstep

load_ckpt_before_run 不为空时,DI-engine 会自动调用 learnerbefore_run 钩子函数,从指定路径加载预训练模型。具体实现可以在 DI-engine 的 learner_hook.py 中找到。

策略本身的检查点保存和加载功能通过_load_state_dict_learn_state_dict_learn方法实现。例如,在PPO策略中,可以在以下位置找到实现:

从检查点恢复训练

恢复时管理日志和TensorBoard路径

默认情况下,DI-engine 会为每个实验创建一个新的日志路径,以避免覆盖之前的训练数据和 TensorBoard 日志。但是,如果你想在恢复训练时将日志和 TensorBoard 数据保存在同一目录中,可以通过在配置文件中设置 resume_training=True 来配置此功能(其默认值为 False)。

示例代码:

learn=dict(
    ...  # Other parts of the code
    learner=dict(hook=dict(load_ckpt_before_run='/path/to/your/ckpt/iteration_100.pth.tar')),
    resume_training=True,
)

resume_training=True 时,DI-engine 会将新的日志和 TensorBoard 数据保存在原始路径中。

关键代码:

# Note that the default value of renew_dir is True. When resume_training=True, renew_dir is set to False to ensure the consistency of log paths.
cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=not cfg.policy.learn.get('resume_training', False))

同时,从加载的ckpt文件中恢复的train_itercollector.envstep将被恢复,使得训练可以从上一个检查点无缝继续。

恢复迭代/步骤计数当恢复时

从检查点恢复训练时,训练itersteps将从检查点中保存的最后一次迭代和步骤计数中恢复。这确保了训练过程从正确的点继续,保持了训练进度的完整性。

首次训练(预训练)结果:

以下图表显示了第一次训练(预训练)的evaluator结果,x轴分别为itersteps

../_images/cartpole_ppo_evaluator_iter_pretrain.png ../_images/cartpole_ppo_evaluator_step_pretrain.png

第二次训练(恢复)结果:

以下图表显示了第二次训练(恢复)的evaluator结果,分别在x轴上显示了itersteps

../_images/cartpole_ppo_evaluator_iter_resume.png ../_images/cartpole_ppo_evaluator_step_resume.png

这些图表清楚地表明,训练在恢复后从停止的地方继续,评估指标在相同的迭代/步骤中显示出一致性。

结论

在使用DI-engine进行强化学习实验时,加载预训练模型并从检查点恢复训练对于确保稳定、长期的训练至关重要。从本文提供的示例和解释中,我们可以观察到以下内容:

  1. 加载预训练模型通过load_ckpt_before_run字段进行配置,并在训练前通过hook机制自动加载。

  2. 恢复训练可以通过设置resume_training=True来实现,确保无缝的日志管理和训练进度继续。

  3. 在实际实验中,正确管理日志路径和检查点数据可以防止冗余训练和数据丢失,提高实验的效率和可重复性。

我们希望本文能为使用DI-engine的实验提供清晰的指导。