Shortcuts

GTrXL

概述

Gated Transformer-XL,或称为GTrXL,首次在稳定强化学习的Transformer中提出,是一种从Transformer-XL架构改编而来的新颖强化学习框架。它主要引入了两种架构修改,以提高Transformer的稳定性和学习速度,包括:仅在子模块的输入流上放置层归一化,以及用门控层替换残差连接。所提出的架构在具有挑战性的记忆环境中超越了LSTM,并在多个记忆基准测试中取得了最先进的结果,超过了外部记忆架构的性能。

快速事实

  1. GTrXL 可以作为许多强化学习算法的骨干

  2. GTrXL 仅支持顺序观察。

  3. GTrXL 基于 Transformer-XL 并带有 门控连接

  4. GTrXL 的 DI-engine 实现基于 R2D2 算法。在原始论文中,它基于算法 V-MPO

关键方程或关键图表

Transformer-XL: 为了解决上下文碎片化问题,Transformer-XL 引入了循环的概念到深度自注意力网络中。 Transformer-XL 不是为每个新片段从头计算隐藏状态,而是重用先前片段中获得的隐藏状态。 重用的隐藏状态作为当前片段的记忆,从而在片段之间建立了循环连接。 因此,由于信息可以通过循环连接传播,建模非常长期的依赖关系成为可能。 为了在不引起时间混淆的情况下启用状态重用,Transformer-XL 提出了一种新的相对位置编码公式,该公式可以推广到比训练期间观察到的更长的注意力长度。

../_images/transformerXL_train_eval.png

身份映射重排序:将层归一化移动到子模块的输入流中。 这种重排序的一个关键好处是,它现在使得从第一层的变压器输入到最后一层后的变压器输出之间存在一个身份映射。 这与规范变压器形成对比,在规范变压器中,有一系列层归一化操作非线性地变换状态编码。 关于为什么身份映射重排序能改善结果的一个假设如下:假设初始化时子模块产生的值在期望上接近零,状态编码未经变换地传递给策略和价值头,使得代理能够在训练开始时学习马尔可夫策略 (即,网络被初始化使得 \(\pi(·|st,...,s1) ≈ \pi(·|st)\)\(V^\pi(s_t|s_{t-1},...,s_1) ≈ V^\pi(s_t|s_{t-1})\)), 从而忽略了来自注意力-XL记忆的过去观察的贡献。 在许多环境中,反应性行为需要在基于记忆的行为能够有效利用之前被学习。 例如,代理需要先学会如何行走,然后才能学会记住它走过的路径。 通过身份映射重排序,模型的前向传递可以计算为:

../_images/identity_map_reordering.png ../_images/gtrxl.png

门控层:用门控层替换残差连接。在研究的几种门控函数中,门控循环单元(GRU)表现最佳。 其强大的门控机制可以表示为:

\[\begin{split}\begin{aligned} r &= \sigma(W_r^{(l)} y + U_r^{(l)} x) \\ z &= \sigma(W_z^{(l)} y + U_z^{(l)} x - b_g^{(l)}) \\ \hat{h} &= \tanh(W_g^{(l)} y + U_g^{(l)} (r \odot x)) \\ g^{(l)}(x, y) &= (1-z)\odot x + z\odot \hat{h} \end{aligned}\end{split}\]

门控身份初始化:作者声称身份映射重排序有助于策略优化,因为它将代理初始化为接近马尔可夫策略或价值函数。 如果这确实是提高稳定性的原因,我们可以显式初始化各种门控机制,使其接近身份映射。 这就是适用门控层中偏置\(b_g^{(l)}\)的目的。作者在消融实验中证明,初始设置\(b_g^{(l)}>0\)会产生最佳结果。

扩展

GTrXL 可以与以下内容结合使用:

  • CoBERL (CoBERL: 对比BERT用于强化学习):

    对比BERT(CoBERL)是一种强化学习代理,它结合了新的对比损失和混合LSTM-Transformer架构,以应对提高RL数据效率的挑战。它使用双向掩码预测结合最近对比方法的泛化,来学习RL中Transformer的更好表示,而无需手工设计的数据增强。

  • R2D2 (分布式强化学习中的循环经验回放):

    循环重放分布式DQN(R2D2)展示了如何调整重放和强化学习目标,使其适用于具有循环架构的智能体。LSTM可以被替换或与门控变压器结合,以便我们能够利用分布式经验收集的优势,将循环智能体状态存储在重放缓冲区中,并在训练期间通过重放序列“预热”部分展开的网络。

实现

使用的网络接口 GTrXL 定义如下:

class ding.torch_utils.network.gtrxl.GTrXL(input_dim: int, head_dim: int = 128, embedding_dim: int = 256, head_num: int = 2, mlp_num: int = 2, layer_num: int = 3, memory_len: int = 64, dropout_ratio: float = 0.0, activation: Module = ReLU(), gru_gating: bool = True, gru_bias: float = 2.0, use_embedding_layer: bool = True)[源代码]
Overview:

GTrXL Transformer 实现,如“Stabilizing Transformer for Reinforcement Learning”中所述 (https://arxiv.org/abs/1910.06764)。

Interfaces:

__init__, forward, reset_memory, get_memory

forward(x: Tensor, batch_first: bool = False, return_mem: bool = True) Dict[str, Tensor][source]
Overview:

在GTrXL上执行前向传递。

Arguments:
  • x (torch.Tensor): 输入张量,形状为 (seq_len, bs, input_size)。

  • batch_first (bool, 可选): 如果输入数据的形状为 (bs, seq_len, input_size),将此参数设置为 True 以沿第一和第二维度进行转置,并获得形状 (seq_len, bs, input_size)。这不会影响输出内存。默认值为 False。 - return_mem (bool, 可选): 如果为 False,则仅返回输出张量而不返回字典。默认值为 True。

Returns:
  • x (Dict[str, torch.Tensor]): 一个包含形状为 (seq_len, bs, embedding_size) 的transformer输出和形状为 (layer_num, seq_len, bs, embedding_size) 的memory的字典。

get_memory()[source]
Overview:

返回GTrXL的内存。

Returns:
  • memory (Optional[torch.Tensor]): 输出的memory,如果memory尚未初始化则为None。形状为(layer_num, memory_len, bs, embedding_dim)。

reset_memory(batch_size: int | None = None, state: Tensor | None = None)[source]
Overview:

清除或设置GTrXL的内存。

Arguments:
  • batch_size (Optional[int]): 批量大小。默认值为 None。

  • 状态 (Optional[torch.Tensor]): 输入的记忆,形状为 (层数, 记忆长度, 批次大小, 嵌入维度)。默认值为 None。

我们基于R2D2的GTrXL的默认实现定义如下:

class ding.policy.r2d2_gtrxl.R2D2GTrXLPolicy(cfg: EasyDict, model: Module | None = None, enable_field: List[str] | None = None)[source]
Overview:

R2D2的策略类采用Transformer架构GTrXL作为骨干。

Config:

ID

符号

类型

默认值

描述

其他(形状)

1

type

字符串

r2d2_gtrxl

RL policy register name, refer to
registry POLICY_REGISTRY
This arg is optional,
a placeholder

2

cuda

布尔

Whether to use cuda for network
This arg can be diff-
erent from modes

3

on_policy

布尔

Whether the RL algorithm is on-policy
or off-policy

4

priority

布尔

Whether use priority(PER)
Priority sample,
update priority

5

priority_IS
_weight

布尔

Whether use Importance Sampling Weight
to correct biased update. If True,
priority must be True.

6

discount_
factor

浮点数

0.99, [0.95, 0.999]

Reward’s future discount factor, aka.
gamma
May be 1 when sparse
reward env

7

nstep

整数

5, [3, 5]

N-step reward discount sum for target
q_value estimation

8

burnin_step

整数

1

The timestep of burnin operation,
which is designed to warm-up GTrXL
memory difference caused by off-policy

9

learn.update
per_collect

整数

1

How many updates(iterations) to train
after collector’s one collection. Only
valid in serial training
This args can be vary
from envs. Bigger val
means more off-policy

10

learn.batch_
size

整数

64

The number of samples of an iteration

11

learn.learning
_rate

浮点数

0.001

Gradient step length of an iteration.

12

learn.value_
rescale

布尔

Whether use value_rescale function for
predicted value

13

learn.target_
update_freq

整数

100

Frequence of target network update.
Hard(assign) update

14

learn.ignore_
done

布尔

Whether ignore done for target value
calculation.
Enable it for some
fake termination env

15

collect.n_sample

整数

[8, 128]

The number of training samples of a
call of collector.
It varies from
different envs

16

collect.seq
_len

整数

20

Training sequence length
unroll_len>=seq_len>1

17

learn.init_
memory

字符串

‘zero’ or ‘old’, how to initialize the
memory before each training iteration.


_data_preprocess_learn(data: List[Dict[str, Any]]) dict[source]
Overview:

预处理数据以适应学习所需的格式

Arguments:
  • 数据 (List[Dict[str, Any]]): 从收集函数中收集的数据

Returns:
  • data (Dict[str, Any]): the processed data, including at least

    [‘main_obs’, ‘target_obs’, ‘burnin_obs’, ‘action’, ‘reward’, ‘done’, ‘weight’]

  • data_info (dict): 数据信息,例如 replay_buffer_idx, replay_unique_id

_forward_learn(data: dict) Dict[str, Any][来源]
Overview:

学习模式的前向和后向函数。 获取数据,计算损失并优化学习器模型。

Arguments:
  • data (dict): Dict type data, including at least

    [‘main_obs’, ‘target_obs’, ‘burnin_obs’, ‘action’, ‘reward’, ‘done’, ‘weight’]

Returns:
  • info_dict (Dict[str, Any]): Including cur_lr and total_loss
    • cur_lr (float): 当前学习率

    • total_loss (float): 计算出的损失

_init_learn() None[source]
Overview:

初始化GTrXLR2D2Policy的学习者模型。目标模型有2个包装器:'target'用于权重更新,'transformer_segment'用于将轨迹分割成段。学习模型有2个包装器:'argmax'用于选择最佳动作,'transformer_segment'。

Arguments:
  • learning_rate (float): 优化器的学习率

  • gamma (float): 折扣因子

  • nstep (int): n步返回的数量

  • value_rescale (bool): 是否在算法中使用值重新缩放的损失

  • burnin_step (int): 预热步骤的数量

  • seq_len (int): 训练序列长度

  • init_memory (str): ‘zero’ 或 ‘old’,每次训练迭代前如何初始化内存。

注意

_init_learn 方法从配置文件中的 self._cfg.learn 获取参数

基准测试

环境

最佳平均奖励

评估结果

配置链接

比较

乒乓球

(PongNoFrameskip-v4)

20

../_images/pong_gtrxl_r2d2.png

config_link_p

附注:

  1. 上述结果是通过在五个不同的随机种子(0, 1, 2, 3, 4)上运行相同的配置获得的

  2. 对于像DQN这样的离散动作空间算法,通常使用Atari环境集进行测试(包括子环境Pong),并且Atari环境通常通过最高平均奖励训练10M env_step来评估。有关Atari的更多详细信息,请参阅Atari环境教程

参考