网格搜索与交叉验证的自定义重拟合策略#

本示例展示了如何通过交叉验证优化分类器, 这是使用 GridSearchCV 对象 在仅包含一半可用标记数据的开发集上完成的。

然后在专门的评估集上测量所选超参数和训练模型的性能, 该评估集在模型选择步骤中未被使用。

有关模型选择工具的更多详细信息,请参见 交叉验证:评估估计器性能调整估计器的超参数 部分。

The dataset#

我们将使用 digits 数据集。目标是对手写数字图像进行分类。 为了便于理解,我们将问题转化为二分类问题:目标是识别一个数字是否为 8

from sklearn import datasets

digits = datasets.load_digits()

为了在图像上训练分类器,我们需要将它们展平为向量。 每个8x8像素的图像需要转换为64像素的向量。 因此,我们将得到一个形状为 (n_images, n_pixels) 的最终数据数组。

n_samples = len(digits.images)
X = digits.images.reshape((n_samples, -1))
y = digits.target == 8
print(
    f"The number of images is {X.shape[0]} and each image contains {X.shape[1]} pixels"
)
The number of images is 1797 and each image contains 64 pixels

正如介绍中所述,数据将被分成大小相等的训练集和测试集。

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)

定义我们的网格搜索策略#

我们将通过在训练集的折叠上搜索最佳超参数来选择分类器。为此,我们需要定义选择最佳候选者的评分标准。

scores = ["precision", "recall"]

我们还可以定义一个函数传递给 GridSearchCV 实例的 refit 参数。该函数将实现自定义策略,从 GridSearchCVcv_results_ 属性中选择最佳候选者。一旦候选者被选中,它将自动由 GridSearchCV 实例重新拟合。

在这里,策略是筛选出在精度和召回率方面表现最好的模型。从选定的模型中,我们最终选择预测速度最快的模型。请注意,这些自定义选择是完全任意的。

import pandas as pd


def print_dataframe(filtered_cv_results):
    """过滤后的数据框美化打印"""
    for mean_precision, std_precision, mean_recall, std_recall, params in zip(
        filtered_cv_results["mean_test_precision"],
        filtered_cv_results["std_test_precision"],
        filtered_cv_results["mean_test_recall"],
        filtered_cv_results["std_test_recall"],
        filtered_cv_results["params"],
    ):
        print(
            f"precision: {mean_precision:0.3f}{std_precision:0.03f}),"
            f" recall: {mean_recall:0.3f}{std_recall:0.03f}),"
            f" for {params}"
        )
    print()


def refit_strategy(cv_results):
    """定义选择最佳估计器的策略。

这里定义的策略是过滤掉所有低于0.98精度阈值的结果,按召回率对剩余结果进行排序,并保留召回率在最佳值一个标准差范围内的所有模型。一旦选定这些模型,我们可以选择预测最快的模型。

Parameters
----------
cv_results : dict of numpy (masked) ndarrays
    `GridSearchCV` 返回的交叉验证结果。

返回
-------
best_index : int
    最佳估计器在 `cv_results` 中的索引。
"""
    # 打印不同评分的网格搜索信息
    precision_threshold = 0.98

    cv_results_ = pd.DataFrame(cv_results)
    print("All grid-search results:")
    print_dataframe(cv_results_)

    # 过滤掉所有低于阈值的结果
    high_precision_cv_results = cv_results_[
        cv_results_["mean_test_precision"] > precision_threshold
    ]

    print(f"Models with a precision higher than {precision_threshold}:")
    print_dataframe(high_precision_cv_results)

    high_precision_cv_results = high_precision_cv_results[
        [
            "mean_score_time",
            "mean_test_recall",
            "std_test_recall",
            "mean_test_precision",
            "std_test_precision",
            "rank_test_recall",
            "rank_test_precision",
            "params",
        ]
    ]

    # 选择召回率表现最好的模型(在最佳模型的1个标准差范围内)
    best_recall_std = high_precision_cv_results["mean_test_recall"].std()
    best_recall = high_precision_cv_results["mean_test_recall"].max()
    best_recall_threshold = best_recall - best_recall_std

    high_recall_cv_results = high_precision_cv_results[
        high_precision_cv_results["mean_test_recall"] > best_recall_threshold
    ]
    print(
        "Out of the previously selected high precision models, we keep all the\n"
        "the models within one standard deviation of the highest recall model:"
    )
    print_dataframe(high_recall_cv_results)

    # 从最佳候选者中选择最快的模型进行预测
    fastest_top_recall_high_precision_index = high_recall_cv_results[
        "mean_score_time"
    ].idxmin()

    print(
        "\nThe selected final model is the fastest to predict out of the previously\n"
        "selected subset of best models based on precision and recall.\n"
        "Its scoring time is:\n\n"
        f"{high_recall_cv_results.loc[fastest_top_recall_high_precision_index]}"
    )

    return fastest_top_recall_high_precision_index

调整超参数#

一旦我们定义了选择最佳模型的策略,我们就定义超参数的值并创建网格搜索实例:

from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

tuned_parameters = [
    {"kernel": ["rbf"], "gamma": [1e-3, 1e-4], "C": [1, 10, 100, 1000]},
    {"kernel": ["linear"], "C": [1, 10, 100, 1000]},
]

grid_search = GridSearchCV(
    SVC(), tuned_parameters, scoring=scores, refit=refit_strategy
)
grid_search.fit(X_train, y_train)
All grid-search results:
precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.257 (±0.061), for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 0.968 (±0.039), recall: 0.780 (±0.083), for {'C': 10, 'gamma': 0.0001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 0.905 (±0.058), recall: 0.889 (±0.074), for {'C': 100, 'gamma': 0.0001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 0.904 (±0.058), recall: 0.890 (±0.073), for {'C': 1000, 'gamma': 0.0001, 'kernel': 'rbf'}
precision: 0.695 (±0.073), recall: 0.743 (±0.065), for {'C': 1, 'kernel': 'linear'}
precision: 0.643 (±0.066), recall: 0.757 (±0.066), for {'C': 10, 'kernel': 'linear'}
precision: 0.611 (±0.028), recall: 0.744 (±0.044), for {'C': 100, 'kernel': 'linear'}
precision: 0.618 (±0.039), recall: 0.744 (±0.044), for {'C': 1000, 'kernel': 'linear'}

Models with a precision higher than 0.98:
precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.257 (±0.061), for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}

Out of the previously selected high precision models, we keep all the
the models within one standard deviation of the highest recall model:
precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'}
precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}


The selected final model is the fastest to predict out of the previously
selected subset of best models based on precision and recall.
Its scoring time is:

mean_score_time                                            0.002186
mean_test_recall                                           0.877206
std_test_recall                                            0.069196
mean_test_precision                                             1.0
std_test_precision                                              0.0
rank_test_recall                                                  3
rank_test_precision                                               1
params                 {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}
Name: 6, dtype: object
GridSearchCV(estimator=SVC(),
             param_grid=[{'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001],
                          'kernel': ['rbf']},
                         {'C': [1, 10, 100, 1000], 'kernel': ['linear']}],
             refit=<function refit_strategy at 0xffff4efd5d00>,
             scoring=['precision', 'recall'])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


通过我们的自定义策略,网格搜索选择的参数是:

grid_search.best_params_
{'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'}

最后,我们在留出的评估集上评估微调后的模型: grid_search 对象 已自动重新拟合 到完整的训练集,并使用我们自定义的重新拟合策略选择的参数。

我们可以使用分类报告来计算留出集上的标准分类指标:

from sklearn.metrics import classification_report

y_pred = grid_search.predict(X_test)
print(classification_report(y_test, y_pred))
              precision    recall  f1-score   support

       False       0.99      1.00      0.99       807
        True       1.00      0.87      0.93        92

    accuracy                           0.99       899
   macro avg       0.99      0.93      0.96       899
weighted avg       0.99      0.99      0.99       899

问题太简单了:超参数平台过于平坦,输出模型在精度和召回率方面相同,质量上没有差异。

Total running time of the script: (0 minutes 6.010 seconds)

Related examples

平衡模型复杂性和交叉验证得分

平衡模型复杂性和交叉验证得分

随机搜索与网格搜索在超参数估计中的比较

随机搜索与网格搜索在超参数估计中的比较

高斯混合模型选择

高斯混合模型选择

识别手写数字

识别手写数字

Gallery generated by Sphinx-Gallery