paddlespeech.t2s.training.extensions.snapshot 模块
- class paddlespeech.t2s.training.extensions.snapshot.Snapshot(max_size: int = 5, snapshot_on_error: bool = False)[来源]
基础:
Extension一个扩展,用于在训练器内部对更新器对象进行快照。 通过调用更新器的 save 方法来实现。
默认情况下,Updater 保存其 state_dict,它包含更新器的状态(即轮次和迭代)以及所有模型参数和优化器状态。如果训练器中的更新器是 StandardUpdater 的子类,一切都可以正常进行。
- Arsg:
checkpoint_dir (Union[str, Path]): 保存检查点的目录。
- Attributes:
- name
方法
__call__(trainer)扩展的主要操作。
finalize(trainer)训练完成时执行的操作。
full()它跟踪的快照数量是否大于 max_size。
initialize(trainer)设置此扩展。
on_error(trainer, exc, tb)处理训练过程中在最终化之前引发的错误。
save_checkpoint_and_update(trainer)保存新的快照,并在需要时删除最旧的快照。
- default_name = 'snapshot'
- priority = -100
- trigger = (1, 'epoch')