备注
前往结尾 下载完整示例代码。
高效的优化算法
Optuna 通过采用最先进的算法来采样超参数并高效地修剪无希望的试验,从而实现高效的超参数优化。
采样算法
采样器基本上通过使用建议的参数值和评估的目标值的记录不断缩小搜索空间,从而得到一个最优的搜索空间,该空间给出的参数能够导致更好的目标值。关于采样器如何建议参数的更详细解释,请参见 BaseSampler。
Optuna 提供了以下采样算法:
在
GridSampler中实现的网格搜索在
RandomSampler中实现的随机搜索在
TPESampler中实现的树结构Parzen估计器算法基于CMA-ES的算法在
CmaEsSampler中实现基于高斯过程的算法在
GPSampler中实现在
PartialFixedSampler中实现的启用部分固定参数的算法在
NSGAIISampler中实现的非支配排序遗传算法 II在
QMCSampler中实现的准蒙特卡罗采样算法
默认的采样器是 TPESampler。
切换采样器
import optuna
默认情况下,Optuna 使用 TPESampler 如下。
study = optuna.create_study()
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is TPESampler
如果你想使用不同的采样器,例如 RandomSampler 和 CmaEsSampler,
study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
study = optuna.create_study(sampler=optuna.samplers.CmaEsSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is RandomSampler
Sampler is CmaEsSampler
剪枝算法
Pruners 自动在训练的早期阶段停止没有希望的试验(也称为,自动提前停止)。目前 pruners 模块仅用于单目标优化。
Optuna 提供了以下剪枝算法:
在中实现的中间修剪算法
MedianPruner在
NopPruner中实现的非剪枝算法在
PatientPruner中实现的带容差的修剪器操作算法在
PercentilePruner中实现的修剪指定百分位试验的算法在
SuccessiveHalvingPruner中实现的异步连续减半算法在
HyperbandPruner中实现的 Hyperband 算法在
ThresholdPruner中实现的阈值剪枝算法基于 Wilcoxon 符号秩检验 的剪枝算法,在
WilcoxonPruner中实现
我们在大多数示例中使用 MedianPruner ,尽管基本上它在 此基准测试结果 中被 SuccessiveHalvingPruner 和 HyperbandPruner 超越。
激活修剪器
要开启剪枝功能,你需要在每次迭代训练后调用 report() 和 should_prune()。report() 定期监控中间目标值。should_prune() 决定是否终止不满足预定义条件的试验。
我们建议使用主要机器学习框架的集成模块。专属列表是 integration ,用例可在 optuna-examples 中找到。
import logging
import sys
import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection
def objective(trial):
iris = sklearn.datasets.load_iris()
classes = list(set(iris.target))
train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
iris.data, iris.target, test_size=0.25, random_state=0
)
alpha = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
clf = sklearn.linear_model.SGDClassifier(alpha=alpha)
for step in range(100):
clf.partial_fit(train_x, train_y, classes=classes)
# Report intermediate objective value.
intermediate_value = 1.0 - clf.score(valid_x, valid_y)
trial.report(intermediate_value, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.TrialPruned()
return 1.0 - clf.score(valid_x, valid_y)
将中位数停止规则设置为剪枝条件。
# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
A new study created in memory with name: no-name-d9669a40-92e8-48b3-8f88-10a3c8e64d9d
Trial 0 finished with value: 0.07894736842105265 and parameters: {'alpha': 0.0551745944061346}. Best is trial 0 with value: 0.07894736842105265.
Trial 1 finished with value: 0.3157894736842105 and parameters: {'alpha': 0.07529961817949117}. Best is trial 0 with value: 0.07894736842105265.
Trial 2 finished with value: 0.10526315789473684 and parameters: {'alpha': 0.00042287098414869526}. Best is trial 0 with value: 0.07894736842105265.
Trial 3 finished with value: 0.1842105263157895 and parameters: {'alpha': 0.015878143093143834}. Best is trial 0 with value: 0.07894736842105265.
Trial 4 finished with value: 0.052631578947368474 and parameters: {'alpha': 0.0011147141800536124}. Best is trial 4 with value: 0.052631578947368474.
Trial 5 pruned.
Trial 6 finished with value: 0.23684210526315785 and parameters: {'alpha': 0.00028368188817897057}. Best is trial 4 with value: 0.052631578947368474.
Trial 7 pruned.
Trial 8 pruned.
Trial 9 pruned.
Trial 10 pruned.
Trial 11 finished with value: 0.42105263157894735 and parameters: {'alpha': 3.1208111475219473e-05}. Best is trial 4 with value: 0.052631578947368474.
Trial 12 finished with value: 0.13157894736842102 and parameters: {'alpha': 0.0028248230488248194}. Best is trial 4 with value: 0.052631578947368474.
Trial 13 pruned.
Trial 14 pruned.
Trial 15 pruned.
Trial 16 finished with value: 0.052631578947368474 and parameters: {'alpha': 0.0010632358706935596}. Best is trial 4 with value: 0.052631578947368474.
Trial 17 finished with value: 0.13157894736842102 and parameters: {'alpha': 0.0009099655410245386}. Best is trial 4 with value: 0.052631578947368474.
Trial 18 pruned.
Trial 19 pruned.
如你所见,几个试验在完成所有迭代之前就被修剪(停止)了。消息的格式是 "Trial <试验编号> pruned."。
应该使用哪种采样器和修剪器?
从 optuna/optuna - wiki “Benchmarks with Kurobako” 提供的基准测试结果来看,至少对于非深度学习任务,我们可以说
对于
RandomSampler,MedianPruner是最好的。对于
TPESampler,HyperbandPruner是最好的。
然而,请注意该基准测试不是深度学习。对于深度学习任务,请参考下表。此表来自 Ozaki 等人,超参数优化方法:概述与特性,IEICE Trans,Vol.J103-D No.9 pp.615-631, 2020 论文,该论文是用日语撰写的。
并行计算资源 |
分类/条件超参数 |
推荐算法 |
|---|---|---|
有限 |
不 |
TPE。如果搜索空间是低维且连续的,则使用 GP-EI。 |
是 |
TPE。如果搜索空间是低维且连续的,则使用GP-EI。 |
|
充分的 |
不 |
CMA-ES, 随机搜索 |
是 |
随机搜索或遗传算法 |
剪枝的集成模块
为了以更简单的方式实现剪枝机制,Optuna 提供了以下库的集成模块。
有关 Optuna 集成模块的完整列表,请参阅 integration。
例如,LightGBMPruningCallback 引入了无需直接改变训练迭代逻辑的剪枝功能。(另见 示例 以获取完整脚本。)
import optuna.integration
pruning_callback = optuna.integration.LightGBMPruningCallback(trial, 'validation-error')
gbm = lgb.train(param, dtrain, valid_sets=[dvalid], callbacks=[pruning_callback])
脚本总运行时间: (0 分钟 2.676 秒)