paddlespeech.s2t.training.trainer 模块

class paddlespeech.s2t.training.trainer.Trainer(config, args)[来源]

基础: object

一个实验模板,用于构建训练代码,并处理保存、加载、日志记录和可视化等事务。它旨在灵活和简单。

所以它只以标准方式处理输出目录(为输出创建目录,创建检查点目录,转储正在使用的配置并创建可视化器和记录器),而无需强制执行任何输入-输出协议给模型和数据加载器。它将主要部分留给用户自行实现(设置模型、标准、优化器,定义训练步骤,定义验证函数并自定义所有文本和可视化日志)。它不保存太多的样板代码。用户仍然需要手动编写前向/反向/更新,但如有需要,他们可以自由添加非标准行为。我们有一些约定需要遵循。
1. 实验应具有 modeloptimizertrain_loadervalid_loaderconfigargs 属性。
2. 配置应具有一个 training 字段,其中包含 valid_intervalsave_intervalmax_iteration 键。它作为触发器,用于调用验证、检查点和停止实验。
3. 有四个方法,即 train_batchvalidsetup_modelsetup_dataloader 应该被实现。如有需要,自由添加/重写其他方法和独立函数。

Parameters:
config: yacs.config.CfgNode

用于实验的配置。

args: argparse.Namespace

解析的命令行参数。

Examples
--------
>>> def main_sp(config, args):
>>> exp = Trainer(config, args)
>>> exp.setup()
>>> exp.run()
>>>
>>> config = get_cfg_defaults()
>>> parser = default_argument_parser()
>>> args = parser.parse_args()
>>> if args.config:
>>> config.merge_from_file(args.config)
>>> if args.opts:
>>> config.merge_from_list(args.opts)
>>> config.freeze()
>>>
>>> if args.ngpu > 1:
>>> dist.spawn(main_sp, args=(config, args), nprocs=args.ngpu)
>>> else:
>>> main_sp(config, args)
Attributes:
parallel

一个标志,指示实验是否应该使用多进程运行。

train

方法

align()

对齐。

destory()

关闭可视化工具以避免训练后卡住

do_train()

训练过程由 эпох 控制。

dump_config()

保存此次实验使用的配置。

export()

测试。

maybe_batch_sampler_step()

按世代的 batch_sampler 种子

new_epoch()

重置训练加载器的种子并增加 epoch

restore()

从输出目录中的最新检查点恢复或加载指定的检查点。

resume_or_scratch()

从输出目录中的最新检查点恢复或加载指定的检查点。

run()

实验设置后的常规流程。

run_align()

进行CTC对齐

run_export()

执行模型导出

run_test()

执行测试/解码

save([tag, infos])

保存检查点(模型参数和优化器状态)。

setup()

设置实验。

setup_dataloader()

设置训练数据加载器和验证数据加载器。

setup_model()

设置模型、标准和优化器等。

setup_output_dir()

创建一个用于输出的目录。

setup_visualizer()

初始化一个可视化工具以记录实验。

test()

测试。

train_batch()

训练循环。

valid()

验证。

训练后批次

训练前

评估

after_train_batch()[来源]
align()[来源]

对齐。子类应该在Tester中实现这个方法。

before_train()[来源]
destory()[来源]

关闭可视化工具以避免训练后出现挂起

do_train()[来源]

训练过程由周期控制。

dump_config()[来源]

保存用于此实验的配置。

它被保存在实验开始时的输出目录中的 config.yaml 文件里。

eval()[来源]
export()[来源]

测试。子类应该在Tester中实现这个方法。

maybe_batch_sampler_step()[来源]

按周期的批采样种子

new_epoch()[来源]

重置训练加载器种子并增加 epoch

property parallel

一个标志,用于指示实验是否应该使用多进程运行。

restore()[来源]

从输出目录中的最新检查点恢复,或加载指定的检查点。

如果 args.checkpoint_path 不是 None,则加载检查点,否则继续训练。

resume_or_scratch()[来源]

从输出目录中的最新检查点恢复或加载指定的检查点。

如果 args.checkpoint_path 不是 None,则加载检查点,否则继续训练。

run()[来源]

实验设置后的常规操作。该方法旨在供用户使用。

run_align()[来源]

进行CTC对齐

run_export()[来源]

导出模型

run_test()[来源]

进行测试/解码

save(tag=None, infos: dict = None)[来源]

保存检查点(模型参数和优化器状态)。

Args:

标签 (int 或 str,可选): None 表示步长,否则使用标签,例如 epoch。默认值为 None。
infos (字典,可选): 要保存的元数据。默认值为 None。

setup()[来源]

设置实验。

setup_dataloader()[来源]

设置训练数据加载器和验证数据加载器。子类应实现此方法。

setup_model()[来源]

设置模型、标准和优化器等。子类应该实现此方法。

setup_output_dir()[来源]

创建一个用于输出的目录。

setup_visualizer()[来源]

初始化一个可视化工具来记录实验。

可视化日志保存在输出目录中。

备注

只有主进程有一个可视化工具。 在多进程中使用多个可视化工具写入同一个日志文件可能会导致意外行为。

test()[来源]

测试。子类应该在Tester中实现此方法。

property train
train_batch()[来源]

训练循环。子类应实现此方法。

valid()[来源]

验证。子类应该实现此方法。