梯度范数裁剪训练回调
- class GradientNormClippingTrainingCallback(max_norm: float, norm_type: float | None = None)[源代码]
基础类:
TrainingCallback在优化器步进之前进行梯度裁剪的回调函数,使用
torch.nn.utils.clip_grad_norm_()。初始化回调。
- Parameters:
max_norm (float) – 用于梯度裁剪的最大梯度范数。
norm_type (float | None) – 用于最大梯度范数的梯度范数类型,参见
torch.nn.utils.clip_grad_norm_()
方法总结
pre_step(**kwargs)在优化器的步骤之前调用。
方法文档