.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/mixture/plot_gmm_covariances.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. or to run this example in your browser via Binder .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_mixture_plot_gmm_covariances.py: =============== GMM 协方差 =============== 高斯混合模型的几种协方差类型的演示。 有关估计器的更多信息,请参见 :ref:`gmm` 。 尽管 GMM 通常用于聚类,但我们可以将获得的聚类与数据集中的实际类别进行比较。我们用训练集中类别的均值初始化高斯分布的均值,以使这种比较有效。 我们使用鸢尾花数据集上的各种 GMM 协方差类型绘制训练数据和保留的测试数据上的预测标签。我们按性能递增顺序比较具有球形、对角线、完整和绑定协方差矩阵的 GMM。尽管一般来说完整协方差的表现最好,但它容易在小数据集上过拟合,并且不能很好地泛化到保留的测试数据。 在图中,训练数据显示为点,而测试数据显示为叉。鸢尾花数据集是四维的。这里只显示前两个维度,因此某些点在其他维度上是分开的。 .. GENERATED FROM PYTHON SOURCE LINES 17-125 .. image-sg:: /auto_examples/mixture/images/sphx_glr_plot_gmm_covariances_001.png :alt: spherical, diag, tied, full :srcset: /auto_examples/mixture/images/sphx_glr_plot_gmm_covariances_001.png :class: sphx-glr-single-img .. code-block:: Python # 作者:scikit-learn 开发者 # SPDX-License-Identifier: BSD-3-Clause import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np from sklearn import datasets from sklearn.mixture import GaussianMixture from sklearn.model_selection import StratifiedKFold colors = ["navy", "turquoise", "darkorange"] def make_ellipses(gmm, ax): for n, color in enumerate(colors): if gmm.covariance_type == "full": covariances = gmm.covariances_[n][:2, :2] elif gmm.covariance_type == "tied": covariances = gmm.covariances_[:2, :2] elif gmm.covariance_type == "diag": covariances = np.diag(gmm.covariances_[n][:2]) elif gmm.covariance_type == "spherical": covariances = np.eye(gmm.means_.shape[1]) * gmm.covariances_[n] v, w = np.linalg.eigh(covariances) u = w[0] / np.linalg.norm(w[0]) angle = np.arctan2(u[1], u[0]) angle = 180 * angle / np.pi # convert to degrees v = 2.0 * np.sqrt(2.0) * np.sqrt(v) ell = mpl.patches.Ellipse( gmm.means_[n, :2], v[0], v[1], angle=180 + angle, color=color ) ell.set_clip_box(ax.bbox) ell.set_alpha(0.5) ax.add_artist(ell) ax.set_aspect("equal", "datalim") iris = datasets.load_iris() # 将数据集划分为互不重叠的训练集(75%)和测试集(25%)。 skf = StratifiedKFold(n_splits=4) # 只取第一折。 train_index, test_index = next(iter(skf.split(iris.data, iris.target))) X_train = iris.data[train_index] y_train = iris.target[train_index] X_test = iris.data[test_index] y_test = iris.target[test_index] n_classes = len(np.unique(y_train)) # 尝试使用不同类型的协方差矩阵进行高斯混合模型(GMM)分析。 estimators = { cov_type: GaussianMixture( n_components=n_classes, covariance_type=cov_type, max_iter=20, random_state=0 ) for cov_type in ["spherical", "diag", "tied", "full"] } n_estimators = len(estimators) plt.figure(figsize=(3 * n_estimators // 2, 6)) plt.subplots_adjust( bottom=0.01, top=0.95, hspace=0.15, wspace=0.05, left=0.01, right=0.99 ) for index, (name, estimator) in enumerate(estimators.items()): # 由于我们有训练数据的类别标签,我们可以以监督的方式初始化GMM参数。 estimator.means_init = np.array( [X_train[y_train == i].mean(axis=0) for i in range(n_classes)] ) # 使用EM算法训练其他参数。 estimator.fit(X_train) h = plt.subplot(2, n_estimators // 2, index + 1) make_ellipses(estimator, h) for n, color in enumerate(colors): data = iris.data[iris.target == n] plt.scatter( data[:, 0], data[:, 1], s=0.8, color=color, label=iris.target_names[n] ) # 用叉号绘制测试数据 for n, color in enumerate(colors): data = X_test[y_test == n] plt.scatter(data[:, 0], data[:, 1], marker="x", color=color) y_train_pred = estimator.predict(X_train) train_accuracy = np.mean(y_train_pred.ravel() == y_train.ravel()) * 100 plt.text(0.05, 0.9, "Train accuracy: %.1f" % train_accuracy, transform=h.transAxes) y_test_pred = estimator.predict(X_test) test_accuracy = np.mean(y_test_pred.ravel() == y_test.ravel()) * 100 plt.text(0.05, 0.8, "Test accuracy: %.1f" % test_accuracy, transform=h.transAxes) plt.xticks(()) plt.yticks(()) plt.title(name) plt.legend(scatterpoints=1, loc="lower right", prop=dict(size=12)) plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.117 seconds) .. _sphx_glr_download_auto_examples_mixture_plot_gmm_covariances.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: binder-badge .. image:: images/binder_badge_logo.svg :target: https://mybinder.org/v2/gh/scikit-learn/scikit-learn/main?urlpath=lab/tree/notebooks/auto_examples/mixture/plot_gmm_covariances.ipynb :alt: Launch binder :width: 150 px .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_gmm_covariances.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_gmm_covariances.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_gmm_covariances.zip ` .. include:: plot_gmm_covariances.recommendations .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_