MuZeroBuffer
- class lzero.mcts.buffer.game_buffer_muzero.MuZeroGameBuffer(cfg: dict)[源代码]
基类:
GameBuffer
- 概述:
MuZero 策略的特定游戏缓冲区。
- _abc_impl = <_abc._abc_data object>
- _compute_target_policy_non_reanalyzed(policy_non_re_context: List[Any], policy_shape: int | None) ndarray [源代码]
- 概述:
从政策的非重新分析上下文中准备政策目标
- 参数:
policy_non_re_context (-) – 包含以下内容的列表:- pos_in_game_segment_list - child_visits - game_segment_lens - action_mask_segment - to_play_segment
policy_shape (-) – self._cfg.model.action_space_size
- 返回:
batch_target_policies_non_re
- _compute_target_policy_reanalyzed(policy_re_context: List[Any], model: Any) ndarray [源代码]
- 概述:
从重新分析的政策背景中准备政策目标
- 参数:
policy_re_context (-) – 待重新分析的政策背景列表
- 返回:
batch_target_policies_re
- _compute_target_reward_value(reward_value_context: List[Any], model: Any) Tuple[Any, Any] [源代码]
- 概述:
从奖励和价值的环境中准备奖励和价值目标。
- 参数:
( (- model) – obj:’list’): 奖励值上下文
( – obj:’torch.tensor’):目标模型的模型
- 返回:
obj:’np.ndarray): batch of value prefix - batch_target_values (:obj:’np.ndarray): batch of value estimation
- 返回类型:
batch_value_prefixs (
- _make_batch(batch_size: int, reanalyze_ratio: float) Tuple[Any] [源代码]
- 概述:
首先通过
_sample_orig_data()
获取原始数据样本,然后准备一批数据的上下文:reward_value_context: 重新分析的价值目标的上下文 policy_re_context: 重新分析的策略目标的上下文 policy_non_re_context: 未重新分析的策略目标的上下文 current_batch: 批次的输入
- 参数:
batch_size (-) – 从回放缓冲区中获取的orig_data的批量大小。
reanalyze_ratio (-) – 重新分析政策的比例(值为100%重新分析)
- 返回:
reward_value_context, policy_re_context, policy_non_re_context, current_batch
- 返回类型:
context (
Tuple
)
- _make_batch_for_reanalyze(batch_size: int) Tuple[Any] [源代码]
- 概述:
首先通过
_sample_orig_data()
获取原始数据样本,然后准备一批数据的上下文:reward_value_context: 重新分析的价值目标的上下文 policy_re_context: 重新分析的策略目标的上下文 policy_non_re_context: 未重新分析的策略目标的上下文 current_batch: 批次的输入
- 参数:
batch_size (-) – 从回放缓冲区中获取的orig_data的批量大小。
- 返回:
reward_value_context, policy_re_context, policy_non_re_context, current_batch
- 返回类型:
context (
Tuple
)
- _prepare_policy_non_reanalyzed_context(batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int]) List[Any] [源代码]
- 概述:
准备非重新分析部分中计算策略目标的策略上下文,只需返回自我对弈中的策略。
- 参数:
batch_index_list (-) – 回放缓冲区中采样的小批次的开始过渡索引
game_segment_list (-) – 游戏段落列表
pos_in_game_segment_list (-) – 游戏中的列表过渡索引
- 返回:
pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment
- 返回类型:
policy_non_re_context (
list
)
- _prepare_policy_reanalyzed_context(batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str]) List[Any] [源代码]
- 概述:
准备重新分析部分中计算政策目标的政策背景。
- 参数:
( (- pos_in_game_segment_list) – obj:’list’): 在重放缓冲区中开始过渡索引
( – obj:’list’): 游戏段落列表
( – obj:’list’): 在一个游戏历史中过渡索引的位置
- 返回:
policy_obs_list, policy_mask, pos_in_game_segment_list, indices, child_visits, game_segment_lens, action_mask_segment, to_play_segment
- 返回类型:
policy_re_context (
list
)
- _prepare_reward_value_context(batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], total_transitions: int) List[Any] [源代码]
- 概述:
为重新分析部分的TD值目标计算准备奖励和价值的环境。
- 参数:
batch_index_list (-) – 回放缓冲区中采样的小批次的开始过渡索引
game_segment_list (-) – 游戏段落列表
pos_in_game_segment_list (-) – 游戏段落中的过渡索引列表
total_transitions (-) – 收集的转换次数
- 返回:
value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment
- 返回类型:
reward_value_context (
list
)
- _preprocess_to_play_and_action_mask(game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, unroll_steps=None)
- 概述:
- 为
value_obs_list
中的目标观察准备 to_play 和 action_mask to_play: {list: game_segment_batch_size * (num_unroll_steps+1)}
action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)}
- 为
- _push_game_segment(data: Any, meta: dict | None = None) None
- 概述:
将数据及其元信息推入缓冲区。保存游戏片段。
- 参数:
data (-) – 将要推入缓冲区的数据(一个游戏片段)。
meta (-) – Meta information, e.g. priority, count, staleness. - done (
bool
): whether the game is finished. - unroll_plus_td_steps (int
): if the game is not finished, we only save the transitions that can be computed - priorities (list
): the priorities corresponding to the transitions in the game history
- 返回:
推送的数据。
- 返回类型:
buffered_data (
BufferedData
)
- _remove(excess_game_segment_index: List[int]) None
- 概述:
删除索引中的游戏片段 [0: excess_game_segment_index]
- 参数:
excess_game_segment_index (-) – 数据索引
- _sample_orig_data(batch_size: int) Tuple
- 概述:
- 包含的原始数据示例:
game_segment_list: 游戏段列表 pos_in_game_segment_list: 游戏中的过渡索引(相对索引) batch_index_list: 重放缓冲区中采样的小批量过渡的起始索引 weights_list: 与优先级相关的权重 make_time: 批次生成的时间(用于在删除数据时正确更新重放缓冲区)
- 参数:
batch_size (-) – 批量大小
beta (-) – 在PER中浮动参数以计算优先级
- _sample_orig_data_episode(batch_size: int) Tuple
- 概述:
- 训练批次样本的原始数据,包括:
game_segment_list: 游戏段落列表。
pos_in_game_segment_list: 游戏片段内的过渡索引。
batch_index_list: 在回放缓冲区中采样的小批量数据的起始转换的索引。
weights_list: 每个采样转换的权重,用于优先级排序。
make_time: 时间戳,指示批次的创建时间(对于管理重放缓冲区更新很有用)。
- 参数:
batch_size (-) – 用于批次抽取的样本数量。
beta (-) – 用于优先经验回放 (PER) 的参数,调整样本的重要性。
- _sample_orig_reanalyze_batch(batch_size: int) Tuple
- 概述:
此函数从重放缓冲区中采样一批游戏片段以进行重新分析。它基于每个游戏片段的 reanalyze_time 进行优先级采样,重新分析频率较高的片段优先级较低。
该函数返回一个包含采样游戏段信息元组,包括它们在每个段中的位置以及批处理创建的时间。
- 参数:
batch_size (-) – 在此批次中抽取的样本数量。
- 返回:
一个包含以下元素的元组: - game_segment_list:采样的游戏片段列表。 - pos_in_game_segment_list:表示每个转换在其对应游戏片段中位置的索引列表。 - batch_index_list:重放缓冲区中采样游戏片段的索引。 - make_time:指示批次创建时间的(在此实现中设置为 0 的时间戳)列表。
- 返回类型:
元组
- 关键细节:
优先采样:游戏片段根据使用每个片段的 reanalyze_time 计算的概率分布进行采样。被重新分析次数较多的片段被选中的可能性较小。
片段切片:每个选定的游戏片段根据 num_unroll_steps 参数以固定间隔进行采样。每个选定的片段最多采样 samples_per_segment 个转换。
处理额外样本:如果 batch_size 不能被每个片段的样本数整除,则会额外采样片段以弥补差额。
重新分析时间更新:每个采样的游戏段落的 reanalyze_time 属性被增加,以反映它已被再次选中进行重新分析。
- 抛出:
- ValueError – 如果 game_segment_length 太小,无法容纳 num_unroll_steps。
- _sample_orig_reanalyze_data(batch_size: int) Tuple
- 概述:
- 包含的原始数据示例:
game_segment_list: 游戏段列表 pos_in_game_segment_list: 游戏中的过渡索引(相对索引) batch_index_list: 重放缓冲区中采样的小批量过渡的起始索引 weights_list: 与优先级相关的权重 make_time: 批次生成的时间(用于在删除数据时正确更新重放缓冲区)
- 参数:
batch_size (-) – 批量大小
beta (-) – 在PER中浮动参数以计算优先级
- config = {'mini_infer_size': 10240, 'reanalyze_outdated': True, 'reanalyze_ratio': 0, 'replay_buffer_size': 1000000, 'sample_type': 'transition', 'use_root_value': False}
- classmethod default_config() EasyDict
- get_num_of_episodes() int
- get_num_of_game_segments() int
- get_num_of_transitions() int
- push_game_segments(data_and_meta: Any) None
- 概述:
将 game_segments 数据及其元信息推入缓冲区。保存一个游戏段
- 参数:
data_and_meta (-) –
data (
Any
): The data (game segments) which will be pushed into buffer.meta (
dict
): Meta information, e.g. priority, count, staleness.
- reanalyze_buffer(batch_size: int, policy: MuZeroPolicy | EfficientZeroPolicy | SampledEfficientZeroPolicy) List[Any] [源代码]
- 概述:
从
GameBuffer
中获取样本数据,并为训练准备当前和目标批次。
- 参数:
batch_size (-) – 批量大小。
policy (-) – 政策。
- 返回:
列车数据列表,包括 current_batch 和 target_batch。
- 返回类型:
train_data (
List
)
- remove_oldest_data_to_fit() None
- 概述:
如果回放缓冲区已满,则删除一些最旧的数据。
- sample(batch_size: int, policy: MuZeroPolicy | EfficientZeroPolicy | SampledEfficientZeroPolicy) List[Any] [源代码]
- 概述:
从
GameBuffer
中获取样本数据,并为训练准备当前和目标批次。
- 参数:
batch_size (-) – 批量大小。
policy (-) – 政策。
- 返回:
列车数据列表,包括 current_batch 和 target_batch。
- 返回类型:
train_data (
List
)
- update_priority(train_data: List[ndarray], batch_priorities: Any) None [源代码]
- 概述:
更新训练数据的优先级。
- 参数:
train_data (-) – 训练数据更新优先级。
batch_priorities (-) – 更新优先级。
备注
train_data = [current_batch, target_batch] current_batch = [obs_list, action_list, improved_policy_list(仅在Gumbel MuZero中), mask_list, batch_index_list, weights, make_time_list]