GameBuffer
- class lzero.mcts.buffer.game_buffer.GameBuffer(cfg: dict)[源代码]
基类:
ABC
,object
- 概述:
MuZeroPolicy、EfficientZeroPolicy、SampledEfficientZeroPolicy、GumbelMuZeroPolicy 的基础游戏缓冲类。
- _abc_impl = <_abc._abc_data object>
- abstract _compute_target_policy_non_reanalyzed(policy_non_re_context: List[Any], policy_shape: int | None) ndarray [源代码]
- 概述:
从政策的非重新分析上下文中准备政策目标
- 参数:
policy_non_re_context (-) – 包含以下内容的列表:- pos_in_game_segment_list - child_visits - game_segment_lens - action_mask_segment - to_play_segment
- 返回:
batch_target_policies_non_re
- abstract _compute_target_policy_reanalyzed(policy_re_context: List[Any], model: Any) ndarray [源代码]
- 概述:
从重新分析的政策背景中准备政策目标
- 参数:
policy_re_context (-) – 待重新分析的政策背景列表
- 返回:
batch_target_policies_re
- abstract _compute_target_reward_value(reward_value_context: List[Any], model: Any) List[ndarray] [源代码]
- 概述:
从奖励和价值的环境中准备奖励和价值目标。
- 参数:
( (- model) – obj:’list’): 奖励值上下文
( – obj:’torch.tensor’):目标模型的模型
- 返回:
obj:’np.ndarray): batch of value prefix - batch_target_values (:obj:’np.ndarray): batch of value estimation
- 返回类型:
batch_value_prefixs (
- abstract _make_batch(orig_data: Any, reanalyze_ratio: float) Tuple[Any] [源代码]
- 概述:
prepare the context of a batch reward_value_context: the context of reanalyzed value targets policy_re_context: the context of reanalyzed policy targets policy_non_re_context: the context of non-reanalyzed policy targets current_batch: the inputs of batch
- 参数:
orig_data – 来自重放缓冲区的任何批次上下文
reanalyze_ratio – 重新分析策略的浮动比率(值为100%重新分析)
- 返回:
reward_value_context, policy_re_context, policy_non_re_context, current_batch
- 返回类型:
context (
Tuple
)
- abstract _prepare_policy_non_reanalyzed_context(batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int]) List[Any] [源代码]
- 概述:
准备非重新分析部分中计算策略目标的策略上下文,只需返回自我对弈中的策略。
- 参数:
batch_index_list (-) – 重放缓冲区中采样小批次的起始过渡索引
game_segment_list (-) – 游戏段落列表
pos_in_game_segment_list (-) – 游戏中的列表过渡索引
- 返回:
state_index_lst, child_visits, game_segment_lens, action_mask_segment, to_play_segment
- 返回类型:
policy_non_re_context (
list
)
- abstract _prepare_policy_reanalyzed_context(batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str]) List[Any] [源代码]
- 概述:
准备重新分析部分中计算政策目标的政策背景。
- 参数:
( (- pos_in_game_segment_list) – obj:’list’): 在重放缓冲区中开始过渡索引
( – obj:’list’): 游戏段落列表
( – obj:’list’): 在一个游戏历史中过渡索引的位置
- 返回:
policy_obs_lst, policy_mask, state_index_lst, indices, child_visits, game_segment_lens, action_mask_segment, to_play_segment
- 返回类型:
policy_re_context (
list
)
- abstract _prepare_reward_value_context(batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], total_transitions: int) List[Any] [源代码]
- 概述:
为重新分析部分的TD值目标计算准备奖励和价值的环境。
- 参数:
batch_index_list (-) – 重放缓冲区中采样小批次的起始过渡索引
game_segment_list (-) – 游戏段落列表
pos_in_game_segment_list (-) – 游戏段落中的过渡索引列表
total_transitions (-) – 收集的过渡次数
- 返回:
value_obs_lst, value_mask, state_index_lst, rewards_lst, game_segment_lens, td_steps_lst, action_mask_segment, to_play_segment
- 返回类型:
reward_value_context (
list
)
- _preprocess_to_play_and_action_mask(game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, unroll_steps=None)[源代码]
- 概述:
- 为
value_obs_list
中的目标观察值准备 to_play 和 action_mask。 to_play: {list: game_segment_batch_size * (num_unroll_steps+1)}
action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)}
- 为
- _push_game_segment(data: Any, meta: dict | None = None) None [源代码]
- 概述:
将数据及其元信息推入缓冲区。保存游戏段。
- 参数:
data (-) – 将要推入缓冲区的数据(一个游戏片段)。
meta (-) – Meta information, e.g. priority, count, staleness. - done (
bool
): whether the game is finished. - unroll_plus_td_steps (int
): if the game is not finished, we only save the transitions that can be computed - priorities (list
): the priorities corresponding to the transitions in the game history
- 返回:
推送的数据。
- 返回类型:
buffered_data (
BufferedData
)
- _remove(excess_game_segment_index: List[int]) None [源代码]
- 概述:
删除索引中的游戏片段 [0: excess_game_segment_index]
- 参数:
excess_game_segment_index (-) – 数据索引
- _sample_orig_data(batch_size: int) Tuple [源代码]
- 概述:
- 包含的原始数据示例:
game_segment_list: 游戏片段列表 pos_in_game_segment_list: 游戏中的过渡索引(相对索引) batch_index_list: 重放缓冲区中采样的小批量过渡的起始索引 weights_list: 与优先级相关的权重 make_time: 批次生成的时间(用于在删除数据时正确更新重放缓冲区)
- 参数:
batch_size (-) – 批量大小
beta (-) – 在PER中浮动参数以计算优先级
- _sample_orig_data_episode(batch_size: int) Tuple [源代码]
- 概述:
- 训练批次样本的原始数据,包括:
game_segment_list: 游戏段落列表。
pos_in_game_segment_list: 游戏段落内的过渡索引。
batch_index_list: 在重放缓冲区中采样的小批量开始转换的索引。
weights_list: 每个采样转换的权重,用于优先级排序。
make_time: 时间戳,指示批处理创建的时间(用于管理重放缓冲区更新很有用)。
- 参数:
batch_size (-) – 为批次抽取的样本数量。
beta (-) – 用于优先经验回放 (PER) 的参数,调整样本的重要性。
- _sample_orig_reanalyze_batch(batch_size: int) Tuple [源代码]
- 概述:
此函数从重放缓冲区中采样一批游戏段以进行重新分析。它基于每个游戏段的 reanalyze_time 进行优先级采样,重新分析次数较多的段将获得较低的优先级。
该函数返回一个包含采样游戏段信息元组,包括它们在每个段中的位置以及批处理创建的时间。
- 参数:
batch_size (-) – 本次批次中要抽取的样本数量。
- 返回:
包含以下元素的元组:- game_segment_list:采样的游戏片段列表。- pos_in_game_segment_list:表示每个转换在其对应游戏片段中位置的索引列表。- batch_index_list:重放缓冲区中采样游戏片段的索引。- make_time:时间戳列表(在此实现中设置为 0),指示批次创建的时间。
- 返回类型:
元组
- 关键细节:
优先采样:游戏片段根据使用每个片段的 reanalyze_time 计算的概率分布进行采样。被重新分析更频繁的片段被选中的可能性较小。
片段切片:每个选定的游戏片段根据 num_unroll_steps 参数确定的固定间隔进行采样。每个选定的片段最多采样 samples_per_segment 个转换。
处理额外样本:如果 batch_size 不能被每个段的样本数整除,则会额外采样一些段来弥补差额。
重新分析时间更新:每个采样游戏段的 reanalyze_time 属性被增加,以反映它已被再次选中进行重新分析。
- 抛出:
- ValueError – 如果 game_segment_length 太小,无法容纳 num_unroll_steps。
- _sample_orig_reanalyze_data(batch_size: int) Tuple [源代码]
- 概述:
- 包含的原始数据示例:
game_segment_list: 游戏片段列表 pos_in_game_segment_list: 游戏中的过渡索引(相对索引) batch_index_list: 重放缓冲区中采样的小批量过渡的起始索引 weights_list: 与优先级相关的权重 make_time: 批次生成的时间(用于在删除数据时正确更新重放缓冲区)
- 参数:
batch_size (-) – 批量大小
beta (-) – 在PER中浮动参数以计算优先级
- config = {'mini_infer_size': 10240, 'reanalyze_outdated': True, 'reanalyze_ratio': 0, 'replay_buffer_size': 1000000, 'sample_type': 'transition', 'use_root_value': False}
- push_game_segments(data_and_meta: Any) None [源代码]
- 概述:
将 game_segments 数据及其元信息推入缓冲区。保存一个游戏段
- 参数:
data_and_meta (-) –
data (
Any
): The data (game segments) which will be pushed into buffer.meta (
dict
): Meta information, e.g. priority, count, staleness.
- abstract sample(batch_size: int, policy: MuZeroPolicy | EfficientZeroPolicy | SampledEfficientZeroPolicy | GumbelMuZeroPolicy) List[Any] [源代码]
- 概述:
从
GameBuffer
中获取样本数据,并为训练准备当前和目标批次。
- 参数:
batch_size (-) – 批量大小。
policy (-) – 政策。
- 返回:
列车数据的列表,包括 current_batch 和 target_batch。
- 返回类型:
train_data (
List
)