Buffer 用户指南¶
Buffer 入门指南¶
缓冲区的基本概念
在离策略强化学习算法中,我们通常使用经验回放来提高样本效率并减少来自不同时间帧样本之间的相关性。 DI-engine 使用 DequeBuffer 来实现经验回放池的常见功能,如数据输入、采样等。用户可以通过以下代码创建 DequeBuffer 对象:
from ding.data import DequeBuffer
buffer = DequeBuffer(size=10)
在DI-engine中,我们使用dataclass作为缓冲区中数据以及其他一些组件的结构载体。 Dataclass是Python3的一个特性,它通过在类中指定字段(类的属性)的数据类型,使数据整洁且一致,这与Dict不同。 与Namedtuple相比,它可以用来设置默认值,以便在初始化阶段启用参数默认值,或者在使用过程中启用灵活的赋值操作。 接下来,我们将为用户介绍缓冲区的具体操作。
# Data is deposited and processed one sample at a time.
# In the middleware of DI-engine, the cache data type is usually a Dict, which records the obs, next_obs, actions, rewards, etc. of the samples.
for _ in range(10):
# The BufferedData object contains three fields: data, index and meta.
# "Data" is the data to be cached and "Meta" is its meta information (optional, defaults to None), both of which are passed into the buffer via the push method.
# "Index" indicates the index of the logical storage address of the data in the buffer, which is automatically generated by the buffer and does not need to be set manually by the user.
buffer.push('a', meta={})
# Data sampling processes multiple samples at a time, and the user needs to explicitly specify the number of samples. The parameter "replace" indicates whether to put back when sampling, and the default value is False.
# The sampling operation returns a data class object named BufferedData, e.g. BufferedData(data='a', index='67bdfadcd', meta={})
buffered_data = buffer.sample(3, replace=False)
data = [d.data for d in buffered_data]
使用Buffer完成在线培训
在前一小节中,我们介绍了数据在缓冲区中的实际存储结构,以及最基本的存储和采样操作。 事实上,在大多数任务中,用户不需要使用这些底层的原子操作。我们建议用户通过DI-engine封装的中间件调用缓冲区对象来完成训练。
from ding.framework import task
from ding.framework.middleware import data_pusher, OffPolicyLearner
task.use(data_pusher(cfg, buffer))
task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer))
使用Buffer加载专家数据
在模仿学习任务中,如SQIL和DQFD,我们需要在训练前加载一些专家经验。实际上,用户可以使用另一个缓冲区来保存专家数据,以SQIL为例(完整代码可以在./ding/example/sqil.py找到):
from ding.framework.middleware import sqil_data_pusher
buffer = DequeBuffer(size=10)
expert_buffer = DequeBuffer(size=10)
task.use(sqil_data_pusher(cfg, buffer_=buffer, expert=False))
task.use(sqil_data_pusher(cfg, buffer_=expert_buffer, expert=True))
task.use(OffPolicyLearner(cfg, policy.learn_mode, [(buffer, 0.5), (expert_buffer, 0.5)]))
缓冲区高级¶
在上一节中,我们提供了一个基本的缓冲区应用场景。接下来,我们将向您展示缓冲区的更全面功能。
优先经验回放(PER)
在某些算法中,需要优先经验回放。在DI-engine中,您可以使用PriorityExperienceReplay中间件来启用缓冲区优先经验回放功能。 如果用户在放入样本时启用了该功能,他们必须明确传递关于每个样本优先级的元信息,如下所示。优先采样增加了样本的经过时间。
from ding.data.buffer.middleware import PriorityExperienceReplay
buffer = DequeBuffer(size=10)
buffer.use(PriorityExperienceReplay(buffer, IS_weight=True))
for _ in range(10):
# The meta is essentially a Dict that complements the description of the sample and is empty by default.
buffer.push('a', meta={"priority": 2.0})
buffered_data = buffer.sample(3)
样本克隆
默认情况下,对于存储在缓冲区中的可变对象(如列表、np.array、torch.tensor等),采样操作实际上返回的是对该对象的引用。 如果用户随后对引用的内容进行更改,可能会导致样本池中的相应内容也发生变化。 在某些应用场景中,用户可能希望样本池中的数据保持不变,这可以通过使用clone_object中间件在采样时返回缓冲区中对象的副本来实现。 这样,对副本内容的修改不会影响缓冲区中的原始数据。样本克隆会显著增加采样的耗时。
from ding.data.buffer.middleware import clone_object
buffer = DequeBuffer(size=10)
buffer.use(clone_object())
分组采样
在某些特定环境或算法中,用户可能希望按整个回合收集、存储和处理样本。例如,在象棋、围棋或纸牌游戏中,玩家只在游戏结束时获得奖励,解决此类任务的算法通常希望处理整个游戏,而像Hindsight Experience Replay (HER)这样的算法需要采样完整的回合并按回合单位进行处理。在这种情况下,用户可以使用分组采样来实现这一目标。
通过原子操作自定义实现
上述需求可以通过一些原子操作来实现,以达到定制化和更高的灵活性。例如,在存储样本时,您可以在元数据中添加“episode”信息,以指定样本所属的剧集,而在采样时,您可以设置groupby="episode"以启用按剧集关键字进行分组采样。分组采样会显著增加采样时间。
buffer = DequeBuffer(size=10) # When storing data, the user needs to add grouping information to the meta, e.g., "episode" as the grouping keyword, and the corresponding value is the specific group buffer.push("a", {"episode": 1}) buffer.push("b", {"episode": 2}) buffer.push("c", {"episode": 2}) # Grouping according to the keyword "episode" requires that the number of different groups in the buffer is not less than the number of samples. grouped_data = buffer.sample(2, groupby="episode")
通过中间件实现
在DI-engine中,我们还通过data_pusher中间件提供了一个完整的组采样操作。以R2D2算法为例,样本的片段通过LSTM网络按顺序传递。 在数据收集中,每个环境实例对应一个唯一的决策轨迹,因此建议使用env_id作为区分不同片段的关键。 以下是使用组采样的R2D2代码,完整版本可以在./ding/example/r2d2.py找到:
buffer = DequeBuffer(size=10) # Here 'env' is used as the keyword for grouping, so that samples with the same env_id will be classified into the same group when sampling. task.use(data_pusher(cfg, buffer, group_by_env=True))
(可用选项) 在组采样之上,您还可以使用group_sample中间件来实现样本的后处理,例如:选择是否打乱同一组内的数据,以及设置每组数据的最大长度。
from ding.data.buffer.middleware import group_sample
buffer = DequeBuffer(size=10)
# The maximum length of each group of data is set to 3, keeping the original order within the group
buffer.use(group_sample(size_in_group=3, ordered_in_group=True))
删除多个使用样本
默认情况下,在Dequebuffer中,样本可能会被多个样本函数调用重复收集。如果不加以控制,训练性能将会下降,因为它会过多地拟合部分样本。 为了避免这个问题,我们可以使用use_time_check中间件来设置样本可以被使用的最大次数。
from ding.data.buffer.middleware import use_time_check
buffer = DequeBuffer(size=10)
# Set the maximum number of times a single sample can be used to 2
buffer.use(use_time_check(buffer, max_use=2))
中间件维护一个计数器来记录每个样本的选取次数,并将其写入元数据中的use_count字段。当样本被选取时,计数将增加1,直到超过设置的最大容忍次数并被删除。