注意
Go to the end 以下载完整的示例代码。
使用重采样比较集成分类器#
集成分类器与单一学习器相比,已显示出提高分类性能的优势。然而,它们会受到类别不平衡的影响。此示例展示了在学习学习器之前平衡训练集的好处。我们正在与非平衡集成方法进行比较。
我们使用平衡准确率和几何平均值进行比较,这些指标在文献中广泛用于评估在不平衡数据集上学习的模型。
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: MIT
print(__doc__)
加载一个不平衡的数据集#
我们将加载UCI SatImage数据集,该数据集的不平衡比例为9.3:1(多数样本与少数样本的数量比)。然后将数据分为训练集和测试集。
from sklearn.model_selection import train_test_split
from imblearn.datasets import fetch_datasets
satimage = fetch_datasets()["satimage"]
X, y = satimage.data, satimage.target
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)
使用单一决策树进行分类#
我们训练一个决策树分类器,它将作为本示例其余部分的基线。
结果以平衡准确率和几何平均数为报告,这些指标在文献中广泛用于验证在不平衡集上训练的模型。
from sklearn.tree import DecisionTreeClassifier
tree = DecisionTreeClassifier()
tree.fit(X_train, y_train)
y_pred_tree = tree.predict(X_test)
from sklearn.metrics import balanced_accuracy_score
from imblearn.metrics import geometric_mean_score
print("Decision tree classifier performance:")
print(
f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_tree):.2f} - "
f"Geometric mean {geometric_mean_score(y_test, y_pred_tree):.2f}"
)
Decision tree classifier performance:
Balanced accuracy: 0.75 - Geometric mean 0.72
import seaborn as sns
from sklearn.metrics import ConfusionMatrixDisplay
sns.set_context("poster")
disp = ConfusionMatrixDisplay.from_estimator(tree, X_test, y_test, colorbar=False)
_ = disp.ax_.set_title("Decision tree")

使用带采样和不带采样的bagging分类器进行分类#
我们将检查是否可以使用决策树的集合来缓解由类别不平衡引起的问题,而不是使用单一的树。首先,我们将使用一个装袋分类器及其内部使用随机欠采样来平衡每个自举样本的对应部分。
from sklearn.ensemble import BaggingClassifier
from imblearn.ensemble import BalancedBaggingClassifier
bagging = BaggingClassifier(n_estimators=50, random_state=0)
balanced_bagging = BalancedBaggingClassifier(n_estimators=50, random_state=0)
bagging.fit(X_train, y_train)
balanced_bagging.fit(X_train, y_train)
y_pred_bc = bagging.predict(X_test)
y_pred_bbc = balanced_bagging.predict(X_test)
平衡每个引导样本可以显著提高平衡准确度和几何平均值。
print("Bagging classifier performance:")
print(
f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_bc):.2f} - "
f"Geometric mean {geometric_mean_score(y_test, y_pred_bc):.2f}"
)
print("Balanced Bagging classifier performance:")
print(
f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_bbc):.2f} - "
f"Geometric mean {geometric_mean_score(y_test, y_pred_bbc):.2f}"
)
Bagging classifier performance:
Balanced accuracy: 0.73 - Geometric mean 0.68
Balanced Bagging classifier performance:
Balanced accuracy: 0.86 - Geometric mean 0.86
import matplotlib.pyplot as plt
fig, axs = plt.subplots(ncols=2, figsize=(10, 5))
ConfusionMatrixDisplay.from_estimator(
bagging, X_test, y_test, ax=axs[0], colorbar=False
)
axs[0].set_title("Bagging")
ConfusionMatrixDisplay.from_estimator(
balanced_bagging, X_test, y_test, ax=axs[1], colorbar=False
)
axs[1].set_title("Balanced Bagging")
fig.tight_layout()

使用随机森林分类器进行分类,包括采样和不采样#
随机森林是另一种流行的集成方法,通常表现优于装袋法。在这里,我们使用了一个普通的随机森林及其平衡版本,其中每个自举样本都是平衡的。
from sklearn.ensemble import RandomForestClassifier
from imblearn.ensemble import BalancedRandomForestClassifier
rf = RandomForestClassifier(n_estimators=50, random_state=0)
brf = BalancedRandomForestClassifier(
n_estimators=50,
sampling_strategy="all",
replacement=True,
bootstrap=False,
random_state=0,
)
rf.fit(X_train, y_train)
brf.fit(X_train, y_train)
y_pred_rf = rf.predict(X_test)
y_pred_brf = brf.predict(X_test)
与之前的实验类似,平衡分类器的表现优于从不平衡的bootstrap样本中学习的分类器。此外,随机森林的表现优于bagging分类器。
print("Random Forest classifier performance:")
print(
f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_rf):.2f} - "
f"Geometric mean {geometric_mean_score(y_test, y_pred_rf):.2f}"
)
print("Balanced Random Forest classifier performance:")
print(
f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_brf):.2f} - "
f"Geometric mean {geometric_mean_score(y_test, y_pred_brf):.2f}"
)
Random Forest classifier performance:
Balanced accuracy: 0.73 - Geometric mean 0.68
Balanced Random Forest classifier performance:
Balanced accuracy: 0.87 - Geometric mean 0.87
fig, axs = plt.subplots(ncols=2, figsize=(10, 5))
ConfusionMatrixDisplay.from_estimator(rf, X_test, y_test, ax=axs[0], colorbar=False)
axs[0].set_title("Random forest")
ConfusionMatrixDisplay.from_estimator(brf, X_test, y_test, ax=axs[1], colorbar=False)
axs[1].set_title("Balanced random forest")
fig.tight_layout()

Boosting分类器#
同样地,easy ensemble分类器是一组平衡的AdaBoost分类器。然而,它的训练速度会比随机森林慢,并且性能也会较差。
from sklearn.ensemble import AdaBoostClassifier
from imblearn.ensemble import EasyEnsembleClassifier, RUSBoostClassifier
estimator = AdaBoostClassifier(n_estimators=10, algorithm="SAMME")
eec = EasyEnsembleClassifier(n_estimators=10, estimator=estimator)
eec.fit(X_train, y_train)
y_pred_eec = eec.predict(X_test)
rusboost = RUSBoostClassifier(n_estimators=10, estimator=estimator)
rusboost.fit(X_train, y_train)
y_pred_rusboost = rusboost.predict(X_test)
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
/home/circleci/project/.pixi/envs/docs/lib/python3.12/site-packages/sklearn/ensemble/_weight_boosting.py:514: FutureWarning: The parameter 'algorithm' is deprecated in 1.6 and has no effect. It will be removed in version 1.8.
warnings.warn(
print("Easy ensemble classifier performance:")
print(
f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_eec):.2f} - "
f"Geometric mean {geometric_mean_score(y_test, y_pred_eec):.2f}"
)
print("RUSBoost classifier performance:")
print(
f"Balanced accuracy: {balanced_accuracy_score(y_test, y_pred_rusboost):.2f} - "
f"Geometric mean {geometric_mean_score(y_test, y_pred_rusboost):.2f}"
)
Easy ensemble classifier performance:
Balanced accuracy: 0.83 - Geometric mean 0.83
RUSBoost classifier performance:
Balanced accuracy: 0.50 - Geometric mean 0.00
fig, axs = plt.subplots(ncols=2, figsize=(10, 5))
ConfusionMatrixDisplay.from_estimator(eec, X_test, y_test, ax=axs[0], colorbar=False)
axs[0].set_title("Easy Ensemble")
ConfusionMatrixDisplay.from_estimator(
rusboost, X_test, y_test, ax=axs[1], colorbar=False
)
axs[1].set_title("RUSBoost classifier")
fig.tight_layout()
plt.show()

脚本的总运行时间: (0 分钟 11.758 秒)
预计内存使用量: 424 MB