注意
Go to the end 下载完整示例代码。
文本文档中的主题分类示例#
此示例展示了在训练分类器之前如何平衡文本数据。
请注意,对于这个例子,数据稍微不平衡,但对于某些数据集,不平衡比例可能更为显著。
# Authors: Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: MIT
print(__doc__)
设置数据集#
我们通过加载4个主题来使用20个新闻组数据集的一部分。使用scikit-learn加载器,数据被分为训练集和测试集。
注意,类别#3是少数类别,其样本数量几乎比多数类别少一半。
from sklearn.datasets import fetch_20newsgroups
categories = [
"alt.atheism",
"talk.religion.misc",
"comp.graphics",
"sci.space",
]
newsgroups_train = fetch_20newsgroups(subset="train", categories=categories)
newsgroups_test = fetch_20newsgroups(subset="test", categories=categories)
X_train = newsgroups_train.data
X_test = newsgroups_test.data
y_train = newsgroups_train.target
y_test = newsgroups_test.target
Training class distributions summary: Counter({2: 593, 1: 584, 0: 480, 3: 377})
Test class distributions summary: Counter({2: 394, 1: 389, 0: 319, 3: 251})
常见的scikit-learn流程#
你通常可能会通过结合TF-IDF向量器来使用scikit-learn管道,以供给多项式朴素贝叶斯分类器。分类报告总结了测试集上的结果。
正如预期的那样,类别#3的召回率较低,主要是由于类别不平衡。
from imblearn.metrics import classification_report_imbalanced
print(classification_report_imbalanced(y_test, y_pred))
pre rec spe f1 geo iba sup
0 0.67 0.94 0.86 0.79 0.90 0.82 319
1 0.96 0.92 0.99 0.94 0.95 0.90 389
2 0.87 0.98 0.94 0.92 0.96 0.92 394
3 0.97 0.36 1.00 0.52 0.60 0.33 251
avg / total 0.87 0.84 0.94 0.82 0.88 0.78 1353
分类前的类别平衡#
为了提高对类别#3的预测,在训练朴素贝叶斯分类器之前应用平衡可能会很有趣。因此,我们将在训练前使用RandomUnderSampler
来均衡所有类别的样本数量。
同样重要的是要注意,我们正在使用
make_pipeline
函数,该函数在
imbalanced-learn 中实现,以正确处理采样器。
from imblearn.pipeline import make_pipeline as make_pipeline_imb
from imblearn.under_sampling import RandomUnderSampler
model = make_pipeline_imb(TfidfVectorizer(), RandomUnderSampler(), MultinomialNB())
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
尽管结果几乎相同,但可以看出,重采样允许以降低其他类别的其他指标为代价来纠正类别#3的低召回率。然而,总体结果略好一些。
print(classification_report_imbalanced(y_test, y_pred))
pre rec spe f1 geo iba sup
0 0.69 0.91 0.87 0.78 0.89 0.79 319
1 0.97 0.85 0.99 0.91 0.92 0.83 389
2 0.96 0.88 0.98 0.92 0.93 0.85 394
3 0.80 0.73 0.96 0.76 0.84 0.69 251
avg / total 0.87 0.85 0.95 0.85 0.90 0.80 1353
脚本的总运行时间: (0 分钟 11.735 秒)
预计内存使用量: 306 MB