如何在 LightZero 中自定义你的算法?
LightZero 是一个 MCTS+RL 强化学习框架,提供了一组高级 API,使用户能够在其中自定义他们的算法。以下是在 LightZero 中自定义算法的一些步骤和注意事项。
基本步骤
1. Understand the Framework Structure
在开始编写您的自定义算法之前,您需要对 LightZero 框架的结构有一个基本的了解。LightZero 流程图如下所示。
仓库的文件夹主要由两部分组成:lzero 和 zoo。lzero 文件夹包含了 LightZero 框架工作流程所需的核心模块。zoo 文件夹提供了一组预定义的环境(envs)及其相应的配置(config)文件。lzero 文件夹包括几个核心模块,包括策略、模型、工作者和入口。这些模块共同工作,以实现复杂的强化学习算法。
在这种架构中,策略模块负责实现算法的决策逻辑,例如在智能体-环境交互期间的动作选择以及如何根据收集的数据更新策略。模型模块负责实现算法所需的神经网络结构。
worker 模块由两个类组成:Collector 和 Evaluator。Collector 类的实例处理代理-环境交互,以收集训练所需的数据,而 Evaluator 类的实例评估当前策略的性能。
入口模块负责初始化环境、模型、策略等,其主循环实现了数据收集、模型训练和策略评估等核心过程。
这些模块之间存在紧密的交互。具体来说,入口模块调用工作模块的收集器和评估器来执行数据收集和算法评估。策略模块的决策函数由收集器和评估器调用,以确定代理在特定环境中的行为。模型模块中实现的神经网络模型嵌入到策略对象中,用于交互过程中的动作生成和训练过程中的更新。
在策略模块中,你可以找到各种算法的实现。例如,MuZero 策略在 muzero.py 文件中实现。
2. Create a New Policy File
在 lzero/policy 目录下创建一个新的 Python 文件。这个文件将包含你的算法实现。例如,如果你的算法名为 MyAlgorithm,你可以创建一个名为 my_algorithm.py 的文件。
3. Implement Your Policy
在你的策略文件中,你需要定义一个类来实现你的策略。这个类应该继承自DI-engine中的Policy类,并实现所需的方法。以下是一个策略类的基本框架:
@POLICY_REGISTRY.register('my_algorithm')
class MyAlgorithmPolicy(Policy):
"""
Overview:
The policy class for MyAlgorithm.
"""
config = dict(
# Add your config here
)
def __init__(self, cfg, **kwargs):
super().__init__(cfg, **kwargs)
# Initialize your policy here
def default_model(self) -> Tuple[str, List[str]]:
# Set the default model name and the import path so that the default model can be loaded during policy initialization
def _init_learn(self):
# Initialize the learn mode here
def _forward_learn(self, data):
# Implement the forward function for learning mode here
def _init_collect(self):
# Initialize the collect mode here
def _forward_collect(self, data, **kwargs):
# Implement the forward function for collect mode here
def _init_eval(self):
# Initialize the eval mode here
def _forward_eval(self, data, **kwargs):
# Implement the forward function for eval mode here
数据收集与模型评估
在 default_model 中,设置当前策略使用的默认模型的类名及其对应的引用路径。
函数 _init_collect 和 _init_eval 负责实例化动作选择策略,并且各自的策略实例将由 _forward_collect 和 _forward_eval 函数调用。
函数 _forward_collect 接受环境的当前状态,并通过调用 _init_collect 中实例化的策略来选择一个步骤动作。该函数返回所选的动作列表和其他相关信息。在训练期间,此函数通过由 Entry 文件创建的 Collector 对象的 collector.collect 方法调用。
toctree是一个 reStructuredText :dfn:指令,这是一个非常多功能的标记。指令可以有参数、选项和内容。
策略学习
函数 _init_learn 使用策略的相关参数(如学习率、更新频率、优化器类型,这些参数从配置文件中传入)初始化网络模型、优化器和其他训练过程中所需的对象。
toctree是一个 reStructuredText :dfn:指令,这是一个非常多功能的标记。指令可以有参数、选项和内容。
4. Register Your Policy
为了让 LightZero 识别您的策略,您需要在策略类上方使用 @POLICY_REGISTRY.register(’my_algorithm’) 装饰器来注册您的策略。这样,LightZero 可以通过名称 ‘my_algorithm’ 引用您的策略。具体来说,在实验的配置文件中,通过 create_config 部分指定相应的算法:
create_config = dict(
...
policy=dict(
type='my_algorithm',
import_names=['lzero.policy.my_algorithm'],
),
...
)
在这里,type 应设置为已注册的策略名称,而 import_names 应设置为策略包的位置。
5. Possible Other Modifications
模型:The LightZero model.common 包提供了一些常见的网络结构,例如将2D图像映射到潜在空间表示的RepresentationNetwork和在MCTS中用于预测概率和节点值的PredictionNetwork。如果自定义策略需要特定的网络模型,则需要在model文件夹下实现相应的模型。例如,MuZero算法的模型保存在muzero_model.py文件中,该文件实现了MuZero算法所需的DynamicsNetwork,并最终通过调用model.common包中的现有网络结构创建了MuZeroModel。
Worker: LightZero 为 AlphaZero 和 MuZero 提供了相应的 worker。后续的算法如 EfficientZero 和 GumbelMuzero 继承了 MuZero 的 worker。如果你的算法在数据收集方面有不同的逻辑,你需要实现相应的 worker。例如,如果你的算法需要对收集的转换进行预处理,可以在收集器的 collect 函数下添加这一段,其中 get_train_sample 函数实现了具体的数据处理过程。
if timestep.done:
# Prepare trajectory data.
transitions = to_tensor_transitions(self._traj_buffer[env_id])
# Use ``get_train_sample`` to process the data.
train_sample = self._policy.get_train_sample(transitions)
return_data.extend(train_sample)
self._traj_buffer[env_id].clear()
6. Test Your Policy
在实施您的策略后,确保其正确性和有效性至关重要。为此,您应该编写一些单元测试来验证您的策略是否正常工作。例如,您可以测试策略是否能在特定环境中执行,以及策略的输出是否与预期结果相符。您可以参考DI-engine中的文档以获取如何编写单元测试的指导。您可以将您的测试添加到lzero/policy/tests中。在编写测试时,尽量考虑所有可能的场景和边界条件,以确保您的策略在各种情况下都能正常运行。
以下是 LightZero 中单元测试的一个示例。在这个示例中,我们测试了 inverse_scalar_transform 和 InverseScalarTransform 方法。这两种方法都反转了给定值的变换,但它们的实现方式不同。在单元测试中,我们将这两种方法应用于相同的数据集,并比较输出结果。如果结果相同,则测试通过。
import pytest
import torch
from lzero.policy.scaling_transform import inverse_scalar_transform, InverseScalarTransform
@pytest.mark.unittest
def test_scaling_transform():
import time
logit = torch.randn(16, 601)
start = time.time()
output_1 = inverse_scalar_transform(logit, 300)
print('t1', time.time() - start)
handle = InverseScalarTransform(300)
start = time.time()
output_2 = handle(logit)
print('t2', time.time() - start)
assert output_1.shape == output_2.shape == (16, 1)
assert (output_1 == output_2).all()
在单元测试文件中,你需要用 @pytest.mark.unittest 标记测试,以将它们包含在 Python 测试框架中。这允许你通过在命令行中输入 pytest -sv xxx.py 直接运行单元测试文件。-sv 是一个命令选项,当使用时,在测试执行期间将详细信息打印到终端,以便于检查。
7. Comprehensive Testing and Running
在确保策略的基本功能后,你需要使用像cartpole这样的经典环境对你的策略进行全面的正确性和收敛性测试。这是为了验证你的策略不仅在单元测试中能有效工作,而且在实际游戏环境中也能有效工作。
您可以通过参考 cartpole_muzero_config.py 编写相关的配置文件和入口程序。在测试过程中,请注意记录策略的性能数据,例如每轮的得分、策略的收敛速度等,以便进行分析和改进。
8. Contribution
完成上述所有步骤后,如果您希望将您的策略贡献到 LightZero 仓库,您可以在官方仓库上提交一个 Pull Request。在提交之前,请确保您的代码符合仓库的编码标准,所有测试都已通过,并且有足够的文档和注释来解释您的代码和策略。
在PR的描述中,详细解释你的策略,包括其工作原理、你的实现方法以及它在测试中的表现。这将有助于他人理解你的贡献并加快PR审查过程。
考虑因素
确保您的代码符合 Python PEP8 编码标准。
在实现诸如 _forward_learn、_forward_collect 和 _forward_eval 等方法时,请确保正确处理输入和返回的数据。
在编写您的策略时,请确保您考虑了不同类型的环境。您的策略应能够处理各种环境。
在实现您的策略时,尽量使您的代码尽可能模块化,以便他人能够理解和重用您的代码。
编写清晰的文档和注释,描述您的策略如何工作以及您的代码如何实现此策略。努力保持内容的核心含义,同时提升其专业性和流畅性。