paddlespeech.s2t.training.trainer 模块
- class paddlespeech.s2t.training.trainer.Trainer(config, args)[来源]
基础:
object一个实验模板,用于构建训练代码,并处理保存、加载、日志记录和可视化等事务。它旨在灵活和简单。
所以它只以标准方式处理输出目录(为输出创建目录,创建检查点目录,转储正在使用的配置并创建可视化器和记录器),而无需强制执行任何输入-输出协议给模型和数据加载器。它将主要部分留给用户自行实现(设置模型、标准、优化器,定义训练步骤,定义验证函数并自定义所有文本和可视化日志)。它不保存太多的样板代码。用户仍然需要手动编写前向/反向/更新,但如有需要,他们可以自由添加非标准行为。我们有一些约定需要遵循。
1. 实验应具有model、optimizer、train_loader和valid_loader、config和args属性。
2. 配置应具有一个training字段,其中包含valid_interval、save_interval和max_iteration键。它作为触发器,用于调用验证、检查点和停止实验。
3. 有四个方法,即train_batch、valid、setup_model和setup_dataloader应该被实现。如有需要,自由添加/重写其他方法和独立函数。- Parameters:
- config: yacs.config.CfgNode
用于实验的配置。
- args: argparse.Namespace
解析的命令行参数。
- Examples
- --------
- >>> def main_sp(config, args):
- >>> exp = Trainer(config, args)
- >>> exp.setup()
- >>> exp.run()
- >>>
- >>> config = get_cfg_defaults()
- >>> parser = default_argument_parser()
- >>> args = parser.parse_args()
- >>> if args.config:
- >>> config.merge_from_file(args.config)
- >>> if args.opts:
- >>> config.merge_from_list(args.opts)
- >>> config.freeze()
- >>>
- >>> if args.ngpu > 1:
- >>> dist.spawn(main_sp, args=(config, args), nprocs=args.ngpu)
- >>> else:
- >>> main_sp(config, args)
- Attributes:
parallel一个标志,指示实验是否应该使用多进程运行。
- train
方法
align()对齐。
destory()关闭可视化工具以避免训练后卡住
do_train()训练过程由 эпох 控制。
保存此次实验使用的配置。
export()测试。
按世代的 batch_sampler 种子
重置训练加载器的种子并增加 epoch。
restore()从输出目录中的最新检查点恢复或加载指定的检查点。
从输出目录中的最新检查点恢复或加载指定的检查点。
run()实验设置后的常规流程。
进行CTC对齐
执行模型导出
run_test()执行测试/解码
save([tag, infos])保存检查点(模型参数和优化器状态)。
setup()设置实验。
设置训练数据加载器和验证数据加载器。
设置模型、标准和优化器等。
创建一个用于输出的目录。
初始化一个可视化工具以记录实验。
test()测试。
训练循环。
valid()验证。
训练后批次
训练前
评估
- property parallel
一个标志,用于指示实验是否应该使用多进程运行。
- save(tag=None, infos: dict = None)[来源]
保存检查点(模型参数和优化器状态)。
- Args:
标签 (int 或 str,可选): None 表示步长,否则使用标签,例如 epoch。默认值为 None。
infos (字典,可选): 要保存的元数据。默认值为 None。
- setup_visualizer()[来源]
初始化一个可视化工具来记录实验。
可视化日志保存在输出目录中。
备注
只有主进程有一个可视化工具。 在多进程中使用多个可视化工具写入同一个日志文件可能会导致意外行为。
- property train