ray.rllib.core.learner.learner.Learner.postprocess_gradients_for_module#

Learner.postprocess_gradients_for_module(*, module_id: str, config: AlgorithmConfig | None = None, module_gradients_dict: Dict[Hashable, torch.Tensor | tf.Variable]) Dict[Hashable, torch.Tensor | tf.Variable][源代码]#

对给定模块的梯度应用后处理操作。

参数:
  • module_id – 我们将对计算出的梯度进行后处理的模块ID。请注意,module_gradients_dict 已经只包含属于此 module_id 的梯度张量。其他 module_id 的梯度在此调用中不可用。

  • config – 特定于给定 module_id 的 AlgorithmConfig。

  • module_gradients_dict – 一个与 self._params 相同(扁平)格式的梯度字典,将梯度引用映射到梯度张量,这些张量将进行后处理。您可以直接更改这些张量,或者创建新的张量并在新的字典中返回它们。

返回:

一个包含更新梯度的字典,其结构与传入的 module_gradients_dict 参数完全相同(扁平结构)。