2022年7月11日

使用嵌入进行分类

,

文本分类有多种方法。本笔记本展示了使用嵌入进行文本分类的示例。对于许多文本分类任务,我们发现微调模型的表现优于嵌入方法。关于微调模型用于分类的示例,请参阅Fine-tuned_classification.ipynb。我们还建议样本数量应多于嵌入维度,但在此示例中我们并未完全实现这一点。

在这个文本分类任务中,我们根据评论文本的嵌入向量来预测食品评论的评分(1到5分)。为了真实评估模型在未见数据上的表现,我们将数据集划分为训练集和测试集用于后续所有任务。该数据集是在Get_embeddings_from_dataset Notebook中创建的。

import pandas as pd
import numpy as np
from ast import literal_eval

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score

datafile_path = "data/fine_food_reviews_with_embeddings_1k.csv"

df = pd.read_csv(datafile_path)
df["embedding"] = df.embedding.apply(literal_eval).apply(np.array)  # convert string to array

# split data into train and test
X_train, X_test, y_train, y_test = train_test_split(
    list(df.embedding.values), df.Score, test_size=0.2, random_state=42
)

# train random forest classifier
clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
probas = clf.predict_proba(X_test)

report = classification_report(y_test, preds)
print(report)
              precision    recall  f1-score   support

           1       0.90      0.45      0.60        20
           2       1.00      0.38      0.55         8
           3       1.00      0.18      0.31        11
           4       0.88      0.26      0.40        27
           5       0.76      1.00      0.86       134

    accuracy                           0.78       200
   macro avg       0.91      0.45      0.54       200
weighted avg       0.81      0.78      0.73       200

我们可以看到,模型已经学会了合理区分不同类别。5星评价整体表现最佳,这并不令人意外,因为它们在数据集中最为常见。

from utils.embeddings_utils import plot_multiclass_precision_recall

plot_multiclass_precision_recall(probas, y_test, [1, 2, 3, 4, 5], clf)
RandomForestClassifier() - Average precision score over all classes: 0.90
image generated by notebook

不出所料,5星和1星评价似乎更容易预测。也许有了更多数据后,2-4星之间的细微差别可以更好地预测,但人们在使用中间评分时可能也存在更多主观性。