高效的优化算法

Optuna 通过采用最先进的算法来采样超参数并高效地修剪无希望的试验,从而实现高效的超参数优化。

采样算法

采样器基本上通过使用建议的参数值和评估的目标值的记录不断缩小搜索空间,从而得到一个最优的搜索空间,该空间给出的参数能够导致更好的目标值。关于采样器如何建议参数的更详细解释,请参见 BaseSampler

Optuna 提供了以下采样算法:

默认的采样器是 TPESampler

切换采样器

import optuna

默认情况下,Optuna 使用 TPESampler 如下。

study = optuna.create_study()
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is TPESampler

如果你想使用不同的采样器,例如 RandomSamplerCmaEsSampler

study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")

study = optuna.create_study(sampler=optuna.samplers.CmaEsSampler())
print(f"Sampler is {study.sampler.__class__.__name__}")
Sampler is RandomSampler
Sampler is CmaEsSampler

剪枝算法

Pruners 自动在训练的早期阶段停止没有希望的试验(也称为,自动提前停止)。目前 pruners 模块仅用于单目标优化。

Optuna 提供了以下剪枝算法:

我们在大多数示例中使用 MedianPruner ,尽管基本上它在 此基准测试结果 中被 SuccessiveHalvingPrunerHyperbandPruner 超越。

激活修剪器

要开启剪枝功能,你需要在每次迭代训练后调用 report()should_prune()report() 定期监控中间目标值。should_prune() 决定是否终止不满足预定义条件的试验。

我们建议使用主要机器学习框架的集成模块。专属列表是 integration ,用例可在 optuna-examples 中找到。

import logging
import sys

import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection


def objective(trial):
    iris = sklearn.datasets.load_iris()
    classes = list(set(iris.target))
    train_x, valid_x, train_y, valid_y = sklearn.model_selection.train_test_split(
        iris.data, iris.target, test_size=0.25, random_state=0
    )

    alpha = trial.suggest_float("alpha", 1e-5, 1e-1, log=True)
    clf = sklearn.linear_model.SGDClassifier(alpha=alpha)

    for step in range(100):
        clf.partial_fit(train_x, train_y, classes=classes)

        # Report intermediate objective value.
        intermediate_value = 1.0 - clf.score(valid_x, valid_y)
        trial.report(intermediate_value, step)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.TrialPruned()

    return 1.0 - clf.score(valid_x, valid_y)

将中位数停止规则设置为剪枝条件。

# Add stream handler of stdout to show the messages
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
A new study created in memory with name: no-name-d9669a40-92e8-48b3-8f88-10a3c8e64d9d
Trial 0 finished with value: 0.07894736842105265 and parameters: {'alpha': 0.0551745944061346}. Best is trial 0 with value: 0.07894736842105265.
Trial 1 finished with value: 0.3157894736842105 and parameters: {'alpha': 0.07529961817949117}. Best is trial 0 with value: 0.07894736842105265.
Trial 2 finished with value: 0.10526315789473684 and parameters: {'alpha': 0.00042287098414869526}. Best is trial 0 with value: 0.07894736842105265.
Trial 3 finished with value: 0.1842105263157895 and parameters: {'alpha': 0.015878143093143834}. Best is trial 0 with value: 0.07894736842105265.
Trial 4 finished with value: 0.052631578947368474 and parameters: {'alpha': 0.0011147141800536124}. Best is trial 4 with value: 0.052631578947368474.
Trial 5 pruned.
Trial 6 finished with value: 0.23684210526315785 and parameters: {'alpha': 0.00028368188817897057}. Best is trial 4 with value: 0.052631578947368474.
Trial 7 pruned.
Trial 8 pruned.
Trial 9 pruned.
Trial 10 pruned.
Trial 11 finished with value: 0.42105263157894735 and parameters: {'alpha': 3.1208111475219473e-05}. Best is trial 4 with value: 0.052631578947368474.
Trial 12 finished with value: 0.13157894736842102 and parameters: {'alpha': 0.0028248230488248194}. Best is trial 4 with value: 0.052631578947368474.
Trial 13 pruned.
Trial 14 pruned.
Trial 15 pruned.
Trial 16 finished with value: 0.052631578947368474 and parameters: {'alpha': 0.0010632358706935596}. Best is trial 4 with value: 0.052631578947368474.
Trial 17 finished with value: 0.13157894736842102 and parameters: {'alpha': 0.0009099655410245386}. Best is trial 4 with value: 0.052631578947368474.
Trial 18 pruned.
Trial 19 pruned.

如你所见,几个试验在完成所有迭代之前就被修剪(停止)了。消息的格式是 "Trial <试验编号> pruned."

应该使用哪种采样器和修剪器?

optuna/optuna - wiki “Benchmarks with Kurobako” 提供的基准测试结果来看,至少对于非深度学习任务,我们可以说

然而,请注意该基准测试不是深度学习。对于深度学习任务,请参考下表。此表来自 Ozaki 等人,超参数优化方法:概述与特性,IEICE Trans,Vol.J103-D No.9 pp.615-631, 2020 论文,该论文是用日语撰写的。

并行计算资源

分类/条件超参数

推荐算法

有限

TPE。如果搜索空间是低维且连续的,则使用 GP-EI。

TPE。如果搜索空间是低维且连续的,则使用GP-EI。

充分的

CMA-ES, 随机搜索

随机搜索或遗传算法

剪枝的集成模块

为了以更简单的方式实现剪枝机制,Optuna 提供了以下库的集成模块。

有关 Optuna 集成模块的完整列表,请参阅 integration

例如,LightGBMPruningCallback 引入了无需直接改变训练迭代逻辑的剪枝功能。(另见 示例 以获取完整脚本。)

import optuna.integration

pruning_callback = optuna.integration.LightGBMPruningCallback(trial, 'validation-error')
gbm = lgb.train(param, dtrain, valid_sets=[dvalid], callbacks=[pruning_callback])

脚本总运行时间: (0 分钟 2.676 秒)

由 Sphinx-Gallery 生成的图库