paddlespeech.t2s.training.extensions.snapshot 模块

class paddlespeech.t2s.training.extensions.snapshot.Snapshot(max_size: int = 5, snapshot_on_error: bool = False)[来源]

基础: Extension

一个扩展,用于在训练器内部对更新器对象进行快照。 通过调用更新器的 save 方法来实现。

默认情况下,Updater 保存其 state_dict,它包含更新器的状态(即轮次和迭代)以及所有模型参数和优化器状态。如果训练器中的更新器是 StandardUpdater 的子类,一切都可以正常进行。

Arsg:

checkpoint_dir (Union[str, Path]): 保存检查点的目录。

Attributes:
name

方法

__call__(trainer)

扩展的主要操作。

finalize(trainer)

训练完成时执行的操作。

full()

它跟踪的快照数量是否大于 max_size。

initialize(trainer)

设置此扩展。

on_error(trainer, exc, tb)

处理训练过程中在最终化之前引发的错误。

save_checkpoint_and_update(trainer)

保存新的快照,并在需要时删除最旧的快照。

default_name = 'snapshot'
full()[来源]

它跟踪的快照数量是否大于 max_size。

initialize(trainer: 训练者)[来源]

设置此扩展。

on_error(trainer, exc, tb)[来源]

处理在最终确认之前训练期间引发的错误。

priority = -100
save_checkpoint_and_update(trainer: 训练师)[来源]

保存新快照,并在需要时移除最旧的快照。

trigger = (1, 'epoch')
paddlespeech.t2s.training.extensions.snapshot.load_records(records_fp)[来源]

加载记录文件(json 行。)