plot_terminator_improvement

optuna.visualization.matplotlib.plot_terminator_improvement(study, plot_error=False, improvement_evaluator=None, error_evaluator=None, min_n_trials=20)[源代码]

绘制未来目标改进的潜力。

This function visualizes the objective improvement potentials, evaluated with improvement_evaluator. It helps to determine whether we should continue the optimization or not. You can also plot the error evaluated with error_evaluator if the plot_error argument is set to True. Note that this function may take some time to compute the improvement potentials.

参数:
  • study (Study) – 一个 Study 对象,其试验根据改进情况进行绘制。

  • plot_error (bool) – A flag to show the error. If it is set to True, errors evaluated by error_evaluator are also plotted as line graph. Defaults to False.

  • improvement_evaluator (BaseImprovementEvaluator | None) – 一个评估目标函数改进情况的对象。默认为 RegretBoundEvaluator

  • error_evaluator (BaseErrorEvaluator | None) – 一个评估目标函数中固有误差的对象。默认为 CrossValidationErrorEvaluator

  • min_n_trials (int) – 考虑终止前的最小试验次数。低于此值的试验终止器改进以较浅的颜色显示。默认为 20

返回:

一个 matplotlib.axes.Axes 对象。

返回类型:

Axes

备注

作为实验性功能添加于 v3.2.0。接口可能会在新版本中发生变化,恕不另行通知。参见 https://github.com/optuna/optuna/releases/tag/v3.2.0

以下代码片段展示了如何绘制改进潜力,以及交叉验证误差。

Terminator Improvement Plot
/Users/cw/baidu/code/fin_tool/github/optuna/docs/visualization_matplotlib_examples/optuna.visualization.matplotlib.terminator_improvement.py:41: ExperimentalWarning:

plot_terminator_improvement is experimental (supported from v3.2.0). The interface can change in the future.

/Users/cw/baidu/code/fin_tool/github/optuna/venv/lib/python3.11/site-packages/optuna/visualization/_terminator_improvement.py:93: ExperimentalWarning:

RegretBoundEvaluator is experimental (supported from v3.2.0). The interface can change in the future.

/Users/cw/baidu/code/fin_tool/github/optuna/venv/lib/python3.11/site-packages/optuna/visualization/_terminator_improvement.py:98: ExperimentalWarning:

CrossValidationErrorEvaluator is experimental (supported from v3.2.0). The interface can change in the future.


  0%|          | 0/30 [00:00<?, ?it/s]
 37%|███▋      | 11/30 [00:00<00:00, 102.43it/s]
 73%|███████▎  | 22/30 [00:00<00:00, 88.12it/s]
100%|██████████| 30/30 [00:00<00:00, 81.20it/s]

<Axes: title={'center': 'Terminator Improvement Plot'}, xlabel='Trial', ylabel='Terminator Improvement'>

from lightgbm import LGBMClassifier
from sklearn.datasets import load_wine
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
import optuna
from optuna.terminator import report_cross_validation_scores
from optuna.visualization.matplotlib import plot_terminator_improvement


def objective(trial):
    X, y = load_wine(return_X_y=True)
    clf = LGBMClassifier(
        reg_alpha=trial.suggest_float("reg_alpha", 1e-8, 10.0, log=True),
        reg_lambda=trial.suggest_float("reg_lambda", 1e-8, 10.0, log=True),
        num_leaves=trial.suggest_int("num_leaves", 2, 256),
        colsample_bytree=trial.suggest_float("colsample_bytree", 0.4, 1.0),
        subsample=trial.suggest_float("subsample", 0.4, 1.0),
        subsample_freq=trial.suggest_int("subsample_freq", 1, 7),
        min_child_samples=trial.suggest_int("min_child_samples", 5, 100),
    )
    scores = cross_val_score(clf, X, y, cv=KFold(n_splits=5, shuffle=True))
    report_cross_validation_scores(trial, scores)
    return scores.mean()


study = optuna.create_study()
study.optimize(objective, n_trials=30)

plot_terminator_improvement(study, plot_error=True)

脚本总运行时间: (0 分钟 5.679 秒)

由 Sphinx-Gallery 生成的图库