GameBuffer

class lzero.mcts.buffer.game_buffer.GameBuffer(cfg: dict)[源代码]

基类:ABC, object

概述:

MuZeroPolicy、EfficientZeroPolicy、SampledEfficientZeroPolicy、GumbelMuZeroPolicy 的基础游戏缓冲类。

__init__(cfg: dict)[源代码]
_abc_impl = <_abc._abc_data object>
abstract _compute_target_policy_non_reanalyzed(policy_non_re_context: List[Any], policy_shape: int | None) ndarray[源代码]
概述:

从政策的非重新分析上下文中准备政策目标

参数:

policy_non_re_context (-) – 包含以下内容的列表:- pos_in_game_segment_list - child_visits - game_segment_lens - action_mask_segment - to_play_segment

返回:

  • batch_target_policies_non_re

abstract _compute_target_policy_reanalyzed(policy_re_context: List[Any], model: Any) ndarray[源代码]
概述:

从重新分析的政策背景中准备政策目标

参数:

policy_re_context (-) – 待重新分析的政策背景列表

返回:

  • batch_target_policies_re

abstract _compute_target_reward_value(reward_value_context: List[Any], model: Any) List[ndarray][源代码]
概述:

从奖励和价值的环境中准备奖励和价值目标。

参数:
  • ( (- model) – obj:’list’): 奖励值上下文

  • ( – obj:’torch.tensor’):目标模型的模型

返回:

obj:’np.ndarray): batch of value prefix - batch_target_values (:obj:’np.ndarray): batch of value estimation

返回类型:

  • batch_value_prefixs (

abstract _make_batch(orig_data: Any, reanalyze_ratio: float) Tuple[Any][源代码]
概述:

prepare the context of a batch reward_value_context: the context of reanalyzed value targets policy_re_context: the context of reanalyzed policy targets policy_non_re_context: the context of non-reanalyzed policy targets current_batch: the inputs of batch

参数:
  • orig_data – 来自重放缓冲区的任何批次上下文

  • reanalyze_ratio – 重新分析策略的浮动比率(值为100%重新分析)

返回:

reward_value_context, policy_re_context, policy_non_re_context, current_batch

返回类型:

  • context (Tuple)

abstract _prepare_policy_non_reanalyzed_context(batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int]) List[Any][源代码]
概述:

准备非重新分析部分中计算策略目标的策略上下文,只需返回自我对弈中的策略。

参数:
  • batch_index_list (-) – 重放缓冲区中采样小批次的起始过渡索引

  • game_segment_list (-) – 游戏段落列表

  • pos_in_game_segment_list (-) – 游戏中的列表过渡索引

返回:

state_index_lst, child_visits, game_segment_lens, action_mask_segment, to_play_segment

返回类型:

  • policy_non_re_context (list)

abstract _prepare_policy_reanalyzed_context(batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str]) List[Any][源代码]
概述:

准备重新分析部分中计算政策目标的政策背景。

参数:
  • ( (- pos_in_game_segment_list) – obj:’list’): 在重放缓冲区中开始过渡索引

  • ( – obj:’list’): 游戏段落列表

  • ( – obj:’list’): 在一个游戏历史中过渡索引的位置

返回:

policy_obs_lst, policy_mask, state_index_lst, indices, child_visits, game_segment_lens, action_mask_segment, to_play_segment

返回类型:

  • policy_re_context (list)

abstract _prepare_reward_value_context(batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any], total_transitions: int) List[Any][源代码]
概述:

为重新分析部分的TD值目标计算准备奖励和价值的环境。

参数:
  • batch_index_list (-) – 重放缓冲区中采样小批次的起始过渡索引

  • game_segment_list (-) – 游戏段落列表

  • pos_in_game_segment_list (-) – 游戏段落中的过渡索引列表

  • total_transitions (-) – 收集的过渡次数

返回:

value_obs_lst, value_mask, state_index_lst, rewards_lst, game_segment_lens, td_steps_lst, action_mask_segment, to_play_segment

返回类型:

  • reward_value_context (list)

_preprocess_to_play_and_action_mask(game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list, unroll_steps=None)[源代码]
概述:
value_obs_list 中的目标观察值准备 to_play 和 action_mask。
  • to_play: {list: game_segment_batch_size * (num_unroll_steps+1)}

  • action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)}

_push_game_segment(data: Any, meta: dict | None = None) None[源代码]
概述:

将数据及其元信息推入缓冲区。保存游戏段。

参数:
  • data (-) – 将要推入缓冲区的数据(一个游戏片段)。

  • meta (-) – Meta information, e.g. priority, count, staleness. - done (bool): whether the game is finished. - unroll_plus_td_steps (int): if the game is not finished, we only save the transitions that can be computed - priorities (list): the priorities corresponding to the transitions in the game history

返回:

推送的数据。

返回类型:

  • buffered_data (BufferedData)

_remove(excess_game_segment_index: List[int]) None[源代码]
概述:

删除索引中的游戏片段 [0: excess_game_segment_index]

参数:

excess_game_segment_index (-) – 数据索引

_sample_orig_data(batch_size: int) Tuple[源代码]
概述:
包含的原始数据示例:

game_segment_list: 游戏片段列表 pos_in_game_segment_list: 游戏中的过渡索引(相对索引) batch_index_list: 重放缓冲区中采样的小批量过渡的起始索引 weights_list: 与优先级相关的权重 make_time: 批次生成的时间(用于在删除数据时正确更新重放缓冲区)

参数:
  • batch_size (-) – 批量大小

  • beta (-) – 在PER中浮动参数以计算优先级

_sample_orig_data_episode(batch_size: int) Tuple[源代码]
概述:
训练批次样本的原始数据,包括:
  • game_segment_list: 游戏段落列表。

  • pos_in_game_segment_list: 游戏段落内的过渡索引。

  • batch_index_list: 在重放缓冲区中采样的小批量开始转换的索引。

  • weights_list: 每个采样转换的权重,用于优先级排序。

  • make_time: 时间戳,指示批处理创建的时间(用于管理重放缓冲区更新很有用)。

参数:
  • batch_size (-) – 为批次抽取的样本数量。

  • beta (-) – 用于优先经验回放 (PER) 的参数,调整样本的重要性。

_sample_orig_reanalyze_batch(batch_size: int) Tuple[源代码]
概述:

此函数从重放缓冲区中采样一批游戏段以进行重新分析。它基于每个游戏段的 reanalyze_time 进行优先级采样,重新分析次数较多的段将获得较低的优先级。

该函数返回一个包含采样游戏段信息元组,包括它们在每个段中的位置以及批处理创建的时间。

参数:

batch_size (-) – 本次批次中要抽取的样本数量。

返回:

包含以下元素的元组:- game_segment_list:采样的游戏片段列表。- pos_in_game_segment_list:表示每个转换在其对应游戏片段中位置的索引列表。- batch_index_list:重放缓冲区中采样游戏片段的索引。- make_time:时间戳列表(在此实现中设置为 0),指示批次创建的时间。

返回类型:

  • 元组

关键细节:
  1. 优先采样:游戏片段根据使用每个片段的 reanalyze_time 计算的概率分布进行采样。被重新分析更频繁的片段被选中的可能性较小。

  2. 片段切片:每个选定的游戏片段根据 num_unroll_steps 参数确定的固定间隔进行采样。每个选定的片段最多采样 samples_per_segment 个转换。

  3. 处理额外样本:如果 batch_size 不能被每个段的样本数整除,则会额外采样一些段来弥补差额。

  4. 重新分析时间更新:每个采样游戏段的 reanalyze_time 属性被增加,以反映它已被再次选中进行重新分析。

抛出:

- ValueError – 如果 game_segment_length 太小,无法容纳 num_unroll_steps

_sample_orig_reanalyze_data(batch_size: int) Tuple[源代码]
概述:
包含的原始数据示例:

game_segment_list: 游戏片段列表 pos_in_game_segment_list: 游戏中的过渡索引(相对索引) batch_index_list: 重放缓冲区中采样的小批量过渡的起始索引 weights_list: 与优先级相关的权重 make_time: 批次生成的时间(用于在删除数据时正确更新重放缓冲区)

参数:
  • batch_size (-) – 批量大小

  • beta (-) – 在PER中浮动参数以计算优先级

config = {'mini_infer_size': 10240, 'reanalyze_outdated': True, 'reanalyze_ratio': 0, 'replay_buffer_size': 1000000, 'sample_type': 'transition', 'use_root_value': False}
classmethod default_config() EasyDict[源代码]
get_num_of_episodes() int[源代码]
get_num_of_game_segments() int[源代码]
get_num_of_transitions() int[源代码]
push_game_segments(data_and_meta: Any) None[源代码]
概述:

将 game_segments 数据及其元信息推入缓冲区。保存一个游戏段

参数:

data_and_meta (-) –

  • data (Any): The data (game segments) which will be pushed into buffer.

  • meta (dict): Meta information, e.g. priority, count, staleness.

remove_oldest_data_to_fit() None[源代码]
概述:

如果回放缓冲区已满,则删除一些最旧的数据。

abstract sample(batch_size: int, policy: MuZeroPolicy | EfficientZeroPolicy | SampledEfficientZeroPolicy | GumbelMuZeroPolicy) List[Any][源代码]
概述:

GameBuffer 中获取样本数据,并为训练准备当前和目标批次。

参数:
  • batch_size (-) – 批量大小。

  • policy (-) – 政策。

返回:

列车数据的列表,包括 current_batch 和 target_batch。

返回类型:

  • train_data (List)

abstract update_priority(train_data: List[ndarray | None] | None, batch_priorities: Any | None) None[源代码]
概述:

更新训练数据的优先级。

参数:
  • train_data (-) – 训练数据更新优先级。

  • batch_priorities (-) – 更新优先级。