paddlespeech.s2t.training.updaters.updater模块

class paddlespeech.s2t.training.updaters.updater.UpdaterBase(init_state=None)[来源]

基础: object

更新器是给定数据加载器和优化器的模型训练的抽象。 update_core 方法是在训练循环中的一个步骤,仅包含必要的操作(获取一个批次、前向和反向传播、更新参数)。 其他内容都是扩展。 可视化、保存、加载和定期验证与评估在这里不予考虑。 但即使在如此简单的情况下,事情也并非那么简单。 有一个试图标准化这个过程的尝试,只需要模型和数据集,并自动完成所有工作。 但这可能会损害灵活性。 如果我们假设从数据加载器输出的批次只是模型的输入,我们会发现某些模型需要更多的参数,或者只是一些关键字参数。 但这阻止我们过于简单化。 从另一个角度来看,批次可能不仅包括输入,还包括目标。 但模型的前向方法可能只需要输入。 我们可以将一个字典或一个超长元组传递给模型,让它选择它真正需要的内容。 但这是一种懒惰接口的滥用。 毕竟,我们关心的是模型是如何训练的。 但只是模型是如何用于推断的。 我们希望控制模型的训练方式。 我们只是不想被其他辅助代码弄得一团糟。 所以最佳实践是定义一个模型并为其定义一个更新器。

方法

加载

保存

set_state_dict

状态字典

更新

load(path)[来源]
save(path)[来源]
set_state_dict(state_dict)[来源]
state_dict()[来源]
update(batch)[来源]
class paddlespeech.s2t.training.updaters.updater.UpdaterState(iteration: int = 0, epoch: int = 0)[来源]

基础: object

epoch: int = 0
iteration: int = 0