• Tutorials >
  • Reinforcement Learning (PPO) with TorchRL Tutorial
Shortcuts

使用TorchRL进行强化学习(PPO)教程

创建于:2023年3月15日 | 最后更新:2024年5月16日 | 最后验证:2024年11月5日

作者: Vincent Moens

本教程演示了如何使用PyTorch和torchrl来训练一个参数化策略网络,以解决来自OpenAI-Gym/Farama-Gymnasium控制库的倒立摆任务。

Inverted pendulum

倒立摆

关键学习点:

  • 如何在TorchRL中创建一个环境,转换其输出,并从这个环境中收集数据;

  • 如何使用TensorDict让你的类相互通信;

  • 使用TorchRL构建训练循环的基础知识:

    • 如何计算策略梯度方法的优势信号;

    • 如何使用概率神经网络创建随机策略;

    • 如何创建一个动态重放缓冲区并从中无重复地采样。

我们将介绍TorchRL的六个关键组件:

如果你在Google Colab中运行此代码,请确保安装以下依赖项:

!pip3 install torchrl
!pip3 install gym[mujoco]
!pip3 install tqdm

近端策略优化(PPO)是一种策略梯度算法,其中收集一批数据并直接用于训练策略,以在给定一些近端约束的情况下最大化预期回报。你可以将其视为REINFORCE的复杂版本,这是基础的策略优化算法。更多信息,请参阅Proximal Policy Optimization Algorithms论文。

PPO通常被认为是一种快速高效的在线、策略强化算法。TorchRL提供了一个损失模块,可以为您完成所有工作,因此您可以依赖此实现,并专注于解决问题,而不是每次想要训练策略时都重新发明轮子。

为了完整性,这里简要概述了损失的计算内容,尽管这已经由我们的ClipPPOLoss模块处理——算法的工作方式如下: 1. 我们将在环境中通过执行策略来采样一批数据,持续给定的步数。 2. 然后,我们将使用这批数据的随机子样本进行给定次数的优化步骤,使用REINFORCE损失的裁剪版本。 3. 裁剪将对我们的损失设置一个悲观的界限:较低的回报估计将比更高的回报估计更受青睐。 损失的精确公式是:

\[L(s,a,\theta_k,\theta) = \min\left( \frac{\pi_{\theta}(a|s)}{\pi_{\theta_k}(a|s)} A^{\pi_{\theta_k}}(s,a), \;\; g(\epsilon, A^{\pi_{\theta_k}}(s,a)) \right),\]

该损失函数包含两个部分:在最小运算符的第一部分中,我们简单地计算了一个重要性加权的REINFORCE损失(例如,一个REINFORCE损失,我们已经对其进行了修正,以适应当前策略配置滞后于用于数据收集的策略配置的情况)。最小运算符的第二部分是一个类似的损失函数,其中我们在比率超过或低于给定的一对阈值时对其进行了裁剪。

这种损失确保了无论优势是正还是负,都会阻止那些会导致与先前配置产生显著变化的策略更新。

本教程结构如下:

  1. 首先,我们将定义一组用于训练的超参数。

  2. 接下来,我们将重点使用TorchRL的包装器和转换器来创建我们的环境或模拟器。

  3. 接下来,我们将设计策略网络和价值模型,这对于损失函数是必不可少的。这些模块将用于配置我们的损失模块。

  4. 接下来,我们将创建回放缓冲区和数据加载器。

  5. 最后,我们将运行我们的训练循环并分析结果。

在本教程中,我们将使用tensordict库。 TensorDict是TorchRL的通用语言:它帮助我们抽象模块的读写操作,减少对特定数据描述的关注,更多地关注算法本身。

import warnings
warnings.filterwarnings("ignore")
from torch import multiprocessing


from collections import defaultdict

import matplotlib.pyplot as plt
import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (Compose, DoubleToFloat, ObservationNorm, StepCounter,
                          TransformedEnv)
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from tqdm import tqdm

定义超参数

我们为算法设置了超参数。根据可用资源,可以选择在GPU或其他设备上执行策略。 frame_skip 将控制单个动作执行的帧数。其他计算帧数的参数必须根据此值进行校正(因为一个环境步骤实际上会返回 frame_skip 帧)。

is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)
num_cells = 256  # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0

数据收集参数

在收集数据时,我们将能够通过定义frames_per_batch参数来选择每个批次的大小。我们还将定义允许自己使用多少帧(例如与模拟器的交互次数)。一般来说,强化学习算法的目标是尽可能快地解决任务,就环境交互而言:total_frames越低越好。

frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 50_000

PPO 参数

在每次数据收集(或批量收集)时,我们将在一定数量的epochs上运行优化,每次在嵌套的训练循环中消耗我们刚刚获取的整个数据。这里,sub_batch_size与上面的frames_per_batch不同:回想一下,我们正在处理来自收集器的“数据批次”,其大小由frames_per_batch定义,并且我们将在内部训练循环中进一步将其拆分为更小的子批次。这些子批次的大小由sub_batch_size控制。

sub_batch_size = 64  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10  # optimization steps per batch of data collected
clip_epsilon = (
    0.2  # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

定义一个环境

在强化学习(RL)中,环境通常指的是模拟器或控制系统。各种库为强化学习提供了模拟环境,包括Gymnasium(前身为OpenAI Gym)、DeepMind控制套件等。作为一个通用库,TorchRL的目标是为大量RL模拟器提供一个可互换的接口,使您能够轻松地更换一个环境为另一个。例如,创建一个封装的gym环境只需几行代码:

base_env = GymEnv("InvertedDoublePendulum-v4", device=device)

在这段代码中有几点需要注意:首先,我们通过调用GymEnv包装器创建了环境。如果传递了额外的关键字参数,它们将被传递给gym.make方法,从而覆盖最常见的环境构建命令。 另外,也可以直接使用gym.make(env_name, **kwargs)创建一个gym环境,并将其包装在GymWrapper类中。

此外,device 参数:对于 gym,这仅控制输入动作和观察状态将存储的设备,但执行将始终在 CPU 上进行。原因很简单,gym 不支持在设备上执行,除非另有说明。对于其他库,我们可以控制执行设备,并且尽可能在存储和执行后端方面保持一致。

转换

我们将向我们的环境添加一些转换,以准备数据供策略使用。在Gym中,这通常通过包装器实现。TorchRL采取了不同的方法,更类似于其他pytorch领域库,通过使用转换。要向环境添加转换,只需将其包装在TransformedEnv实例中,并向其附加转换序列。转换后的环境将继承被包装环境的设备和元数据,并根据其包含的转换序列进行转换。

归一化

首先进行编码的是归一化变换。 通常来说,最好有数据大致符合单位高斯分布:为了实现这一点,我们将在环境中运行一定数量的随机步骤,并计算这些观察值的汇总统计量。

我们将附加另外两个转换:DoubleToFloat 转换将把双精度条目转换为单精度数字,准备由策略读取。StepCounter 转换将用于在环境终止之前计算步数。我们将使用这个度量作为性能的补充度量。

正如我们稍后将看到的,TorchRL的许多类都依赖于TensorDict来进行通信。你可以将其视为一个具有额外张量功能的Python字典。实际上,这意味着我们将要使用的许多模块需要被告知在它们将接收的tensordict中读取哪个键(in_keys)以及写入哪个键(out_keys)。通常,如果省略了out_keys,则假定in_keys条目将被原地更新。对于我们的转换,我们唯一感兴趣的条目被称为"observation",并且我们的转换层将被指示仅修改此条目:

env = TransformedEnv(
    base_env,
    Compose(
        # normalize observations
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(),
        StepCounter(),
    ),
)

正如你可能已经注意到的,我们已经创建了一个归一化层,但我们没有设置其归一化参数。为此,ObservationNorm 可以自动收集我们环境的汇总统计信息:

env.transform[0].init_stats(num_iter=1000, reduce_dim=0, cat_dim=0)

ObservationNorm 转换现在已经填充了一个位置和一个比例,这些将用于标准化数据。

让我们对我们的汇总统计的形状做一个小小的合理性检查:

print("normalization constant shape:", env.transform[0].loc.shape)
normalization constant shape: torch.Size([11])

环境不仅由其模拟器和转换定义,还由一系列描述其执行过程中可以预期的元数据定义。 出于效率考虑,TorchRL 在环境规范方面非常严格,但你可以轻松检查你的环境规范是否合适。 在我们的示例中,继承自 GymWrapperGymEnv 的类已经负责为你的环境设置适当的规范,因此你不必担心这一点。

不过,让我们通过查看其规格来使用我们转换后的环境来看一个具体的例子。 有三个规格需要查看:observation_spec 定义了在环境中执行操作时预期的内容, reward_spec 指示了奖励的领域,最后是 input_spec(包含action_spec),它代表了环境执行单个步骤所需的一切。

print("observation_spec:", env.observation_spec)
print("reward_spec:", env.reward_spec)
print("input_spec:", env.input_spec)
print("action_spec (as defined by input_spec):", env.action_spec)
observation_spec: CompositeSpec(
    observation: UnboundedContinuousTensorSpec(
        shape=torch.Size([11]),
        space=None,
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    step_count: BoundedTensorSpec(
        shape=torch.Size([1]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
        device=cpu,
        dtype=torch.int64,
        domain=continuous),
    device=cpu,
    shape=torch.Size([]))
reward_spec: UnboundedContinuousTensorSpec(
    shape=torch.Size([1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)
input_spec: CompositeSpec(
    full_state_spec: CompositeSpec(
        step_count: BoundedTensorSpec(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),
            device=cpu,
            dtype=torch.int64,
            domain=continuous),
        device=cpu,
        shape=torch.Size([])),
    full_action_spec: CompositeSpec(
        action: BoundedTensorSpec(
            shape=torch.Size([1]),
            space=ContinuousBox(
                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
            device=cpu,
            dtype=torch.float32,
            domain=continuous),
        device=cpu,
        shape=torch.Size([])),
    device=cpu,
    shape=torch.Size([]))
action_spec (as defined by input_spec): BoundedTensorSpec(
    shape=torch.Size([1]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)

the check_env_specs() 函数运行一个小规模的 rollout 并将其输出与环境规范进行比较。如果没有引发错误,我们可以确信规范已正确定义:

check_env_specs(env)

为了好玩,让我们看看一个简单的随机滚动是什么样子的。你可以调用env.rollout(n_steps)并查看环境输入和输出的概览。动作将自动从动作规范域中抽取,所以你不需要关心设计一个随机采样器。

通常,在每一步中,强化学习环境接收一个动作作为输入,并输出一个观察值、一个奖励和一个完成状态。观察值可能是复合的,意味着它可能由多个张量组成。这对TorchRL来说不是问题,因为整个观察集会自动打包在输出的TensorDict中。在执行一定步数的滚动(例如,一系列环境步骤和随机动作生成)后,我们将检索到一个形状与此轨迹长度匹配的TensorDict实例:

rollout = env.rollout(3)
print("rollout of three steps:", rollout)
print("Shape of the rollout TensorDict:", rollout.batch_size)
rollout of three steps: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([3]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([3]),
    device=cpu,
    is_shared=False)
Shape of the rollout TensorDict: torch.Size([3])

我们的滚动数据形状为torch.Size([3]),这与我们运行的步骤数相匹配。"next"条目指向当前步骤之后的数据。在大多数情况下,时间t"next"数据与t+1的数据相匹配,但如果我们使用某些特定的转换(例如,多步骤),则可能不是这种情况。

政策

PPO 使用随机策略来处理探索。这意味着我们的神经网络将需要输出分布的参数,而不是与采取的动作相对应的单个值。

由于数据是连续的,我们使用Tanh-Normal分布来尊重动作空间的边界。TorchRL提供了这样的分布,我们唯一需要关心的是构建一个神经网络,该网络输出策略所需的正确数量的参数(一个位置或均值,以及一个尺度):

\[f_{\theta}(\text{observation}) = \mu_{\theta}(\text{observation}), \sigma^{+}_{\theta}(\text{observation})\]

这里提出的唯一额外困难是将我们的输出分成两个相等的部分,并将第二部分映射到一个严格正的空间。

我们设计政策的三个步骤:

  1. 定义一个神经网络 D_obs -> 2 * D_action。实际上,我们的 loc (mu) 和 scale (sigma) 都具有维度 D_action

  2. 附加一个NormalParamExtractor来提取位置和尺度(例如,将输入分成两个相等的部分并对尺度参数应用正变换)。

  3. 创建一个概率性的TensorDictModule,可以生成此分布并从中采样。

为了使策略能够通过tensordict数据载体与环境“对话”,我们将nn.Module包装在TensorDictModule中。这个类将简单地准备好提供的in_keys,并在注册的out_keys处就地写入输出。

policy_module = TensorDictModule(
    actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)

我们现在需要根据正态分布的位置和尺度构建一个分布。为此,我们指示ProbabilisticActor类根据位置和尺度参数构建一个TanhNormal。我们还提供了该分布的最小值和最大值,这些值是从环境规格中收集的。

in_keys 的名称(以及上面 TensorDictModule 中的 out_keys 的名称)不能随意设置,因为 TanhNormal 分布构造函数会期望 locscale 关键字参数。也就是说,ProbabilisticActor 也接受 Dict[str, str] 类型的 in_keys,其中键值对指示每个要使用的关键字参数应使用哪个 in_key 字符串。

policy_module = ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "min": env.action_spec.space.low,
        "max": env.action_spec.space.high,
    },
    return_log_prob=True,
    # we'll need the log-prob for the numerator of the importance weights
)

价值网络

价值网络是PPO算法中的一个关键组件,尽管在推理时不会使用它。该模块将读取观察结果并返回对后续轨迹的折扣回报的估计。这使我们能够通过在训练期间即时学习的一些效用估计来分摊学习成本。我们的价值网络与策略共享相同的结构,但为了简单起见,我们为其分配了自己的一组参数。

value_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(1, device=device),
)

value_module = ValueOperator(
    module=value_net,
    in_keys=["observation"],
)

让我们尝试我们的策略和价值模块。正如我们之前所说,使用TensorDictModule可以直接读取环境的输出来运行这些模块,因为它们知道要读取什么信息以及在哪里写入它:

print("Running policy:", policy_module(env.reset()))
print("Running value:", value_module(env.reset()))
Running policy: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        loc: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)
Running value: TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),
        state_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

数据收集器

TorchRL 提供了一组 DataCollector 类。 简而言之,这些类执行三个操作:重置环境, 根据最新的观察结果计算动作,在环境中执行一步, 并重复最后两个步骤,直到环境发出停止信号(或达到 完成状态)。

它们允许您控制每次迭代收集多少帧 (通过frames_per_batch参数), 何时重置环境(通过max_frames_per_traj参数), 策略应在哪个device上执行等。它们还 设计用于高效地与批处理和多进程环境一起工作。

最简单的数据收集器是SyncDataCollector: 它是一个迭代器,您可以使用它来获取给定长度的数据批次,并且一旦收集到总帧数(total_frames)就会停止。 其他数据收集器(MultiSyncDataCollectorMultiaSyncDataCollector)将在一组多进程工作者上以同步和异步方式执行相同的操作。

至于之前的策略和环境,数据收集器将返回与frames_per_batch匹配的元素总数的TensorDict实例。使用TensorDict将数据传递给训练循环,使您可以编写完全忽略滚动内容实际细节的数据加载管道。

collector = SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device=device,
)

回放缓冲区

回放缓冲区是离策略强化学习算法中常见的构建模块。 在策略上下文中,每次收集一批数据时,回放缓冲区都会被重新填充,并且其数据会在一定数量的周期内被重复使用。

TorchRL的重放缓冲区是使用一个通用容器ReplayBuffer构建的,该容器将缓冲区的组件作为参数:存储、写入器、采样器和可能的某些转换。只有存储(指示重放缓冲区的容量)是必需的。我们还指定了一个不重复的采样器,以避免在一个周期内多次采样相同的项目。对于PPO来说,使用重放缓冲区并不是必需的,我们可以简单地从收集的批次中采样子批次,但使用这些类使我们能够以可重复的方式轻松构建内部训练循环。

replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=frames_per_batch),
    sampler=SamplerWithoutReplacement(),
)

损失函数

为了方便起见,PPO损失可以直接从TorchRL中导入,使用ClipPPOLoss类。这是使用PPO的最简单方法:它隐藏了PPO的数学操作和与之相关的控制流程。

PPO 需要计算一些“优势估计”。简而言之,优势是一个反映期望回报值的值,同时处理偏差/方差的权衡。 要计算优势,只需 (1) 构建优势模块,该模块利用我们的值操作符,以及 (2) 在每个周期之前将每批数据传递给它。 GAE 模块将使用新的 "advantage""value_target" 条目更新输入的 tensordict"value_target" 是一个无梯度的张量,表示值网络应通过输入观察表示的经验值。 这两者都将被 ClipPPOLoss 用来返回策略和值损失。

advantage_module = GAE(
    gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True
)

loss_module = ClipPPOLoss(
    actor_network=policy_module,
    critic_network=value_module,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    # these keys match by default but we set this for completeness
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, total_frames // frames_per_batch, 0.0
)

训练循环

我们现在拥有了编写训练循环所需的所有部分。 步骤包括:

  • 收集数据

    • 计算优势

      • 遍历收集的数据以计算损失值

      • 反向传播

      • 优化

      • 重复

    • 重复

  • 重复

logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""

# We iterate over the collector until it reaches the total number of frames it was
# designed to collect:
for i, tensordict_data in enumerate(collector):
    # we now have a batch of data to work with. Let's learn something from it.
    for _ in range(num_epochs):
        # We'll need an "advantage" signal to make PPO work.
        # We re-compute it at each epoch as its value depends on the value
        # network which is updated in the inner loop.
        advantage_module(tensordict_data)
        data_view = tensordict_data.reshape(-1)
        replay_buffer.extend(data_view.cpu())
        for _ in range(frames_per_batch // sub_batch_size):
            subdata = replay_buffer.sample(sub_batch_size)
            loss_vals = loss_module(subdata.to(device))
            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            # Optimization: backward, grad clipping and optimization step
            loss_value.backward()
            # this is not strictly mandatory but it's good practice to keep
            # your gradient norm bounded
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    pbar.update(tensordict_data.numel())
    cum_reward_str = (
        f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
    )
    logs["step_count"].append(tensordict_data["step_count"].max().item())
    stepcount_str = f"step count (max): {logs['step_count'][-1]}"
    logs["lr"].append(optim.param_groups[0]["lr"])
    lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
    if i % 10 == 0:
        # We evaluate the policy once every 10 batches of data.
        # Evaluation is rather simple: execute the policy without exploration
        # (take the expected value of the action distribution) for a given
        # number of steps (1000, which is our ``env`` horizon).
        # The ``rollout`` method of the ``env`` can take a policy as argument:
        # it will then execute this policy at each step.
        with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
            # execute a rollout with the trained policy
            eval_rollout = env.rollout(1000, policy_module)
            logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
            logs["eval reward (sum)"].append(
                eval_rollout["next", "reward"].sum().item()
            )
            logs["eval step_count"].append(eval_rollout["step_count"].max().item())
            eval_str = (
                f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
                f"eval step-count: {logs['eval step_count'][-1]}"
            )
            del eval_rollout
    pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))

    # We're also using a learning rate scheduler. Like the gradient clipping,
    # this is a nice-to-have but nothing necessary for PPO to work.
    scheduler.step()
  0%|          | 0/50000 [00:00<?, ?it/s]
  2%|2         | 1000/50000 [00:03<03:10, 256.79it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.0998 (init= 9.0998), step count (max): 16, lr policy:  0.0003:   2%|2         | 1000/50000 [00:03<03:10, 256.79it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.0998 (init= 9.0998), step count (max): 16, lr policy:  0.0003:   4%|4         | 2000/50000 [00:07<03:04, 259.51it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1175 (init= 9.0998), step count (max): 14, lr policy:  0.0003:   4%|4         | 2000/50000 [00:07<03:04, 259.51it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1175 (init= 9.0998), step count (max): 14, lr policy:  0.0003:   6%|6         | 3000/50000 [00:11<02:59, 261.82it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1509 (init= 9.0998), step count (max): 14, lr policy:  0.0003:   6%|6         | 3000/50000 [00:11<02:59, 261.82it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1509 (init= 9.0998), step count (max): 14, lr policy:  0.0003:   8%|8         | 4000/50000 [00:15<02:54, 263.16it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1931 (init= 9.0998), step count (max): 22, lr policy:  0.0003:   8%|8         | 4000/50000 [00:15<02:54, 263.16it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.1931 (init= 9.0998), step count (max): 22, lr policy:  0.0003:  10%|#         | 5000/50000 [00:19<02:50, 264.70it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2155 (init= 9.0998), step count (max): 27, lr policy:  0.0003:  10%|#         | 5000/50000 [00:19<02:50, 264.70it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2155 (init= 9.0998), step count (max): 27, lr policy:  0.0003:  12%|#2        | 6000/50000 [00:22<02:45, 265.97it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2189 (init= 9.0998), step count (max): 25, lr policy:  0.0003:  12%|#2        | 6000/50000 [00:22<02:45, 265.97it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2189 (init= 9.0998), step count (max): 25, lr policy:  0.0003:  14%|#4        | 7000/50000 [00:26<02:41, 266.81it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2371 (init= 9.0998), step count (max): 47, lr policy:  0.0003:  14%|#4        | 7000/50000 [00:26<02:41, 266.81it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2371 (init= 9.0998), step count (max): 47, lr policy:  0.0003:  16%|#6        | 8000/50000 [00:30<02:37, 267.35it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2277 (init= 9.0998), step count (max): 36, lr policy:  0.0003:  16%|#6        | 8000/50000 [00:30<02:37, 267.35it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2277 (init= 9.0998), step count (max): 36, lr policy:  0.0003:  18%|#8        | 9000/50000 [00:33<02:32, 268.43it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2517 (init= 9.0998), step count (max): 41, lr policy:  0.0003:  18%|#8        | 9000/50000 [00:33<02:32, 268.43it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2517 (init= 9.0998), step count (max): 41, lr policy:  0.0003:  20%|##        | 10000/50000 [00:37<02:32, 262.98it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2600 (init= 9.0998), step count (max): 41, lr policy:  0.0003:  20%|##        | 10000/50000 [00:37<02:32, 262.98it/s]
eval cumulative reward:  101.6955 (init:  101.6955), eval step-count: 10, average reward= 9.2600 (init= 9.0998), step count (max): 41, lr policy:  0.0003:  22%|##2       | 11000/50000 [00:41<02:26, 265.59it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2575 (init= 9.0998), step count (max): 38, lr policy:  0.0003:  22%|##2       | 11000/50000 [00:41<02:26, 265.59it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2575 (init= 9.0998), step count (max): 38, lr policy:  0.0003:  24%|##4       | 12000/50000 [00:45<02:23, 265.18it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2730 (init= 9.0998), step count (max): 56, lr policy:  0.0003:  24%|##4       | 12000/50000 [00:45<02:23, 265.18it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2730 (init= 9.0998), step count (max): 56, lr policy:  0.0003:  26%|##6       | 13000/50000 [00:48<02:18, 267.12it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2719 (init= 9.0998), step count (max): 55, lr policy:  0.0003:  26%|##6       | 13000/50000 [00:48<02:18, 267.12it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2719 (init= 9.0998), step count (max): 55, lr policy:  0.0003:  28%|##8       | 14000/50000 [00:52<02:14, 268.62it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2725 (init= 9.0998), step count (max): 102, lr policy:  0.0003:  28%|##8       | 14000/50000 [00:52<02:14, 268.62it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2725 (init= 9.0998), step count (max): 102, lr policy:  0.0003:  30%|###       | 15000/50000 [00:56<02:09, 269.74it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2774 (init= 9.0998), step count (max): 95, lr policy:  0.0002:  30%|###       | 15000/50000 [00:56<02:09, 269.74it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2774 (init= 9.0998), step count (max): 95, lr policy:  0.0002:  32%|###2      | 16000/50000 [01:00<02:05, 270.42it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2724 (init= 9.0998), step count (max): 59, lr policy:  0.0002:  32%|###2      | 16000/50000 [01:00<02:05, 270.42it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2724 (init= 9.0998), step count (max): 59, lr policy:  0.0002:  34%|###4      | 17000/50000 [01:03<02:01, 271.23it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2809 (init= 9.0998), step count (max): 89, lr policy:  0.0002:  34%|###4      | 17000/50000 [01:03<02:01, 271.23it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2809 (init= 9.0998), step count (max): 89, lr policy:  0.0002:  36%|###6      | 18000/50000 [01:07<01:57, 271.33it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2828 (init= 9.0998), step count (max): 83, lr policy:  0.0002:  36%|###6      | 18000/50000 [01:07<01:57, 271.33it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2828 (init= 9.0998), step count (max): 83, lr policy:  0.0002:  38%|###8      | 19000/50000 [01:11<01:54, 271.37it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2828 (init= 9.0998), step count (max): 69, lr policy:  0.0002:  38%|###8      | 19000/50000 [01:11<01:54, 271.37it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2828 (init= 9.0998), step count (max): 69, lr policy:  0.0002:  40%|####      | 20000/50000 [01:14<01:50, 270.62it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2765 (init= 9.0998), step count (max): 66, lr policy:  0.0002:  40%|####      | 20000/50000 [01:14<01:50, 270.62it/s]
eval cumulative reward:  437.4563 (init:  101.6955), eval step-count: 46, average reward= 9.2765 (init= 9.0998), step count (max): 66, lr policy:  0.0002:  42%|####2     | 21000/50000 [01:18<01:46, 271.33it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2970 (init= 9.0998), step count (max): 121, lr policy:  0.0002:  42%|####2     | 21000/50000 [01:18<01:46, 271.33it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2970 (init= 9.0998), step count (max): 121, lr policy:  0.0002:  44%|####4     | 22000/50000 [01:22<01:44, 267.08it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.3032 (init= 9.0998), step count (max): 125, lr policy:  0.0002:  44%|####4     | 22000/50000 [01:22<01:44, 267.08it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.3032 (init= 9.0998), step count (max): 125, lr policy:  0.0002:  46%|####6     | 23000/50000 [01:26<01:42, 262.83it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2970 (init= 9.0998), step count (max): 78, lr policy:  0.0002:  46%|####6     | 23000/50000 [01:26<01:42, 262.83it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2970 (init= 9.0998), step count (max): 78, lr policy:  0.0002:  48%|####8     | 24000/50000 [01:29<01:37, 265.77it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2985 (init= 9.0998), step count (max): 113, lr policy:  0.0002:  48%|####8     | 24000/50000 [01:29<01:37, 265.77it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2985 (init= 9.0998), step count (max): 113, lr policy:  0.0002:  50%|#####     | 25000/50000 [01:33<01:33, 267.84it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.3044 (init= 9.0998), step count (max): 102, lr policy:  0.0002:  50%|#####     | 25000/50000 [01:33<01:33, 267.84it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.3044 (init= 9.0998), step count (max): 102, lr policy:  0.0002:  52%|#####2    | 26000/50000 [01:37<01:29, 269.15it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2937 (init= 9.0998), step count (max): 87, lr policy:  0.0001:  52%|#####2    | 26000/50000 [01:37<01:29, 269.15it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2937 (init= 9.0998), step count (max): 87, lr policy:  0.0001:  54%|#####4    | 27000/50000 [01:41<01:25, 268.28it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2961 (init= 9.0998), step count (max): 70, lr policy:  0.0001:  54%|#####4    | 27000/50000 [01:41<01:25, 268.28it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2961 (init= 9.0998), step count (max): 70, lr policy:  0.0001:  56%|#####6    | 28000/50000 [01:44<01:21, 268.42it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2842 (init= 9.0998), step count (max): 60, lr policy:  0.0001:  56%|#####6    | 28000/50000 [01:44<01:21, 268.42it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2842 (init= 9.0998), step count (max): 60, lr policy:  0.0001:  58%|#####8    | 29000/50000 [01:48<01:17, 269.30it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2952 (init= 9.0998), step count (max): 67, lr policy:  0.0001:  58%|#####8    | 29000/50000 [01:48<01:17, 269.30it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2952 (init= 9.0998), step count (max): 67, lr policy:  0.0001:  60%|######    | 30000/50000 [01:52<01:14, 270.23it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2988 (init= 9.0998), step count (max): 75, lr policy:  0.0001:  60%|######    | 30000/50000 [01:52<01:14, 270.23it/s]
eval cumulative reward:  867.6711 (init:  101.6955), eval step-count: 92, average reward= 9.2988 (init= 9.0998), step count (max): 75, lr policy:  0.0001:  62%|######2   | 31000/50000 [01:55<01:10, 270.84it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.2974 (init= 9.0998), step count (max): 77, lr policy:  0.0001:  62%|######2   | 31000/50000 [01:55<01:10, 270.84it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.2974 (init= 9.0998), step count (max): 77, lr policy:  0.0001:  64%|######4   | 32000/50000 [01:59<01:07, 267.85it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3021 (init= 9.0998), step count (max): 100, lr policy:  0.0001:  64%|######4   | 32000/50000 [01:59<01:07, 267.85it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3021 (init= 9.0998), step count (max): 100, lr policy:  0.0001:  66%|######6   | 33000/50000 [02:03<01:03, 268.88it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3097 (init= 9.0998), step count (max): 175, lr policy:  0.0001:  66%|######6   | 33000/50000 [02:03<01:03, 268.88it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3097 (init= 9.0998), step count (max): 175, lr policy:  0.0001:  68%|######8   | 34000/50000 [02:06<00:59, 270.19it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3168 (init= 9.0998), step count (max): 140, lr policy:  0.0001:  68%|######8   | 34000/50000 [02:06<00:59, 270.19it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3168 (init= 9.0998), step count (max): 140, lr policy:  0.0001:  70%|#######   | 35000/50000 [02:10<00:56, 264.93it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3159 (init= 9.0998), step count (max): 117, lr policy:  0.0001:  70%|#######   | 35000/50000 [02:10<00:56, 264.93it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3159 (init= 9.0998), step count (max): 117, lr policy:  0.0001:  72%|#######2  | 36000/50000 [02:14<00:52, 267.19it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3155 (init= 9.0998), step count (max): 132, lr policy:  0.0001:  72%|#######2  | 36000/50000 [02:14<00:52, 267.19it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3155 (init= 9.0998), step count (max): 132, lr policy:  0.0001:  74%|#######4  | 37000/50000 [02:18<00:48, 268.67it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3185 (init= 9.0998), step count (max): 118, lr policy:  0.0001:  74%|#######4  | 37000/50000 [02:18<00:48, 268.67it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3185 (init= 9.0998), step count (max): 118, lr policy:  0.0001:  76%|#######6  | 38000/50000 [02:21<00:44, 270.04it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3231 (init= 9.0998), step count (max): 147, lr policy:  0.0000:  76%|#######6  | 38000/50000 [02:21<00:44, 270.04it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3231 (init= 9.0998), step count (max): 147, lr policy:  0.0000:  78%|#######8  | 39000/50000 [02:25<00:40, 270.74it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3233 (init= 9.0998), step count (max): 173, lr policy:  0.0000:  78%|#######8  | 39000/50000 [02:25<00:40, 270.74it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3233 (init= 9.0998), step count (max): 173, lr policy:  0.0000:  80%|########  | 40000/50000 [02:29<00:36, 271.60it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3168 (init= 9.0998), step count (max): 135, lr policy:  0.0000:  80%|########  | 40000/50000 [02:29<00:36, 271.60it/s]
eval cumulative reward:  662.4586 (init:  101.6955), eval step-count: 70, average reward= 9.3168 (init= 9.0998), step count (max): 135, lr policy:  0.0000:  82%|########2 | 41000/50000 [02:32<00:33, 272.21it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3168 (init= 9.0998), step count (max): 135, lr policy:  0.0000:  82%|########2 | 41000/50000 [02:32<00:33, 272.21it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3168 (init= 9.0998), step count (max): 135, lr policy:  0.0000:  84%|########4 | 42000/50000 [02:36<00:29, 270.49it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3261 (init= 9.0998), step count (max): 166, lr policy:  0.0000:  84%|########4 | 42000/50000 [02:36<00:29, 270.49it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3261 (init= 9.0998), step count (max): 166, lr policy:  0.0000:  86%|########6 | 43000/50000 [02:40<00:25, 271.36it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3296 (init= 9.0998), step count (max): 193, lr policy:  0.0000:  86%|########6 | 43000/50000 [02:40<00:25, 271.36it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3296 (init= 9.0998), step count (max): 193, lr policy:  0.0000:  88%|########8 | 44000/50000 [02:43<00:22, 271.95it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3385 (init= 9.0998), step count (max): 182, lr policy:  0.0000:  88%|########8 | 44000/50000 [02:43<00:22, 271.95it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3385 (init= 9.0998), step count (max): 182, lr policy:  0.0000:  90%|######### | 45000/50000 [02:47<00:18, 272.30it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3294 (init= 9.0998), step count (max): 189, lr policy:  0.0000:  90%|######### | 45000/50000 [02:47<00:18, 272.30it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3294 (init= 9.0998), step count (max): 189, lr policy:  0.0000:  92%|#########2| 46000/50000 [02:51<00:15, 266.35it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3320 (init= 9.0998), step count (max): 197, lr policy:  0.0000:  92%|#########2| 46000/50000 [02:51<00:15, 266.35it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3320 (init= 9.0998), step count (max): 197, lr policy:  0.0000:  94%|#########3| 47000/50000 [02:55<00:11, 268.64it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3278 (init= 9.0998), step count (max): 160, lr policy:  0.0000:  94%|#########3| 47000/50000 [02:55<00:11, 268.64it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3278 (init= 9.0998), step count (max): 160, lr policy:  0.0000:  96%|#########6| 48000/50000 [02:58<00:07, 270.15it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3257 (init= 9.0998), step count (max): 162, lr policy:  0.0000:  96%|#########6| 48000/50000 [02:58<00:07, 270.15it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3257 (init= 9.0998), step count (max): 162, lr policy:  0.0000:  98%|#########8| 49000/50000 [03:02<00:03, 271.35it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3230 (init= 9.0998), step count (max): 118, lr policy:  0.0000:  98%|#########8| 49000/50000 [03:02<00:03, 271.35it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3230 (init= 9.0998), step count (max): 118, lr policy:  0.0000: 100%|##########| 50000/50000 [03:06<00:00, 272.31it/s]
eval cumulative reward:  344.6715 (init:  101.6955), eval step-count: 36, average reward= 9.3355 (init= 9.0998), step count (max): 348, lr policy:  0.0000: 100%|##########| 50000/50000 [03:06<00:00, 272.31it/s]

结果

在达到1M步数上限之前,算法应该已经达到了1000步的最大步数,这是轨迹被截断前的最大步数。

plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Return (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()
training rewards (average), Max step count (training), Return (test), Max step count (test)

结论和下一步

在本教程中,我们学习了:

  1. 如何使用torchrl创建和自定义环境;

  2. 如何编写模型和损失函数;

  3. 如何设置一个典型的训练循环。

如果你想对这个教程进行更多的实验,你可以应用以下修改:

  • 从效率的角度来看, 我们可以并行运行多个模拟以加快数据收集速度。 有关更多信息,请查看 ParallelEnv

  • 从日志记录的角度来看,可以在请求渲染后向环境中添加一个torchrl.record.VideoRecorder变换,以获取倒立摆动作的视觉渲染。查看torchrl.record以了解更多信息。

脚本总运行时间: (3 分钟 7.796 秒)

Gallery generated by Sphinx-Gallery

优云智算