ray.rllib.core.learner.learner.Learner.postprocess_gradients_for_module#
- Learner.postprocess_gradients_for_module(*, module_id: str, config: AlgorithmConfig | None = None, module_gradients_dict: Dict[Hashable, torch.Tensor | tf.Variable]) Dict[Hashable, torch.Tensor | tf.Variable][源代码]#
对给定模块的梯度应用后处理操作。
- 参数:
module_id – 我们将对计算出的梯度进行后处理的模块ID。请注意,
module_gradients_dict已经只包含属于此module_id的梯度张量。其他module_id的梯度在此调用中不可用。config – 特定于给定
module_id的 AlgorithmConfig。module_gradients_dict – 一个与 self._params 相同(扁平)格式的梯度字典,将梯度引用映射到梯度张量,这些张量将进行后处理。您可以直接更改这些张量,或者创建新的张量并在新的字典中返回它们。
- 返回:
一个包含更新梯度的字典,其结构与传入的
module_gradients_dict参数完全相同(扁平结构)。