ray.rllib.algorithms.algorithm_config.AlgorithmConfig.框架#

AlgorithmConfig.framework(framework: str | None = <ray.rllib.utils.from_config._NotProvided object>, *, eager_tracing: bool | None = <ray.rllib.utils.from_config._NotProvided object>, eager_max_retraces: int | None = <ray.rllib.utils.from_config._NotProvided object>, tf_session_args: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, local_tf_session_args: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner: bool | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner_what_to_compile: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner_dynamo_mode: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_learner_dynamo_backend: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_worker: bool | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_worker_dynamo_backend: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_compile_worker_dynamo_mode: str | None = <ray.rllib.utils.from_config._NotProvided object>, torch_ddp_kwargs: ~typing.Dict[str, ~typing.Any] | None = <ray.rllib.utils.from_config._NotProvided object>, torch_skip_nan_gradients: bool | None = <ray.rllib.utils.from_config._NotProvided object>) AlgorithmConfig[源代码]#

设置配置的深度学习框架设置。

参数:
  • framework – torch: PyTorch; tf2: TensorFlow 2.x (eager execution 或 eager_tracing=True 时 traced); tf: TensorFlow (静态图);

  • eager_tracing – 在急切模式下启用跟踪。这大大提高了性能(速度提升约2倍),但由于在初始急切传递后不会评估Python代码,因此调试起来稍微困难一些。仅在框架=tf2时可能。

  • eager_max_retraces – 在引发运行时错误之前,tf.function 重新追踪的最大次数。这是为了防止 ..._eager_traced 策略中方法的未被注意的重新追踪,这可能会使执行速度降低4倍,而用户可能不会注意到导致这种减速的根本原因。仅在 framework=tf2 时需要。设置为 None 以忽略重新追踪计数,并且永远不会抛出错误。

  • tf_session_args – 默认配置TF以进行单进程操作。

  • local_tf_session_args – 在本地工作节点上覆盖以下 tf 会话参数

  • torch_compile_learner – 如果为 True,学习器上的 TorchRLModules 的 forward_train 方法将被编译。如果未指定,默认情况下会在学习器上编译 forward_train。

  • torch_compile_learner_what_to_compile – 一个指定在 torch_compile_learner 为 True 时,学习者端需要编译内容的 TorchCompileWhatToCompile 模式。详情及使用建议请参阅 TorchCompileWhatToCompile。

  • torch_compile_learner_dynamo_backend – 在学习者上使用的 torch dynamo 后端。

  • torch_compile_learner_dynamo_mode – 在学习者上使用的 torch dynamo 模式。

  • torch_compile_worker – 如果为 True,则在工作者上的 TorchRLModules 上的前向探索和推理方法将被编译。如果未指定,默认情况下不会在工作者上编译前向方法,因为重新追踪可能会很昂贵。

  • torch_compile_worker_dynamo_backend – 在工作者上使用的 torch dynamo 后端。

  • torch_compile_worker_dynamo_mode – 在工作者上使用的 torch dynamo 模式。

  • torch_ddp_kwargs – 在使用 num_learners > 1 时传递给 torch.nn.parallel.DistributedDataParallel 的 kwargs。这在搜索未使用的参数时特别有用,这些参数在反向传播过程中未被使用。这可以为自定义模型中的错误提供提示,其中某些参数虽然在反向传播中应该被触及,但实际上并未被触及。

  • torch_skip_nan_gradients – 如果带有 nan 梯度的更新应该完全跳过。如果它们包含任何 nan 梯度,这将完全跳过优化器中的更新。这有助于避免偏移基于移动平均的优化器 - 如 Adam。这可以在训练阶段帮助避免策略更新高度不稳定的情况,例如在训练的早期阶段或使用高度探索性策略时。在这些阶段,许多梯度可能会变为 nan,将它们设置为零可能会破坏优化器的内部状态。默认值为 False,并将 nan 梯度设置为零。如果遇到许多 nan 梯度,请考虑 (a) 通过将 AlgorithmConfig 中的 log_gradients 设置为 True 来监控梯度,(b) 通过 AlgorithmConfig.rl_module 中的 model_config_dict 使用适当的权重初始化(例如 Xavier、Kaiming),和/或 (c) 通过 AlgorithmConfig.training 中的 grad_clip 进行梯度裁剪。

返回:

这个更新的 AlgorithmConfig 对象。