speechbrain.utils.hpopt 模块
用于超参数优化的实用工具。 此包装器对Oríon有可选的依赖。
https://orion.readthedocs.io/en/stable/ https://github.com/Epistimio/orion
- Authors
阿尔乔姆·普洛日尼科夫 2021
摘要
类:
一个通用的超参数拟合报告器,将结果以JSON格式输出到任意数据流,可以被第三方工具读取 |
|
一个方便的上下文管理器,使得有条件地为配方启用超参数优化成为可能。 |
|
超参数拟合报告器的基类 |
|
基于Orion的结果报告器实现 |
函数:
尝试获取由模式指定的报告器,如果不可用则回退到通用报告器 |
|
返回当前超参数优化试验的ID,主要用于实验文件夹的命名。 |
|
用于为超参数优化模式注册报告器实现的装饰器 |
|
初始化超参数优化上下文 |
|
如果可用,使用当前报告器报告结果。 |
参考
- speechbrain.utils.hpopt.hpopt_mode(mode)[source]
用于为超参数优化模式注册报告器实现的装饰器
- Parameters:
mode (str) – 要注册的模式
- Returns:
f – 一个可调用的函数,用于注册并返回报告器类
- Return type:
可调用的
Example
>>> @hpopt_mode("raw") ... class RawHyperparameterOptimizationReporter(HyperparameterOptimizationReporter): ... def __init__(self, *args, **kwargs): ... super().__init__( *args, **kwargs) ... def report_objective(self, result): ... objective = result[self.objective_key] ... print(f"Objective: {objective}")
>>> reporter = get_reporter("raw", objective_key="error") >>> result = {"error": 1.2, "train_loss": 7.2} >>> reporter.report_objective(result) Objective: 1.2
- class speechbrain.utils.hpopt.HyperparameterOptimizationReporter(objective_key)[source]
基础类:
object超参数拟合报告器的基类
- Parameters:
objective_key (str) – 用作目标的结果字典中的键
- property is_available
确定此报告器是否可用
- property trial_id
此试验的唯一ID(用于文件夹命名)
- class speechbrain.utils.hpopt.GenericHyperparameterOptimizationReporter(reference_date=None, output=None, *args, **kwargs)[source]
基础类:
HyperparameterOptimizationReporter一个通用的超参数拟合报告器,将结果以JSON格式输出到任意数据流中,可以被第三方工具读取
- Parameters:
reference_date (datetime.datetime) – 用于创建试验ID的日期
输出 (流) – 用于报告结果的流
*args (tuple) – 要传递给父类的参数
**kwargs (dict) – 要传递给父类的参数
- report_objective(result)[source]
报告超参数优化的目标。
- Parameters:
结果 (dict) – 包含运行结果的字典。
Example
>>> reporter = GenericHyperparameterOptimizationReporter( ... objective_key="error" ... ) >>> result = {"error": 1.2, "train_loss": 7.2} >>> reporter.report_objective(result) {"error": 1.2, "train_loss": 7.2, "objective": 1.2}
- property trial_id
此试验的唯一ID(主要用于文件夹命名)
Example
>>> import datetime >>> reporter = GenericHyperparameterOptimizationReporter( ... objective_key="error", ... reference_date=datetime.datetime(2021, 1, 3) ... ) >>> print(reporter.trial_id) 20210103000000000000
- class speechbrain.utils.hpopt.OrionHyperparameterOptimizationReporter(objective_key)[source]
基础类:
HyperparameterOptimizationReporter基于Orion的结果报告器实现
- Parameters:
objective_key (str) – 用作目标的结果字典中的键
- property trial_id
此试验的唯一ID(主要用于文件夹命名)
- property is_available
确定Orion是否可用。为了使其可用,需要安装库,并且至少需要设置ORION_EXPERIMENT_NAME、ORION_EXPERIMENT_VERSION、ORION_TRIAL_ID中的一个。
- speechbrain.utils.hpopt.get_reporter(mode, *args, **kwargs)[source]
尝试获取由模式指定的报告器,如果不可用则回退到通用报告器
- Parameters:
- Returns:
reporter – 一个报告实例
- Return type:
Example
>>> reporter = get_reporter("generic", objective_key="error") >>> result = {"error": 3.4, "train_loss": 1.2} >>> reporter.report_objective(result) {"error": 3.4, "train_loss": 1.2, "objective": 3.4}
- class speechbrain.utils.hpopt.HyperparameterOptimizationContext(reporter_args=None, reporter_kwargs=None)[source]
基础类:
object一个方便的上下文管理器,使得有条件地为配方启用超参数优化成为可能。
Example
>>> ctx = HyperparameterOptimizationContext( ... reporter_args=[], ... reporter_kwargs={"objective_key": "error"} ... )
- parse_arguments(arg_list, pass_hpopt_args=None, pass_trial_id=True)[source]
一个增强版的speechbrain.parse_arguments,用于超参数优化。
如果提供了名为‘hpopt’的参数,将启用超参数优化和报告。
如果参数值对应于一个文件名,它将被读取为一个hyperpyyaml文件,并且内容将被添加到“overrides”中。这对于在超参数优化期间与完整训练期间某些超参数值不同的情况非常有用(例如,训练轮数、保存文件等)。
- Parameters:
- Returns:
param_file (str) – 参数文件的位置。
run_opts (dict) – 运行选项,例如分布式、设备等。
overrides (dict) – 传递给
load_hyperpyyaml的覆盖项。
Example
>>> ctx = HyperparameterOptimizationContext() >>> arg_list = ["hparams.yaml", "--x", "1", "--y", "2"] >>> hparams_file, run_opts, overrides = ctx.parse_arguments(arg_list) >>> print(f"File: {hparams_file}, Overrides: {overrides}") File: hparams.yaml, Overrides: {'x': 1, 'y': 2}
- speechbrain.utils.hpopt.hyperparameter_optimization(*args, **kwargs)[source]
初始化超参数优化上下文
- Parameters:
- Return type:
Example
>>> import sys >>> with hyperparameter_optimization(objective_key="error", output=sys.stdout) as hp_ctx: ... result = {"error": 3.5, "train_loss": 2.1} ... report_result(result) ... {"error": 3.5, "train_loss": 2.1, "objective": 3.5}