教程:分类微调¶
让我们通过一个快速示例来微调DSPy程序中的语言模型权重。我们将应用于一个简单的77分类任务。
我们微调后的程序将使用一个微型的 Llama-3.2-1B 语言模型,该模型托管在您的GPU上。为了使这更有趣,我们假设 (i) 我们没有任何训练标签,但 (ii) 我们有500个未标记的训练样本。
安装依赖并下载数据¶
通过 pip install -U dspy 安装最新版本的 DSPy 并跟随操作(如果你喜欢,也可以用 uv pip)。本教程需要 DSPy >= 2.6.0。你还需要运行 pip install datasets。
本教程目前需要本地GPU进行推理,不过我们计划也支持通过ollama服务运行微调模型。
您还需要以下依赖项:
- 推理: 我们使用SGLang来运行本地推理服务器。您可以通过以下说明安装最新版本: https://docs.sglang.ai/start/install.html
以下分享的是截至2025年04月02日的最新安装命令,但我们建议您通过访问安装链接来遵循最新版本的说明。
这确保微调包和
sglang包保持同步。> pip install --upgrade pip > pip install uv > uv pip install "sglang[all]>=0.4.4.post3" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python
- 微调: 我们使用以下软件包。请注意,我们指定了transformers包的版本作为最近问题的临时修复: https://github.com/huggingface/trl/issues/2338
> uv pip install -U torch transformers==4.48.3 accelerate trl peft
我们推荐使用uv包管理器来加速安装过程。
推荐:设置MLflow追踪以了解底层运行情况。
MLflow DSPy 集成¶
MLflow 是一个与 DSPy 原生集成的 LLMOps 工具,提供可解释性和实验追踪功能。在本教程中,您可以使用 MLflow 将提示和优化进度可视化为追踪记录,以更好地理解 DSPy 的行为。您只需按照以下四个步骤即可轻松设置 MLflow。

- 安装 MLflow
%pip install mlflow>=2.20
- 在单独的终端中启动 MLflow UI
mlflow ui --port 5000
- 将笔记本连接到 MLflow
import mlflow
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("DSPy")
- 启用追踪。
mlflow.dspy.autolog()
要了解更多关于集成的信息,请访问MLflow DSPy Documentation。
数据集¶
在本教程中,我们将使用Banking77数据集。
import dspy
import random
from dspy.datasets import DataLoader
from datasets import load_dataset
# Load the Banking77 dataset.
CLASSES = load_dataset("PolyAI/banking77", split="train", trust_remote_code=True).features['label'].names
kwargs = dict(fields=("text", "label"), input_keys=("text",), split="train", trust_remote_code=True)
# Load the first 2000 examples from the dataset, and assign a hint to each *training* example.
raw_data = [
dspy.Example(x, label=CLASSES[x.label]).with_inputs("text")
for x in DataLoader().from_huggingface(dataset_name="PolyAI/banking77", **kwargs)[:1000]
]
random.Random(0).shuffle(raw_data)
该数据集包含77个不同的分类类别。让我们回顾其中一些。
len(CLASSES), CLASSES[:10]
(77, ['activate_my_card', 'age_limit', 'apple_pay_or_google_pay', 'atm_support', 'automatic_top_up', 'balance_not_updated_after_bank_transfer', 'balance_not_updated_after_cheque_or_cash_deposit', 'beneficiary_not_allowed', 'cancel_transfer', 'card_about_to_expire'])
让我们从Banking77中采样500个(未标记的)查询。我们将使用这些进行我们的自举微调。
unlabeled_trainset = [dspy.Example(text=x.text).with_inputs("text") for x in raw_data[:500]]
unlabeled_trainset[0]
Example({'text': 'What if there is an error on the exchange rate?'}) (input_keys={'text'})
DSPy 程序¶
假设我们想要一个程序,它可以接收text并逐步推理,然后从Banking77中选择一个类别。
请注意,这主要用于演示,或者当您想要检查模型的推理过程时,例如为了获得一定程度的可解释性。换句话说,这种类型的任务不一定能从显式推理中获益很多。
from typing import Literal
classify = dspy.ChainOfThought(f"text -> label: Literal{CLASSES}")
引导式微调¶
实现这一目标有多种方式,例如允许模型自我学习或使用推理时计算(如集成学习)来识别无标签情况下的高置信度案例。
也许最简单的方法是使用一个我们期望能在此任务中表现良好的模型作为推理和分类的教师,并将其知识提炼到我们的小模型中。所有这些模式都可以用几行代码来表达。
让我们将微型 Llama-3.2-1B-Instruct 设置为学生语言模型。我们将使用 GPT-4o-mini 作为教师语言模型。
from dspy.clients.lm_local import LocalProvider
student_lm_name = "meta-llama/Llama-3.2-1B-Instruct"
student_lm = dspy.LM(model=f"openai/local:{student_lm_name}", provider=LocalProvider(), max_tokens=2000)
teacher_lm = dspy.LM('openai/gpt-4o-mini', max_tokens=3000)
现在,让我们为我们的语言模型分配分类器。
student_classify = classify.deepcopy()
student_classify.set_lm(student_lm)
teacher_classify = classify.deepcopy()
teacher_classify.set_lm(teacher_lm)
现在让我们启动引导式微调。这里的"引导式"意味着程序本身将在训练输入上被调用,并且所有模块上观察到的结果追踪将被记录并用于微调。这是DSPy中各种BootstrapFewShot方法的权重优化变体。
在(未标记的)训练集中的每个问题上,这将调用教师程序,该程序将生成推理并选择一个类别。这将被追踪,然后构成学生程序中所有模块(在这种情况下,仅一个CoT模块)的训练集。
当调用compile方法时,BootstrapFinetune优化器将使用传入的教师程序(或多个程序,您可以传递一个列表!)来创建训练数据集。
然后它将使用这个训练数据集为student程序创建一个微调版本的LM,并用训练后的LM替换它。
请注意,训练后的LM将是一个新的LM实例(我们在此实例化的student_lm对象将保持不变!)
注意:如果您有标签,可以向 BootstrapFinetune 的构造函数传递 metric。如果您想在实践中应用此功能,可以向构造函数传递 train_kwargs 来控制本地语言模型训练设置:device、use_peft、num_train_epochs、per_device_train_batch_size、gradient_accumulation_steps、learning_rate、max_seq_length、packing、bf16 和 output_dir。
# Optional:
# [1] You can set `DSPY_FINETUNEDIR` environment variable to control where the directory that will be used to store the
# checkpoints and fine-tuning data. If this is not set, `DSPY_CACHEDIR` is used by default.
# [2] You can set the `CUDA_VISIBLE_DEVICES` environment variable to control the GPU that will be used for fine-tuning
# and inference. If this is not set and the default GPU that's used by HuggingFace's `transformers` library is
# occupied, an OutOfMemoryError might be raised.
#
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# os.environ["DSPY_FINETUNEDIR"] = "/path/to/dir"
dspy.settings.experimental = True # fine-tuning is an experimental feature, so we set a flag to enable it
optimizer = dspy.BootstrapFinetune(num_threads=16) # if you *do* have labels, pass metric=your_metric here!
classify_ft = optimizer.compile(student_classify, teacher=teacher_classify, trainset=unlabeled_trainset)
由于这是一个本地模型,我们需要明确启动它。
classify_ft.get_lm().launch()
验证微调后的程序¶
现在我们来确认这是否成功。我们可以向系统提出一个问题并观察其行为。
classify_ft(text="I didn't receive my money earlier and it says the transaction is still in progress. Can you fix it?")
Prediction(
reasoning='The user is inquiring about a specific issue, which they did not receive and is still showing as a pending transaction. This situation typically indicates a problem with the cash withdrawal process, as the user is not receiving the money they attempted to withdraw. The appropriate label for this scenario is "pending_cash_withdrawal," as it directly relates to the status of the cash withdrawal transaction.',
label='pending_cash_withdrawal'
)
我们也可以获取一小部分黄金标签,看看系统是否能泛化到未见过的查询。
devset = raw_data[500:600]
devset[0]
Example({'text': 'Which fiat currencies do you currently support? Will this change in this future?', 'label': 'fiat_currency_support'}) (input_keys={'text'})
让我们在这个小型开发集上定义一个评估器,其中指标忽略推理过程,只检查标签是否完全正确。
metric = (lambda x, y, trace=None: x.label == y.label)
evaluate = dspy.Evaluate(devset=devset, metric=metric, display_progress=True, display_table=5, num_threads=16)
现在,我们来评估微调后的1B分类器。
evaluate(classify_ft)
Average Metric: 51.00 / 99 (51.5%): 100%|██████████| 100/100 [00:35<00:00, 2.79it/s]
| 文本 | 示例标签 | 推理 | 预测标签 | 标签 | ||
|---|---|---|---|---|---|---|
| 0 | 目前支持哪些法定货币?这会改变吗... | fiat_currency_support | 用户正在询问当前对法定货币的支持情况... | fiat_currency_support | ✔️ [True] | NaN |
| 1 | 我之前没有收到钱,它显示交易状态为s... | pending_cash_withdrawal | 用户正在询问一个具体问题,即他们没有收到... | pending_cash_withdrawal | ✔️ [True] | NaN |
| 2 | 你们接受哪些货币? | fiat_currency_support | 用户正在询问接受的货币类型,这... | fiat_currency_support | ✔️ [True] | NaN |
| 3 | 我在哪里可以找到你们的汇率? | exchange_rate | 用户正在询问在哪里可以找到汇率,这涉及到... | exchange_rate | ✔️ [True] | NaN |
| 4 | 为什么我的卡还没到? | card_arrival | 用户正在询问他们的卡的状态,这表明... | card_arrival | ✔️ [True] | NaN |
51.0
在MLflow实验中跟踪评估结果
为了跟踪并随时间可视化评估结果,您可以将结果记录到 MLflow 实验中。
import mlflow
with mlflow.start_run(run_name="classifier_evaluation"):
evaluate_correctness = dspy.Evaluate(
devset=devset,
metric=extraction_correctness_metric,
num_threads=16,
display_progress=True,
)
# Evaluate the program as usual
result = evaluate_correctness(people_extractor)
# Log the aggregated score
mlflow.log_metric("exact_match", result.score)
# Log the detailed evaluation results as a table
mlflow.log_table(
{
"Text": [example.text for example in devset],
"Expected": [example.example_label for example in devset],
"Predicted": [output[1] for output in result.results],
"Exact match": [output[2] for output in result.results],
},
artifact_file="eval_results.json",
)
要了解更多关于集成的信息,请访问MLflow DSPy Documentation。
还不错,考虑到我们一开始没有这个任务的标签。即使没有标签,你也可以使用各种策略来提高引导训练数据的质量。
接下来尝试,让我们通过终止微调语言模型来释放GPU内存。
classify_ft.get_lm().kill()
基于指标进行自举微调¶
如果你有标签,通常可以大幅提升效果。为此,你可以传递一个metric给BootstrapFinetune,它会在构建微调数据之前使用该指标来筛选程序中的轨迹。
optimizer = dspy.BootstrapFinetune(num_threads=16, metric=metric)
classify_ft = optimizer.compile(student_classify, teacher=teacher_classify, trainset=raw_data[:500])
现在让我们启动并评估这个。
classify_ft.get_lm().launch()
evaluate(classify_ft)
Average Metric: 85.00 / 98 (86.7%): 100%|██████████| 100/100 [00:46<00:00, 2.14it/s]
| 文本 | 示例标签 | 推理 | 预测标签 | <lambda> | 标签 | |
|---|---|---|---|---|---|---|
| 0 | 你们目前支持哪些法币?这会改变吗... | fiat_currency_support | 用户正在询问目前支持的法币... | fiat_currency_support | ✔️ [True] | NaN |
| 1 | 我之前没有收到钱,它显示交易状态为s... | pending_cash_withdrawal | 用户正在询问其账户上出现的一笔意外费用,该... | extra_charge_on_statement | NaN | |
| 2 | 你们接受哪些货币? | fiat_currency_support | 用户正在询问接受的货币类型... | fiat_currency_support | ✔️ [True] | NaN |
| 3 | 我在哪里可以找到你们的汇率? | exchange_rate | 用户正在询问在哪里可以找到汇率,这涉及到... | exchange_rate | ✔️ [True] | NaN |
| 4 | 为什么我的卡还没到? | card_arrival | 用户正在查询他们的卡片投递状态,这... | card_arrival | ✔️ [True] | NaN |
85.0
仅用500个标签,效果就好多了。事实上,它似乎比教师语言模型开箱即用的效果强得多!
evaluate(teacher_classify)
Average Metric: 55.00 / 100 (55.0%): 100%|██████████| 100/100 [00:11<00:00, 8.88it/s]
2025/01/08 12:38:35 INFO dspy.evaluate.evaluate: Average Metric: 55 / 100 (55.0%)
| 文本 | 示例标签 | 推理 | 预测标签 | ||
|---|---|---|---|---|---|
| 0 | 你们目前支持哪些法币?这会改变吗... | fiat_currency_support | 用户正在询问系统支持的法币... | fiat_currency_support | ✔️ [True] |
| 1 | 我之前没有收到钱,并且显示交易状态为s... | pending_cash_withdrawal | 用户遇到一笔仍在处理中的交易问题... | pending_transfer | |
| 2 | 你们接受哪些货币? | fiat_currency_support | 这个问题询问的是接受的货币类型,这... | fiat_currency_support | ✔️ [True] |
| 3 | 我在哪里可以找到你们的汇率? | exchange_rate | 用户正在询问在哪里可以找到汇率,这涉及到... | exchange_rate | ✔️ [True] |
| 4 | 为什么我的卡还没到? | card_arrival | 用户正在询问他们的卡片投递状态,这... | card_delivery_estimate |
55.0
得益于自举机制,模型学会应用我们的模块来获取正确的标签,在这个例子中,明确地进行推理:
classify_ft(text="why hasnt my card come in yet?")
dspy.inspect_history()
[2025-01-08T12:39:42.143798] System message: Your input fields are: 1. `text` (str) Your output fields are: 1. `reasoning` (str) 2. `label` (Literal[activate_my_card, age_limit, apple_pay_or_google_pay, atm_support, automatic_top_up, balance_not_updated_after_bank_transfer, balance_not_updated_after_cheque_or_cash_deposit, beneficiary_not_allowed, cancel_transfer, card_about_to_expire, card_acceptance, card_arrival, card_delivery_estimate, card_linking, card_not_working, card_payment_fee_charged, card_payment_not_recognised, card_payment_wrong_exchange_rate, card_swallowed, cash_withdrawal_charge, cash_withdrawal_not_recognised, change_pin, compromised_card, contactless_not_working, country_support, declined_card_payment, declined_cash_withdrawal, declined_transfer, direct_debit_payment_not_recognised, disposable_card_limits, edit_personal_details, exchange_charge, exchange_rate, exchange_via_app, extra_charge_on_statement, failed_transfer, fiat_currency_support, get_disposable_virtual_card, get_physical_card, getting_spare_card, getting_virtual_card, lost_or_stolen_card, lost_or_stolen_phone, order_physical_card, passcode_forgotten, pending_card_payment, pending_cash_withdrawal, pending_top_up, pending_transfer, pin_blocked, receiving_money, Refund_not_showing_up, request_refund, reverted_card_payment?, supported_cards_and_currencies, terminate_account, top_up_by_bank_transfer_charge, top_up_by_card_charge, top_up_by_cash_or_cheque, top_up_failed, top_up_limits, top_up_reverted, topping_up_by_card, transaction_charged_twice, transfer_fee_charged, transfer_into_account, transfer_not_received_by_recipient, transfer_timing, unable_to_verify_identity, verify_my_identity, verify_source_of_funds, verify_top_up, virtual_card_not_working, visa_or_mastercard, why_verify_identity, wrong_amount_of_cash_received, wrong_exchange_rate_for_cash_withdrawal]) All interactions will be structured in the following way, with the appropriate values filled in. [[ ## text ## ]] {text} [[ ## reasoning ## ]] {reasoning} [[ ## label ## ]] {label} # note: the value you produce must be one of: activate_my_card; age_limit; apple_pay_or_google_pay; atm_support; automatic_top_up; balance_not_updated_after_bank_transfer; balance_not_updated_after_cheque_or_cash_deposit; beneficiary_not_allowed; cancel_transfer; card_about_to_expire; card_acceptance; card_arrival; card_delivery_estimate; card_linking; card_not_working; card_payment_fee_charged; card_payment_not_recognised; card_payment_wrong_exchange_rate; card_swallowed; cash_withdrawal_charge; cash_withdrawal_not_recognised; change_pin; compromised_card; contactless_not_working; country_support; declined_card_payment; declined_cash_withdrawal; declined_transfer; direct_debit_payment_not_recognised; disposable_card_limits; edit_personal_details; exchange_charge; exchange_rate; exchange_via_app; extra_charge_on_statement; failed_transfer; fiat_currency_support; get_disposable_virtual_card; get_physical_card; getting_spare_card; getting_virtual_card; lost_or_stolen_card; lost_or_stolen_phone; order_physical_card; passcode_forgotten; pending_card_payment; pending_cash_withdrawal; pending_top_up; pending_transfer; pin_blocked; receiving_money; Refund_not_showing_up; request_refund; reverted_card_payment?; supported_cards_and_currencies; terminate_account; top_up_by_bank_transfer_charge; top_up_by_card_charge; top_up_by_cash_or_cheque; top_up_failed; top_up_limits; top_up_reverted; topping_up_by_card; transaction_charged_twice; transfer_fee_charged; transfer_into_account; transfer_not_received_by_recipient; transfer_timing; unable_to_verify_identity; verify_my_identity; verify_source_of_funds; verify_top_up; virtual_card_not_working; visa_or_mastercard; why_verify_identity; wrong_amount_of_cash_received; wrong_exchange_rate_for_cash_withdrawal [[ ## completed ## ]] In adhering to this structure, your objective is: Given the fields `text`, produce the fields `label`. User message: [[ ## text ## ]] why hasnt my card come in yet? Respond with the corresponding output fields, starting with the field `[[ ## reasoning ## ]]`, then `[[ ## label ## ]]` (must be formatted as a valid Python Literal[activate_my_card, age_limit, apple_pay_or_google_pay, atm_support, automatic_top_up, balance_not_updated_after_bank_transfer, balance_not_updated_after_cheque_or_cash_deposit, beneficiary_not_allowed, cancel_transfer, card_about_to_expire, card_acceptance, card_arrival, card_delivery_estimate, card_linking, card_not_working, card_payment_fee_charged, card_payment_not_recognised, card_payment_wrong_exchange_rate, card_swallowed, cash_withdrawal_charge, cash_withdrawal_not_recognised, change_pin, compromised_card, contactless_not_working, country_support, declined_card_payment, declined_cash_withdrawal, declined_transfer, direct_debit_payment_not_recognised, disposable_card_limits, edit_personal_details, exchange_charge, exchange_rate, exchange_via_app, extra_charge_on_statement, failed_transfer, fiat_currency_support, get_disposable_virtual_card, get_physical_card, getting_spare_card, getting_virtual_card, lost_or_stolen_card, lost_or_stolen_phone, order_physical_card, passcode_forgotten, pending_card_payment, pending_cash_withdrawal, pending_top_up, pending_transfer, pin_blocked, receiving_money, Refund_not_showing_up, request_refund, reverted_card_payment?, supported_cards_and_currencies, terminate_account, top_up_by_bank_transfer_charge, top_up_by_card_charge, top_up_by_cash_or_cheque, top_up_failed, top_up_limits, top_up_reverted, topping_up_by_card, transaction_charged_twice, transfer_fee_charged, transfer_into_account, transfer_not_received_by_recipient, transfer_timing, unable_to_verify_identity, verify_my_identity, verify_source_of_funds, verify_top_up, virtual_card_not_working, visa_or_mastercard, why_verify_identity, wrong_amount_of_cash_received, wrong_exchange_rate_for_cash_withdrawal]), and then ending with the marker for `[[ ## completed ## ]]`. Response: [[ ## reasoning ## ]] The user is inquiring about the status of their card delivery, which suggests they are concerned about when they will receive their card. This aligns with the topic of card arrival and delivery estimates. [[ ## label ## ]] card_arrival [[ ## completed ## ]]
在MLflow实验中保存微调后的程序
要在生产环境中部署微调后的程序或与团队分享,您可以将其保存在MLflow实验中。相比简单地保存到本地文件,MLflow提供以下优势:
- 依赖管理: MLflow自动保存冻结的环境元数据与程序,确保可复现性。
- 实验追踪: 使用 MLflow,您可以追踪程序的性能和成本,同时追踪程序本身。
- 协作: 你可以通过共享MLflow实验,与团队成员分享程序和结果。
要将程序保存到MLflow中,请运行以下代码:
import mlflow
# Start an MLflow Run and save the program
with mlflow.start_run(run_name="optimized_classifier"):
model_info = mlflow.dspy.log_model(
classify_ft,
artifact_path="model", # Any name to save the program in MLflow
)
# Load the program back from MLflow
loaded = mlflow.dspy.load_model(model_info.model_uri)
要了解更多关于集成的信息,请访问MLflow DSPy Documentation。