ray.rllib.core.learner.learner.Learner.postprocess_gradients#
- Learner.postprocess_gradients(gradients_dict: Dict[Hashable, torch.Tensor | tf.Variable]) Dict[Hashable, torch.Tensor | tf.Variable][源代码]#
对梯度应用潜在的后处理操作。
此方法在计算梯度后调用,并在优化器将梯度应用于相应的模块之前修改它们。这可能包括按值、范数或全局范数进行梯度裁剪,或其他特定算法的梯度后处理步骤。
这个默认实现调用了
self.postprocess_gradients_for_module()在我们的 MultiRLModule 中的每个子模块:self.module,并返回累积的梯度字典。- 参数:
gradients_dict – 一个与 self._params 相同(扁平)格式的梯度字典。请注意,诸如模块ID等顶层结构将不再存在于这个字典中。它仅仅将梯度张量引用映射到梯度张量。
- 返回:
一个包含更新梯度的字典,其结构与传入的
gradients_dict参数完全相同(扁平结构)。