paddlespeech.t2s.training.updater模块

class paddlespeech.t2s.training.updater.UpdaterBase(init_state=None)[来源]

基础: object

更新器是给定数据加载器和优化器时,模型训练的抽象。

update_core 方法是在训练循环中的一步,仅进行必要的操作(获取一个批次,前向和后向,更新参数)。

其他内容是由扩展构成的。可视化、保存、加载以及定期验证和评估在这里没有考虑。

但是即使在这样简单的案例中,事情也没有那么简单。有人试图标准化这个过程,只需要模型和数据集,并且自动完成所有操作。但这可能会影响灵活性。

如果我们假设来自数据加载器的批量产出只是模型的输入,我们会发现一些模型需要更多的参数,或者只是一些关键字参数。但这阻止了我们过于简化它。

从另一个角度来看,批次可能不仅包括输入,还包括目标。但模型的前向方法可能只需要输入。我们可以将一个字典或一个超长元组传递给模型,让它选择它真正需要的内容。但这是一种懒惰接口的滥用。

毕竟,我们关心的是模型是如何训练的。但模型是如何用于推理的呢?我们想要控制模型的训练方式。我们只是不想与其他辅助代码混在一起。

因此,最佳实践是定义一个模型并为其定义一个更新器。

方法

加载

保存

set_state_dict

状态字典

更新

load(path)[来源]
save(path)[来源]
set_state_dict(state_dict)[来源]
state_dict()[来源]
update(batch)[来源]
class paddlespeech.t2s.training.updater.UpdaterState(iteration: int = 0, epoch: int = 0)[来源]

基础: object

epoch: int = 0
iteration: int = 0