Shortcuts

第一个强化学习程序

强化学习是一种在众多机器学习算法中用于创建决策智能人工代理的有前途的算法。 CartPole 是介绍强化学习的理想学习环境,使用 DQN 算法可以使 CartPole 在非常短的时间内收敛(保持平衡)。我们将介绍基于 CartPole + DQN 的 DI-engine 的使用。

../_images/cartpole_cmp.gif

使用配置文件

DI-engine 使用一个全局配置文件来控制环境和策略的所有变量,每个变量都有相应的默认配置,可以在 cartpole_dqn_config 中找到,在本教程中我们直接使用默认配置:

from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config
from ding.config import compile_config

cfg = compile_config(main_config, create_cfg=create_config, auto=True)

初始化环境

在强化学习中,训练过程和评估过程在收集环境数据的策略上可能存在差异,例如,训练过程倾向于为n步收集训练一个周期,而评估过程则需要完成整个游戏以获得分数。我们建议如下分别初始化收集和评估环境。

from ding.envs import DingEnvWrapper, BaseEnvManagerV2

collector_env = BaseEnvManagerV2(
    env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
    cfg=cfg.env.manager
)
evaluator_env = BaseEnvManagerV2(
    env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
    cfg=cfg.env.manager
)

注意

DingEnvWrapper 是 DI-engine 对不同环境库的统一封装。BaseEnvManagerV2 是管理多个环境的统一外部接口。因此,你可以使用 BaseEnvManagerV2 来并行收集多个环境。

选择策略

DI-engine 涵盖了大多数强化学习策略,使用它们只需要选择正确的策略和模型。 由于 DQN 是离策略的,我们还需要实例化一个缓冲模块。

from ding.model import DQN
from ding.policy import DQNPolicy
from ding.data import DequeBuffer

model = DQN(**cfg.policy.model)
buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
policy = DQNPolicy(cfg.policy, model=model)

构建管道

通过DI-engine提供的各种中间件,我们可以轻松构建整个管道:

from ding.framework import task
from ding.framework.context import OnlineRLContext
from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, eps_greedy_handler, CkptSaver

with task.start(async_mode=False, ctx=OnlineRLContext()):
    # Evaluating, we place it on the first place to get the score of the random model as a benchmark value
    task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
    task.use(eps_greedy_handler(cfg))  # Decay probability of explore-exploit
    task.use(StepCollector(cfg, policy.collect_mode, collector_env))  # Collect environmental data
    task.use(data_pusher(cfg, buffer_))  # Push data to buffer
    task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))  # Train the model
    task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))  # Save the model
    # In the evaluation process, if the model is found to have exceeded the convergence value, it will end early here
    task.run()

运行代码

完整的示例可以在DQN示例中找到,并且可以通过python dqn.py运行。 此外,我们还提供了从DI-engine安装到训练的Colab运行示例供参考。

../_images/train_dqn.gif

现在你已经完成了使用DI-engine的第一个强化学习任务,你可以尝试在示例目录中查看更多算法,或者继续阅读文档以更深入地了解DI-engine的算法系统设计最佳实践