.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/model_selection/plot_grid_search_digits.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_model_selection_plot_grid_search_digits.py: ============================================================ 网格搜索与交叉验证的自定义重拟合策略 ============================================================ 本示例展示了如何通过交叉验证优化分类器, 这是使用 :class:`~sklearn.model_selection.GridSearchCV` 对象 在仅包含一半可用标记数据的开发集上完成的。 然后在专门的评估集上测量所选超参数和训练模型的性能, 该评估集在模型选择步骤中未被使用。 有关模型选择工具的更多详细信息,请参见 :ref:`cross_validation` 和 :ref:`grid_search` 部分。 .. GENERATED FROM PYTHON SOURCE LINES 18-23 The dataset ----------- 我们将使用 `digits` 数据集。目标是对手写数字图像进行分类。 为了便于理解,我们将问题转化为二分类问题:目标是识别一个数字是否为 `8` 。 .. GENERATED FROM PYTHON SOURCE LINES 23-27 .. code-block:: Python from sklearn import datasets digits = datasets.load_digits() .. GENERATED FROM PYTHON SOURCE LINES 28-31 为了在图像上训练分类器,我们需要将它们展平为向量。 每个8x8像素的图像需要转换为64像素的向量。 因此,我们将得到一个形状为 `(n_images, n_pixels)` 的最终数据数组。 .. GENERATED FROM PYTHON SOURCE LINES 31-39 .. code-block:: Python n_samples = len(digits.images) X = digits.images.reshape((n_samples, -1)) y = digits.target == 8 print( f"The number of images is {X.shape[0]} and each image contains {X.shape[1]} pixels" ) .. rst-class:: sphx-glr-script-out .. code-block:: none The number of images is 1797 and each image contains 64 pixels .. GENERATED FROM PYTHON SOURCE LINES 40-41 正如介绍中所述,数据将被分成大小相等的训练集和测试集。 .. GENERATED FROM PYTHON SOURCE LINES 41-46 .. code-block:: Python from sklearn.model_selection import train_test_split X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0) .. GENERATED FROM PYTHON SOURCE LINES 47-51 定义我们的网格搜索策略 --------------------------- 我们将通过在训练集的折叠上搜索最佳超参数来选择分类器。为此,我们需要定义选择最佳候选者的评分标准。 .. GENERATED FROM PYTHON SOURCE LINES 51-54 .. code-block:: Python scores = ["precision", "recall"] .. GENERATED FROM PYTHON SOURCE LINES 55-58 我们还可以定义一个函数传递给 :class:`~sklearn.model_selection.GridSearchCV` 实例的 `refit` 参数。该函数将实现自定义策略,从 :class:`~sklearn.model_selection.GridSearchCV` 的 `cv_results_` 属性中选择最佳候选者。一旦候选者被选中,它将自动由 :class:`~sklearn.model_selection.GridSearchCV` 实例重新拟合。 在这里,策略是筛选出在精度和召回率方面表现最好的模型。从选定的模型中,我们最终选择预测速度最快的模型。请注意,这些自定义选择是完全任意的。 .. GENERATED FROM PYTHON SOURCE LINES 58-151 .. code-block:: Python import pandas as pd def print_dataframe(filtered_cv_results): """过滤后的数据框美化打印""" for mean_precision, std_precision, mean_recall, std_recall, params in zip( filtered_cv_results["mean_test_precision"], filtered_cv_results["std_test_precision"], filtered_cv_results["mean_test_recall"], filtered_cv_results["std_test_recall"], filtered_cv_results["params"], ): print( f"precision: {mean_precision:0.3f} (±{std_precision:0.03f})," f" recall: {mean_recall:0.3f} (±{std_recall:0.03f})," f" for {params}" ) print() def refit_strategy(cv_results): """定义选择最佳估计器的策略。 这里定义的策略是过滤掉所有低于0.98精度阈值的结果,按召回率对剩余结果进行排序,并保留召回率在最佳值一个标准差范围内的所有模型。一旦选定这些模型,我们可以选择预测最快的模型。 Parameters ---------- cv_results : dict of numpy (masked) ndarrays `GridSearchCV` 返回的交叉验证结果。 返回 ------- best_index : int 最佳估计器在 `cv_results` 中的索引。 """ # 打印不同评分的网格搜索信息 precision_threshold = 0.98 cv_results_ = pd.DataFrame(cv_results) print("All grid-search results:") print_dataframe(cv_results_) # 过滤掉所有低于阈值的结果 high_precision_cv_results = cv_results_[ cv_results_["mean_test_precision"] > precision_threshold ] print(f"Models with a precision higher than {precision_threshold}:") print_dataframe(high_precision_cv_results) high_precision_cv_results = high_precision_cv_results[ [ "mean_score_time", "mean_test_recall", "std_test_recall", "mean_test_precision", "std_test_precision", "rank_test_recall", "rank_test_precision", "params", ] ] # 选择召回率表现最好的模型(在最佳模型的1个标准差范围内) best_recall_std = high_precision_cv_results["mean_test_recall"].std() best_recall = high_precision_cv_results["mean_test_recall"].max() best_recall_threshold = best_recall - best_recall_std high_recall_cv_results = high_precision_cv_results[ high_precision_cv_results["mean_test_recall"] > best_recall_threshold ] print( "Out of the previously selected high precision models, we keep all the\n" "the models within one standard deviation of the highest recall model:" ) print_dataframe(high_recall_cv_results) # 从最佳候选者中选择最快的模型进行预测 fastest_top_recall_high_precision_index = high_recall_cv_results[ "mean_score_time" ].idxmin() print( "\nThe selected final model is the fastest to predict out of the previously\n" "selected subset of best models based on precision and recall.\n" "Its scoring time is:\n\n" f"{high_recall_cv_results.loc[fastest_top_recall_high_precision_index]}" ) return fastest_top_recall_high_precision_index .. GENERATED FROM PYTHON SOURCE LINES 152-156 调整超参数 ----------------------- 一旦我们定义了选择最佳模型的策略,我们就定义超参数的值并创建网格搜索实例: .. GENERATED FROM PYTHON SOURCE LINES 157-170 .. code-block:: Python from sklearn.model_selection import GridSearchCV from sklearn.svm import SVC tuned_parameters = [ {"kernel": ["rbf"], "gamma": [1e-3, 1e-4], "C": [1, 10, 100, 1000]}, {"kernel": ["linear"], "C": [1, 10, 100, 1000]}, ] grid_search = GridSearchCV( SVC(), tuned_parameters, scoring=scores, refit=refit_strategy ) grid_search.fit(X_train, y_train) .. rst-class:: sphx-glr-script-out .. code-block:: none All grid-search results: precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.257 (±0.061), for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'} precision: 0.968 (±0.039), recall: 0.780 (±0.083), for {'C': 10, 'gamma': 0.0001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'} precision: 0.905 (±0.058), recall: 0.889 (±0.074), for {'C': 100, 'gamma': 0.0001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'} precision: 0.904 (±0.058), recall: 0.890 (±0.073), for {'C': 1000, 'gamma': 0.0001, 'kernel': 'rbf'} precision: 0.695 (±0.073), recall: 0.743 (±0.065), for {'C': 1, 'kernel': 'linear'} precision: 0.643 (±0.066), recall: 0.757 (±0.066), for {'C': 10, 'kernel': 'linear'} precision: 0.611 (±0.028), recall: 0.744 (±0.044), for {'C': 100, 'kernel': 'linear'} precision: 0.618 (±0.039), recall: 0.744 (±0.044), for {'C': 1000, 'kernel': 'linear'} Models with a precision higher than 0.98: precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.257 (±0.061), for {'C': 1, 'gamma': 0.0001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'} Out of the previously selected high precision models, we keep all the the models within one standard deviation of the highest recall model: precision: 1.000 (±0.000), recall: 0.854 (±0.063), for {'C': 1, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 10, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 100, 'gamma': 0.001, 'kernel': 'rbf'} precision: 1.000 (±0.000), recall: 0.877 (±0.069), for {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'} The selected final model is the fastest to predict out of the previously selected subset of best models based on precision and recall. Its scoring time is: mean_score_time 0.002186 mean_test_recall 0.877206 std_test_recall 0.069196 mean_test_precision 1.0 std_test_precision 0.0 rank_test_recall 3 rank_test_precision 1 params {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'} Name: 6, dtype: object .. raw:: html
GridSearchCV(estimator=SVC(),
                 param_grid=[{'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001],
                              'kernel': ['rbf']},
                             {'C': [1, 10, 100, 1000], 'kernel': ['linear']}],
                 refit=<function refit_strategy at 0xffff4efd5d00>,
                 scoring=['precision', 'recall'])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.


.. GENERATED FROM PYTHON SOURCE LINES 171-172 通过我们的自定义策略,网格搜索选择的参数是: .. GENERATED FROM PYTHON SOURCE LINES 173-175 .. code-block:: Python grid_search.best_params_ .. rst-class:: sphx-glr-script-out .. code-block:: none {'C': 1000, 'gamma': 0.001, 'kernel': 'rbf'} .. GENERATED FROM PYTHON SOURCE LINES 176-179 最后,我们在留出的评估集上评估微调后的模型: `grid_search` 对象 **已自动重新拟合** 到完整的训练集,并使用我们自定义的重新拟合策略选择的参数。 我们可以使用分类报告来计算留出集上的标准分类指标: .. GENERATED FROM PYTHON SOURCE LINES 180-185 .. code-block:: Python from sklearn.metrics import classification_report y_pred = grid_search.predict(X_test) print(classification_report(y_test, y_pred)) .. rst-class:: sphx-glr-script-out .. code-block:: none precision recall f1-score support False 0.99 1.00 0.99 807 True 1.00 0.87 0.93 92 accuracy 0.99 899 macro avg 0.99 0.93 0.96 899 weighted avg 0.99 0.99 0.99 899 .. GENERATED FROM PYTHON SOURCE LINES 186-188 .. NOTE:: 问题太简单了:超参数平台过于平坦,输出模型在精度和召回率方面相同,质量上没有差异。 .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 6.010 seconds) .. _sphx_glr_download_auto_examples_model_selection_plot_grid_search_digits.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/model_selection/plot_grid_search_digits.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_grid_search_digits.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_grid_search_digits.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_grid_search_digits.zip ` .. include:: plot_grid_search_digits.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_