paddlespeech.t2s.training.experiment 模块

class paddlespeech.t2s.training.experiment.ExperimentBase(config, args)[来源]

基础: object

一个实验模板,用于构建训练代码,并处理保存、加载、日志记录和可视化等事务。它旨在灵活和简单。

因此,它仅以标准方式处理输出目录(为输出创建目录,创建检查点目录,转储正在使用的配置并创建可视化程序和记录器),而不强加任何输入输出协议给模型和数据加载器。它留给用户主要部分来实现他们自己的(设置模型、标准、优化器,定义训练步骤,定义验证函数并自定义所有文本和视觉日志)。

它并没有保存太多的样板代码。用户仍然需要手动编写 前向/后向/更新,但他们可以自由地添加 非标准行为(如果需要的话)。

我们有一些规范需要遵循。 1. 实验应该具有 model, optimizer, train_loadervalid_loader, configargs 属性。 2. config 应该有一个 training 字段,其中包含 valid_interval, save_intervalmax_iteration 键。它被 用作触发验证、检查点和停止实验的条件。 3. 有四个方法,分别是 train_batch, valid, setup_modelsetup_dataloader,应当被实现。

如有需要,请随意添加/覆盖其他方法和独立函数。

Args:
config (yacs.config.CfgNode):

用于实验的配置。

args (argparse.Namespace):

解析的命令行参数。

Examples:
>>> def main_sp(config, args):
>>>     exp = Experiment(config, args)
>>>     exp.setup()
>>>     exe.resume_or_load()
>>>     exp.run()
>>>
>>> config = get_cfg_defaults()
>>> parser = default_argument_parser()
>>> args = parser.parse_args()
>>> if args.config:
>>>     config.merge_from_file(args.config)
>>> if args.opts:
>>>     config.merge_from_list(args.opts)
>>> config.freeze()
>>>
>>> if args.ngpu > 1:
>>>     dist.spawn(main_sp, args=(config, args), nprocs=args.ngpu)
>>> else:
>>>     main_sp(config, args)
Attributes:
parallel

一个标志,指示实验是否应该使用多进程运行。

方法

close()

关闭可视化以避免训练后停滞

dump_config()

保存此次实验使用的配置。

init_parallel()

初始化多进程训练的环境。

new_epoch()

重置训练加载器并递增epoch

read_batch()

从 train_loader 中读取一批数据。

resume_or_load()

从输出目录中的最新检查点恢复或加载指定的检查点。

run()

实验设置后的常规流程。

save()

保存检查点(模型参数和优化器状态)。

setup()

设置实验。

setup_checkpointer()

创建一个用于保存检查点的目录。

setup_dataloader()

设置训练数据加载器和验证数据加载器。

setup_logger()

初始化一个文本日志记录器以记录实验。

setup_model()

设置模型、标准和优化器等。

setup_output_dir()

创建一个用于输出的目录。

setup_visualizer()

初始化一个可视化工具以记录实验。

train()

训练过程。

train_batch()

训练循环。

valid()

验证。

close()[来源]

关闭可视化工具以避免训练后出现挂起

dump_config()[来源]

保存用于此实验的配置。

它被保存在实验开始时的输出目录中的 config.yaml 文件里。

init_parallel()[来源]

初始化多进程训练的环境。

new_epoch()[来源]

重置训练加载器并增加 epoch

property parallel

一个标志,用于指示实验是否应该使用多进程运行。

read_batch()[来源]

从 train_loader 读取一批数据。

Returns:
List[Tensor]

一批。

resume_or_load()[来源]

从输出目录中的最新检查点恢复或加载指定的检查点。

如果 args.checkpoint_path 不是 None,则加载检查点,否则继续训练。

run()[来源]

实验设置后的常规操作。该方法旨在供用户使用。

save()[来源]

保存检查点(模型参数和优化器状态)。

setup()[来源]

设置实验。

setup_checkpointer()[来源]

创建一个用于保存检查点的目录。

它是“checkpoints”在输出目录中。

setup_dataloader()[来源]

设置训练数据加载器和验证数据加载器。子类应实现此方法。

setup_logger()[来源]

初始化一个文本记录器以记录实验。

每个进程都有自己的文本日志记录器。日志消息被写入标准输出和一个名为 worker_n.log 的文本文件,该文件位于输出目录中,其中 n 表示进程的排名。

setup_model()[来源]

设置模型、标准和优化器等。子类应该实现此方法。

setup_output_dir()[来源]

创建一个用于输出的目录。

setup_visualizer()[来源]

初始化一个可视化工具来记录实验。

可视化日志保存在输出目录中。

备注

只有主进程有一个可视化工具。 在多进程中使用多个可视化工具写入同一个日志文件可能会导致意外行为。

train()[来源]

训练过程。

它包括前向/后向/更新以及周期性验证和保存。

train_batch()[来源]

训练循环。子类应实现此方法。

valid()[来源]

验证。子类应该实现此方法。