跳过内容

使用OpenAI和Pydantic进行文本分类

本教程展示了如何使用OpenAI API和Pydantic模型实现文本分类任务——特别是单标签和多标签分类。完整的示例,请查看我们的单分类多标签分类示例。

动机

文本分类是许多自然语言处理应用中的常见问题,例如垃圾邮件检测或支持票证分类。目标是提供一种系统的方法,使用OpenAI的GPT模型结合Python数据结构来处理这些情况。

单标签分类

定义结构

对于单标签分类,我们定义了一个带有Literal字段的Pydantic模型,用于可能的标签。

字面量与枚举

我们更喜欢使用Literal类型而不是enum来分类标签。Literals提供了更好的类型检查,并且与Pydantic模型一起使用更加直接。

少样本示例

在模型的文档字符串中包含少量示例对于提高模型的分类准确性至关重要。这些示例指导AI理解任务和预期输出。

如果你想了解更多提示技巧,请查看我们的提示指南

思维链

使用Chain of Thought已被证明可以将预测质量提高约10%

from pydantic import BaseModel, Field
from typing import Literal
from openai import OpenAI
import instructor

# Apply the patch to the OpenAI client
# enables response_model keyword
client = instructor.from_openai(OpenAI())


class ClassificationResponse(BaseModel):
    """
    A few-shot example of text classification:

    Examples:
    - "Buy cheap watches now!": SPAM
    - "Meeting at 3 PM in the conference room": NOT_SPAM
    - "You've won a free iPhone! Click here": SPAM
    - "Can you pick up some milk on your way home?": NOT_SPAM
    - "Increase your followers by 10000 overnight!": SPAM
    """

    chain_of_thought: str = Field(
        ...,
        description="The chain of thought that led to the prediction.",
    )
    label: Literal["SPAM", "NOT_SPAM"] = Field(
        ...,
        description="The predicted class label.",
    )

文本分类

函数 classify 将执行单标签分类。

def classify(data: str) -> ClassificationResponse:
    """Perform single-label classification on the input text."""
    return client.chat.completions.create(
        model="gpt-4o-mini",
        response_model=ClassificationResponse,
        messages=[
            {
                "role": "user",
                "content": f"Classify the following text: <text>{data}</text>",
            },
        ],
    )

测试与评估

让我们运行示例,看看它是否能正确识别垃圾邮件和非垃圾邮件。

if __name__ == "__main__":
    for text, label in [
        ("Hey Jason! You're awesome", "NOT_SPAM"),
        ("I am a nigerian prince and I need your help.", "SPAM"),
    ]:
        prediction = classify(text)
        assert prediction.label == label
        print(f"Text: {text}, Predicted Label: {prediction.label}")
        #> Text: Hey Jason! You're awesome, Predicted Label: NOT_SPAM
        #> Text: I am a nigerian prince and I need your help., Predicted Label: SPAM

多标签分类

定义结构

对于多标签分类,我们将更新我们的方法,使用Literals而不是enums,并在模型的docstring中包含少量示例。

from typing import List
from pydantic import BaseModel, Field
from typing import Literal


class MultiClassPrediction(BaseModel):
    """
    Class for a multi-class label prediction.

    Examples:
    - "My account is locked": ["TECH_ISSUE"]
    - "I can't access my billing info": ["TECH_ISSUE", "BILLING"]
    - "When do you close for holidays?": ["GENERAL_QUERY"]
    - "My payment didn't go through and now I can't log in": ["BILLING", "TECH_ISSUE"]
    """

    chain_of_thought: str = Field(
        ...,
        description="The chain of thought that led to the prediction.",
    )

    class_labels: List[Literal["TECH_ISSUE", "BILLING", "GENERAL_QUERY"]] = Field(
        ...,
        description="The predicted class labels for the support ticket.",
    )

文本分类

函数 multi_classify 负责多标签分类。

import instructor
from openai import OpenAI

client = instructor.from_openai(OpenAI())


def multi_classify(data: str) -> MultiClassPrediction:
    """Perform multi-label classification on the input text."""
    return client.chat.completions.create(
        model="gpt-4o-mini",
        response_model=MultiClassPrediction,
        messages=[
            {
                "role": "user",
                "content": f"Classify the following support ticket: <ticket>{data}</ticket>",
            },
        ],
    )

测试与评估

最后,我们使用一个示例支持票来测试多标签分类功能。

# Test multi-label classification
ticket = "My account is locked and I can't access my billing info."
prediction = multi_classify(ticket)
assert "TECH_ISSUE" in prediction.class_labels
assert "BILLING" in prediction.class_labels
print(f"Ticket: {ticket}")
#> Ticket: My account is locked and I can't access my billing info.
print(f"Predicted Labels: {prediction.class_labels}")
#> Predicted Labels: ['TECH_ISSUE', 'BILLING']

通过使用Literals并包含少量示例,我们改进了单标签和多标签分类的实现。这些更改增强了类型安全性,并为AI模型提供了更好的指导,可能带来更准确的分类。