2022年10月20日

交易的多类别分类

,

在本笔记本中,我们将对一个公开交易数据集进行分类,将其归入我们预先定义的多个类别。这些方法应可复用于任何多类别分类场景,特别是当我们试图将交易数据匹配到预定义类别时。通过完成本教程,您将掌握处理带标签和无标签数据集的几种方法。

我们将在本笔记本中采用的不同方法有:

  • 零样本分类: 首先我们将进行零样本分类,仅通过提示指导将交易归入五个已命名的类别之一
  • 使用嵌入进行分类: 接下来我们将在带标签的数据集上创建嵌入,然后使用传统分类模型来测试它们在识别我们类别方面的有效性
  • 微调分类: 最后我们将基于标注数据集训练一个微调模型,看看与零样本和小样本分类方法相比效果如何
%load_ext autoreload
%autoreload
%pip install openai 'openai[datalib]' 'openai[embeddings]' transformers scikit-learn matplotlib plotly pandas scipy
import openai
import pandas as pd
import numpy as np
import json
import os

COMPLETIONS_MODEL = "gpt-4"
os.environ["OPENAI_API_KEY"] = "<your-api-key>"
client = openai.OpenAI()
transactions = pd.read_csv('./data/25000_spend_dataset_current.csv', encoding= 'unicode_escape')
print(f"Number of transactions: {len(transactions)}")
print(transactions.head())
Number of transactions: 359
         Date                      Supplier                 Description  \
0  21/04/2016          M & J Ballantyne Ltd       George IV Bridge Work   
1  26/04/2016                  Private Sale   Literary & Archival Items   
2  30/04/2016     City Of Edinburgh Council         Non Domestic Rates    
3  09/05/2016              Computacenter Uk                 Kelvin Hall   
4  09/05/2016  John Graham Construction Ltd  Causewayside Refurbishment   

   Transaction value (£)  
0                35098.0  
1                30000.0  
2                40800.0  
3                72835.0  
4                64361.0  

零样本分类

我们将首先评估基础模型在使用简单提示对这些交易进行分类时的性能。我们会为模型提供5个类别,并为无法分类的交易设置一个"无法分类"的兜底选项。

zero_shot_prompt = '''You are a data expert working for the National Library of Scotland.
You are analysing all transactions over £25,000 in value and classifying them into one of five categories.
The five categories are Building Improvement, Literature & Archive, Utility Bills, Professional Services and Software/IT.
If you can't tell what it is, say Could not classify

Transaction:

Supplier: {}
Description: {}
Value: {}

The classification is:'''

def format_prompt(transaction):
    return zero_shot_prompt.format(transaction['Supplier'], transaction['Description'], transaction['Transaction value (£)'])

def classify_transaction(transaction):

    
    prompt = format_prompt(transaction)
    messages = [
        {"role": "system", "content": prompt},
    ]
    completion_response = openai.chat.completions.create(
                            messages=messages,
                            temperature=0,
                            max_tokens=5,
                            top_p=1,
                            frequency_penalty=0,
                            presence_penalty=0,
                            model=COMPLETIONS_MODEL)
    label = completion_response.choices[0].message.content.replace('\n','')
    return label
# Get a test transaction
transaction = transactions.iloc[0]
# Use our completion function to return a prediction
print(f"Transaction: {transaction['Supplier']} {transaction['Description']} {transaction['Transaction value (£)']}")
print(f"Classification: {classify_transaction(transaction)}")
Transaction: M & J Ballantyne Ltd George IV Bridge Work 35098.0
Classification: Building Improvement

我们的第一次尝试是正确的,M & J Ballantyne有限公司是一家房屋建筑商,他们从事的工作确实是建筑改进。

让我们将样本量扩大到25,看看效果如何,这次仍然只用一个简单的提示来引导它

test_transactions = transactions.iloc[:25]
test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x),axis=1)
/var/folders/3n/79rgh27s6l7_l91b9shw0_nr0000gp/T/ipykernel_81921/2775604370.py:2: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_transactions['Classification'] = test_transactions.apply(lambda x: classify_transaction(x),axis=1)
test_transactions['Classification'].value_counts()
Classification
Building Improvement    17
Literature & Archive     3
Software/IT              2
Could not classify       2
Utility Bills            1
Name: count, dtype: int64
test_transactions.head(25)
日期 供应商 描述 交易金额(£) 分类
0 2016年4月21日 M & J Ballantyne有限公司 乔治四世大桥工程 35098.0 建筑改造
1 2016年4月26日 私人拍卖 文学与档案物品 30000.0 文学与档案
2 2016年4月30日 爱丁堡市议会 非住宅房产税 40800.0 公用事业账单
3 2016/05/09 Computacenter 英国 凯尔文大厅 72835.0 软件/IT
4 2016年9月5日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 64361.0 建筑改善
5 2016年5月9日 A McGillivray Causewayside翻新工程 53690.0 建筑改善
6 2016年5月16日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 365344.0 建筑改善
7 2016年5月23日 Computacenter英国 凯尔文大厅 26506.0 软件/IT
8 2016年5月23日 ECG设施服务 设施管理费 32777.0 建筑改善
9 2016年5月23日 心电图设施服务 设施管理费 32777.0 建筑改善
10 2016年5月30日 ALDL ALDL费用 32317.0 无法分类
11 2016/10/06 Wavetek有限公司 凯尔文大厅 87589.0 建筑改造
12 2016年10月6日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 381803.0 建筑改善
13 2016/06/28 心电图设施服务 设施管理费 32832.0 建筑改善
14 2016年6月30日 格拉斯哥市议会 凯尔文大厅 1700000.0 建筑改善工程
15 2016/07/11 Wavetek有限公司 凯尔文大厅 65692.0 建筑改造
16 2016年11月7日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 139845.0 建筑改善
17 2016年7月15日 苏富比拍卖行 文学与档案物品 28500.0 文学与档案
18 2016年7月18日 佳士得 文学与档案物品 33800.0 文学与档案
19 2016年7月25日 A McGillivray Causewayside翻新工程 30113.0 建筑改善
20 2016/07/31 ALDL ALDL费用 32317.0 无法分类
21 2016年8月8日 心电图设施服务 设施管理费 32795.0 建筑改善
22 2016年8月15日 创意视频制作有限公司 凯尔文大厅 26866.0 建筑改善
23 2016年8月15日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 196807.0 建筑改善
24 2016年8月24日 心电图设施服务 设施管理费 32795.0 建筑改善

即使没有任何标注示例,初步结果已经相当不错!那些无法分类的案例属于线索较少、主题较难判断的棘手情况,但如果我们能清理标注数据集并提供更多示例,或许可以提升性能。

df = pd.read_csv('./data/labelled_transactions.csv')
df.head()
日期 供应商 描述 交易金额(£) 分类
0 2016年8月15日 创意视频制作有限公司 凯尔文大厅 26866 其他
1 2017年5月29日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 74806 建筑改善
2 2017年5月29日 Morris & Spottiswood有限公司 乔治四世大桥工程 56448 建筑改造
3 2017年5月31日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 164691 建筑改善
4 2017年7月24日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 27926 建筑改善
df['combined'] = "Supplier: " + df['Supplier'].str.strip() + "; Description: " + df['Description'].str.strip() + "; Value: " + str(df['Transaction value (£)']).strip()
df.head(2)
日期 供应商 描述 交易金额(£) 分类 合并项
0 2016年8月15日 创意视频制作有限公司 凯尔文大厅 26866 其他 供应商:创意视频制作有限公司;描述...
1 2017年5月29日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 74806 建筑改善 供应商:约翰·格雷厄姆建筑有限公司;描述...
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

df['n_tokens'] = df.combined.apply(lambda x: len(tokenizer.encode(x)))
len(df)
101
embedding_path = './data/transactions_with_embeddings_100.csv'
from utils.embeddings_utils import get_embedding
df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x))
df['babbage_search'] = df.combined.apply(lambda x: get_embedding(x))
df.to_csv(embedding_path)
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from ast import literal_eval

fs_df = pd.read_csv(embedding_path)
fs_df["babbage_similarity"] = fs_df.babbage_similarity.apply(literal_eval).apply(np.array)
fs_df.head()
Unnamed: 0 日期 供应商 描述 交易金额(£) 分类 combined n_tokens babbage_similarity babbage_search
0 0 2016年8月15日 创意视频制作有限公司 凯尔文大厅 26866 其他 供应商:创意视频制作有限公司;描述... 136 [-0.02898375503718853, -0.02881557121872902, 0... [-0.02879939414560795, -0.02867320366203785, 0...
1 1 2017年5月29日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 74806 建筑改善 供应商:约翰·格雷厄姆建筑有限公司;描述... 140 [-0.024112487211823463, -0.02881261520087719, ... [-0.024112487211823463, -0.02881261520087719, ...
2 2 2017年5月29日 Morris & Spottiswood有限公司 乔治四世大桥工程 56448 建筑改善 供应商:Morris & Spottiswood有限公司;描述... 141 [0.013581369072198868, -0.003978211898356676, ... [0.013593776151537895, -0.0037341134157031775,...
3 3 2017年5月31日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 164691 建筑改善 供应商:约翰·格雷厄姆建筑有限公司;描述... 140 [-0.024112487211823463, -0.02881261520087719, ... [-0.024112487211823463, -0.02881261520087719, ...
4 4 2017年7月24日 John Graham建筑有限公司 Causewayside翻新工程 27926 建筑改善 供应商:John Graham建筑有限公司;描述... 140 [-0.02408558875322342, -0.02881370671093464, 0... [-0.024109570309519768, -0.02880912832915783, ...
X_train, X_test, y_train, y_test = train_test_split(
    list(fs_df.babbage_similarity.values), fs_df.Classification, test_size=0.2, random_state=42
)

clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)
preds = clf.predict(X_test)
probas = clf.predict_proba(X_test)

report = classification_report(y_test, preds)
print(report)
                      precision    recall  f1-score   support

Building Improvement       0.92      1.00      0.96        11
Literature & Archive       1.00      1.00      1.00         3
               Other       0.00      0.00      0.00         1
         Software/IT       1.00      1.00      1.00         1
       Utility Bills       1.00      1.00      1.00         5

            accuracy                           0.95        21
           macro avg       0.78      0.80      0.79        21
        weighted avg       0.91      0.95      0.93        21

/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
/Users/vishnu/code/openai-cookbook/cookbook_env/lib/python3.11/site-packages/sklearn/metrics/_classification.py:1565: UndefinedMetricWarning: Precision is ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))

该模型的性能相当强大,因此创建嵌入向量并使用更简单的分类器似乎也是一种有效的方法,而零样本分类器可以帮助我们对未标记数据集进行初始分类。

让我们更进一步,看看在相同标注数据集上训练的微调模型是否能给出可比的结果

微调交易分类

对于这个用例,我们将尝试通过在同一组标记的101笔交易上训练一个微调模型,并将该微调模型应用于一组未见过的交易,从而改进上述的少样本分类。

构建微调分类器

我们首先需要进行一些数据准备工作以使数据就绪。这将包括以下步骤:

  • 为了准备我们的训练集和验证集,我们将创建一组消息序列。每条序列的第一条消息将是包含交易详情的用户提示格式化内容,最后一条消息则是模型预期的分类响应。
  • 我们的测试集将包含每笔交易的初始用户提示,以及相应的预期类别标签。然后我们将使用微调后的模型为每笔交易生成实际的分类。
ft_prep_df = fs_df.copy()
len(ft_prep_df)
101
ft_prep_df.head()
Unnamed: 0 日期 供应商 描述 交易金额(£) 分类 combined n_tokens babbage_similarity babbage_search
0 0 2016年8月15日 创意视频制作有限公司 凯尔文大厅 26866 其他 供应商:创意视频制作有限公司;描述... 136 [-0.028885245323181152, -0.028660893440246582,... [-0.02879939414560795, -0.02867320366203785, 0...
1 1 2017年5月29日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 74806 建筑改善 供应商:约翰·格雷厄姆建筑有限公司;描述... 140 [-0.024112487211823463, -0.02881261520087719, ... [-0.02414606139063835, -0.02883070334792137, 0...
2 2 2017年5月29日 Morris & Spottiswood有限公司 乔治四世大桥工程 56448 建筑改善 供应商:Morris & Spottiswood有限公司;描述... 141 [0.013593776151537895, -0.0037341134157031775,... [0.013561442494392395, -0.004199974238872528, ...
3 3 2017年5月31日 约翰·格雷厄姆建筑有限公司 堤道边翻新工程 164691 建筑改善 供应商:约翰·格雷厄姆建筑有限公司;描述... 140 [-0.024112487211823463, -0.02881261520087719, ... [-0.024112487211823463, -0.02881261520087719, ...
4 4 2017年7月24日 约翰·格雷厄姆建筑有限公司 Causewayside翻新工程 27926 建筑改善 供应商:约翰·格雷厄姆建筑有限公司;描述... 140 [-0.024112487211823463, -0.02881261520087719, ... [-0.024112487211823463, -0.02881261520087719, ...
classes = list(set(ft_prep_df['Classification']))
class_df = pd.DataFrame(classes).reset_index()
class_df.columns = ['class_id','class']
class_df  , len(class_df)
(   class_id                 class
 0         0                 Other
 1         1  Literature & Archive
 2         2           Software/IT
 3         3         Utility Bills
 4         4  Building Improvement,
 5)
ft_df_with_class = ft_prep_df.merge(class_df,left_on='Classification',right_on='class',how='inner')

# Creating a list of messages for the fine-tuning job. The user message is the prompt, and the assistant message is the response from the model
ft_df_with_class['messages'] = ft_df_with_class.apply(lambda x: [{"role": "user", "content": format_prompt(x)}, {"role": "assistant", "content": x['class']}],axis=1)
ft_df_with_class[['messages', 'class']].head()
消息 类别
0 [{'role': 'user', 'content': '你是一个数据...' 其他
1 [{'role': 'user', 'content': '你是一个数据...'}] 建筑改进
2 [{'role': 'user', 'content': '你是一个数据... 建筑改进
3 [{'role': 'user', 'content': '你是一个数据... 建筑改进
4 [{'role': 'user', 'content': '你是一个数据...'}] 建筑改进
# Create train/validation split
samples = ft_df_with_class["messages"].tolist()
train_df, valid_df = train_test_split(samples, test_size=0.2, random_state=42)

def write_to_jsonl(list_of_messages, filename):
    with open(filename, "w+") as f:
        for messages in list_of_messages:
            object = {  
                "messages": messages
            }
            f.write(json.dumps(object) + "\n")
# Write the train/validation split to jsonl files
train_file_name, valid_file_name = "transactions_grouped_train.jsonl", "transactions_grouped_valid.jsonl"
write_to_jsonl(train_df, train_file_name)
write_to_jsonl(valid_df, valid_file_name)
# Upload the files to OpenAI
train_file = client.files.create(file=open(train_file_name, "rb"), purpose="fine-tune")
valid_file = client.files.create(file=open(valid_file_name, "rb"), purpose="fine-tune")
# Create the fine-tuning job
fine_tuning_job = client.fine_tuning.jobs.create(training_file=train_file.id, validation_file=valid_file.id, model="gpt-4o-2024-08-06")
# Get the fine-tuning job status and model name
status = client.fine_tuning.jobs.retrieve(fine_tuning_job.id)
# Once the fine-tuning job is complete, you can retrieve the model name from the job status
fine_tuned_model = client.fine_tuning.jobs.retrieve(fine_tuning_job.id).fine_tuned_model
print(f"Fine tuned model id: {fine_tuned_model}")
Fine tuned model id: ft:gpt-4o-2024-08-06:openai::BKr3Xy8U

应用微调分类器

现在我们将应用分类器来评估其性能。我们的训练集中仅有31个独特观测值,验证集中有8个,让我们看看表现如何

# Create a test set with the expected class labels
test_set = pd.read_json(valid_file_name, lines=True)
test_set['expected_class'] = test_set.apply(lambda x: x['messages'][-1]['content'], axis=1)
test_set.head()
消息 预期类别
0 [{'role': 'user', 'content': '你是一个数据...' 水电费账单
1 [{'role': 'user', 'content': '你是一个数据... 文学与档案
2 [{'role': 'user', 'content': '你是一个数据...'}] 文学与档案
3 [{'role': 'user', 'content': '你是一个数据...' 文学与档案
4 [{'role': 'user', 'content': '你是一个数据...'}] 建筑改进
# Apply the fine-tuned model to the test set
test_set['response'] = test_set.apply(lambda x: openai.chat.completions.create(model=fine_tuned_model, messages=x['messages'][:-1], temperature=0),axis=1)
test_set['predicted_class'] = test_set.apply(lambda x: x['response'].choices[0].message.content, axis=1)

test_set.head()
消息 预期类别 响应 预测类别
0 [{'role': 'user', 'content': '你是一个数据...'} 水电账单 ChatCompletion(id='chatcmpl-BKrC0S1wQSfM9ZQfcC... 水电账单
1 [{'role': 'user', 'content': '你是一个数据...' 文学与档案 ChatCompletion(id='chatcmpl-BKrC1BTr0DagbDkC2s... 文学与档案
2 [{'role': 'user', 'content': '你是一个数据...'}] 文学与档案 ChatCompletion(id='chatcmpl-BKrC1H3ZeIW5cz2Owr... 文学与档案
3 [{'role': 'user', 'content': '你是一个数据...' 文学与档案 ChatCompletion(id='chatcmpl-BKrC1wdhaMP0Q7YmYx... 文学与档案
4 [{'role': 'user', 'content': '你是一个数据...'} 建筑改进 ChatCompletion(id='chatcmpl-BKrC20c5pkpngy1xDu... 建筑改进
# Calculate the accuracy of the predictions
from sklearn.metrics import f1_score
test_set['result'] = test_set.apply(lambda x: str(x['predicted_class']).strip() == str(x['expected_class']).strip(), axis = 1)
test_set['result'].value_counts()

print(test_set['result'].value_counts())

print("F1 Score: ", f1_score(test_set['expected_class'], test_set['predicted_class'], average="weighted"))
print("Raw Accuracy: ", test_set['result'].value_counts()[True] / len(test_set))
result
True     20
False     1
Name: count, dtype: int64
F1 Score:  0.9296066252587991
Raw Accuracy:  0.9523809523809523