paddlespeech.t2s.training.updater模块
- class paddlespeech.t2s.training.updater.UpdaterBase(init_state=None)[来源]
基础:
object更新器是给定数据加载器和优化器时,模型训练的抽象。
update_core 方法是在训练循环中的一步,仅进行必要的操作(获取一个批次,前向和后向,更新参数)。
其他内容是由扩展构成的。可视化、保存、加载以及定期验证和评估在这里没有考虑。
但是即使在这样简单的案例中,事情也没有那么简单。有人试图标准化这个过程,只需要模型和数据集,并且自动完成所有操作。但这可能会影响灵活性。
如果我们假设来自数据加载器的批量产出只是模型的输入,我们会发现一些模型需要更多的参数,或者只是一些关键字参数。但这阻止了我们过于简化它。
从另一个角度来看,批次可能不仅包括输入,还包括目标。但模型的前向方法可能只需要输入。我们可以将一个字典或一个超长元组传递给模型,让它选择它真正需要的内容。但这是一种懒惰接口的滥用。
毕竟,我们关心的是模型是如何训练的。但模型是如何用于推理的呢?我们想要控制模型的训练方式。我们只是不想与其他辅助代码混在一起。
因此,最佳实践是定义一个模型并为其定义一个更新器。
方法
加载
保存
set_state_dict
状态字典
更新