ray.rllib.core.learner.learner.Learner.postprocess_gradients#

Learner.postprocess_gradients(gradients_dict: Dict[Hashable, torch.Tensor | tf.Variable]) Dict[Hashable, torch.Tensor | tf.Variable][源代码]#

对梯度应用潜在的后处理操作。

此方法在计算梯度后调用,并在优化器将梯度应用于相应的模块之前修改它们。这可能包括按值、范数或全局范数进行梯度裁剪,或其他特定算法的梯度后处理步骤。

这个默认实现调用了 self.postprocess_gradients_for_module() 在我们的 MultiRLModule 中的每个子模块:self.module,并返回累积的梯度字典。

参数:

gradients_dict – 一个与 self._params 相同(扁平)格式的梯度字典。请注意,诸如模块ID等顶层结构将不再存在于这个字典中。它仅仅将梯度张量引用映射到梯度张量。

返回:

一个包含更新梯度的字典,其结构与传入的 gradients_dict 参数完全相同(扁平结构)。