测试与评估
使用PydanticAI和LLM集成,一般有两种不同类型的测试:
- 单元测试 — 你的应用程序代码的测试,以及它是否正常运行
- Evals — LLM的测试,以及它的响应有多好或多差
在大多数情况下,这两种测试有着相对独立的目标和考虑因素。
单元测试
PydanticAI代码的单元测试与其他任何Python代码的单元测试是一样的。
因为大多数情况下它们并没有什么新意,我们已经建立了相当成熟的工具和模式来编写和运行这些类型的测试。
除非你真的确定自己更了解,否则你可能想大致遵循这个策略:
- 使用
pytest作为您的测试工具 - 如果你发现自己在输入长的断言,可以使用 inline-snapshot
- 类似地, dirty-equals 对于比较大型数据结构也很有用
- 使用
TestModel或FunctionModel代替您的实际模型,以避免真实 LLM 调用的使用、延迟和变异性 - 使用
Agent.override在您的应用逻辑中替换模型 - 全局设置
ALLOW_MODEL_REQUESTS=False来阻止意外向非测试模型发起任何请求
使用 TestModel 进行单元测试
最简单和最快的方式来测试您大部分应用程序代码是使用 TestModel,这将(默认情况下)调用代理中的所有工具,然后根据代理的返回类型返回纯文本或结构化响应。
TestModel 不是魔法
“聪明”(但不太聪明)的部分在于TestModel会尝试根据注册工具的模式生成有效的结构化数据,用于function tools和result types。
在TestModel中没有机器学习或人工智能,这只是传统的过程式Python代码,试图生成满足工具JSON架构的数据。
生成的数据看起来可能不太美观或相关,但在大多数情况下,它应该通过Pydantic的验证。
如果你想要更复杂的东西,使用 FunctionModel 并编写你自己的数据生成逻辑。
我们来为以下应用程序代码编写单元测试:
import asyncio
from datetime import date
from pydantic_ai import Agent, RunContext
from fake_database import DatabaseConn # (1)!
from weather_service import WeatherService # (2)!
weather_agent = Agent(
'openai:gpt-4o',
deps_type=WeatherService,
system_prompt='Providing a weather forecast at the locations the user provides.',
)
@weather_agent.tool
def weather_forecast(
ctx: RunContext[WeatherService], location: str, forecast_date: date
) -> str:
if forecast_date < date.today(): # (3)!
return ctx.deps.get_historic_weather(location, forecast_date)
else:
return ctx.deps.get_forecast(location, forecast_date)
async def run_weather_forecast( # (4)!
user_prompts: list[tuple[str, int]], conn: DatabaseConn
):
"""Run weather forecast for a list of user prompts and save."""
async with WeatherService() as weather_service:
async def run_forecast(prompt: str, user_id: int):
result = await weather_agent.run(prompt, deps=weather_service)
await conn.store_forecast(user_id, result.data)
# run all prompts in parallel
await asyncio.gather(
*(run_forecast(prompt, user_id) for (prompt, user_id) in user_prompts)
)
DatabaseConn是一个持有数据库连接的类WeatherService具有获取天气预报和天气历史数据的方法- 我们需要根据日期是过去还是未来调用不同的端点,您将在下面看到这个细微差别的重要性
- 这个函数是我们想要测试的代码,以及它使用的代理
这里有一个函数,它接受一个由(user_prompt, user_id) 元组组成的列表,为每个提示获取天气预报,并将结果存储在数据库中。
我们希望在不需要模拟某些对象或修改我们的代码以传入测试对象的情况下测试这段代码。
以下是我们如何使用 TestModel 编写测试:
from datetime import timezone
import pytest
from dirty_equals import IsNow
from pydantic_ai import models, capture_run_messages
from pydantic_ai.models.test import TestModel
from pydantic_ai.messages import (
ModelResponse,
SystemPromptPart,
TextPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
ModelRequest,
)
from fake_database import DatabaseConn
from weather_app import run_weather_forecast, weather_agent
pytestmark = pytest.mark.anyio # (1)!
models.ALLOW_MODEL_REQUESTS = False # (2)!
async def test_forecast():
conn = DatabaseConn()
user_id = 1
with capture_run_messages() as messages:
with weather_agent.override(model=TestModel()): # (3)!
prompt = 'What will the weather be like in London on 2024-11-28?'
await run_weather_forecast([(prompt, user_id)], conn) # (4)!
forecast = await conn.get_forecast(user_id)
assert forecast == '{"weather_forecast":"Sunny with a chance of rain"}' # (5)!
assert messages == [ # (6)!
ModelRequest(
parts=[
SystemPromptPart(
content='Providing a weather forecast at the locations the user provides.',
),
UserPromptPart(
content='What will the weather be like in London on 2024-11-28?',
timestamp=IsNow(tz=timezone.utc), # (7)!
),
]
),
ModelResponse(
parts=[
ToolCallPart(
tool_name='weather_forecast',
args={
'location': 'a',
'forecast_date': '2024-01-01', # (8)!
},
tool_call_id=None,
)
],
model_name='test',
timestamp=IsNow(tz=timezone.utc),
),
ModelRequest(
parts=[
ToolReturnPart(
tool_name='weather_forecast',
content='Sunny with a chance of rain',
tool_call_id=None,
timestamp=IsNow(tz=timezone.utc),
),
],
),
ModelResponse(
parts=[
TextPart(
content='{"weather_forecast":"Sunny with a chance of rain"}',
)
],
model_name='test',
timestamp=IsNow(tz=timezone.utc),
),
]
- 我们正在使用 anyio 来运行异步测试。
- 这是一个安全措施,以确保我们在测试时不会意外对LLM发出真实请求,更多细节请参见
ALLOW_MODEL_REQUESTS。 - 我们正在使用
Agent.override将代理的模型替换为TestModel,关于override的好处是我们可以在不需要访问代理run*方法调用位置的情况下替换代理内部的模型。 - 现在我们在
override上下文管理器内调用我们想要测试的函数。 - 但是默认情况下,
TestModel将返回一个 JSON 字符串,总结了工具的调用及其返回的结果。如果您想将响应自定义为更贴合该领域的内容,可以在定义TestModel时添加custom_result_text='Sunny'。 - 到目前为止,我们实际上不知道调用了哪些工具以及使用了哪些值,我们可以使用
capture_run_messages来检查最近一次运行的消息,并验证代理与模型之间的交换是否如预期发生。 - 这个
IsNow辅助工具允许我们使用声明性断言,即使是包含会随时间变化的时间戳的数据。 TestModel并没有做任何聪明的事情来从提示中提取值,因此这些值是硬编码的。
使用 FunctionModel 的单元测试
上述测试是一个很好的开始,但细心的读者会注意到,WeatherService.get_forecast 从未被调用,因为 TestModel 使用过去的日期调用了 weather_forecast。
要充分使用 weather_forecast,我们需要使用 FunctionModel 来定制工具的调用方式。
这是一个使用 FunctionModel 测试 weather_forecast 工具的自定义输入示例
import re
import pytest
from pydantic_ai import models
from pydantic_ai.messages import (
ModelMessage,
ModelResponse,
TextPart,
ToolCallPart,
)
from pydantic_ai.models.function import AgentInfo, FunctionModel
from fake_database import DatabaseConn
from weather_app import run_weather_forecast, weather_agent
pytestmark = pytest.mark.anyio
models.ALLOW_MODEL_REQUESTS = False
def call_weather_forecast( # (1)!
messages: list[ModelMessage], info: AgentInfo
) -> ModelResponse:
if len(messages) == 1:
# first call, call the weather forecast tool
user_prompt = messages[0].parts[-1]
m = re.search(r'\d{4}-\d{2}-\d{2}', user_prompt.content)
assert m is not None
args = {'location': 'London', 'forecast_date': m.group()} # (2)!
return ModelResponse(parts=[ToolCallPart('weather_forecast', args)])
else:
# second call, return the forecast
msg = messages[-1].parts[0]
assert msg.part_kind == 'tool-return'
return ModelResponse(parts=[TextPart(f'The forecast is: {msg.content}')])
async def test_forecast_future():
conn = DatabaseConn()
user_id = 1
with weather_agent.override(model=FunctionModel(call_weather_forecast)): # (3)!
prompt = 'What will the weather be like in London on 2032-01-01?'
await run_weather_forecast([(prompt, user_id)], conn)
forecast = await conn.get_forecast(user_id)
assert forecast == 'The forecast is: Rainy with a chance of sun'
- 我们定义了一个函数
call_weather_forecast,该函数将被FunctionModel调用,代替LLM,该函数可以访问构成运行的ModelMessage列表,以及包含代理和函数工具及返回工具信息的AgentInfo。 - 我们的函数稍微智能一些,因为它尝试从提示中提取日期,但只是硬编码了位置。
- 我们使用
FunctionModel来用我们的自定义函数替换代理的模型。
通过 pytest 固件重写模型
如果您编写大量都需要重写模型的测试,您可以使用 pytest fixtures 以可重用的方式用 TestModel 或 FunctionModel 来重写模型。
这是一个用 TestModel 重写模型的示例:
import pytest
from weather_app import weather_agent
from pydantic_ai.models.test import TestModel
@pytest.fixture
def override_weather_agent():
with weather_agent.override(model=TestModel()):
yield
async def test_forecast(override_weather_agent: None):
...
# test code here
评估
"Evals" 指的是评估模型在特定应用中的表现。
警告
与单元测试不同,评估是一门新兴的艺术/科学;任何声称确切知道你的评估应如何定义的人都可以被安全地忽略。
评估通常更像基准测试而不是单元测试,它们从来不会“通过”,虽然它们确实会“失败”;你最关心的是它们随时间如何变化。
由于评估需要在真实模型上运行,因此可能会慢且成本高昂,您一般不希望在每次提交时在CI中运行它们。
性能测量
评估中最困难的部分是衡量模型的表现如何。
在某些情况下(例如,生成SQL的代理),可以使用简单、易于运行的测试来衡量性能(例如,SQL是否有效?它是否返回正确的结果?它是否只返回正确的结果?)。
在其他情况下(例如,提供戒烟建议的代理),很难或不可能对表现进行定量测量——在吸烟的情况下,您确实需要进行几十个月的双盲试验,接着等待 40 年并观察健康结果,以确定对提示的修改是否是改进。
您可以使用几种不同的策略来衡量性能:
- 端到端的自包含测试 — 像SQL示例一样,我们可以几乎立刻测试代理的最终结果
- 综合自包含测试 — 编写单元测试风格的检查,确保输出符合预期,检查类似于
'chewing gum' in response,虽然这些检查看起来简单,但它们很有帮助,一个好的特点是当它们失败时很容易判断出问题所在 - 大型语言模型评估大型语言模型 — 使用其他模型,甚至同一模型的不同提示来评估智能体的性能(就像班级学生互相批改作业,因为老师宿醉),虽然这种方法的缺点和复杂性显而易见,但有些人认为在合适的情况下,它可以是一个有用的工具
- 生产中的评估 — 衡量代理在生产中的最终结果,然后创建一个定量的性能度量,以便您可以轻松衡量在更改提示或使用的模型时随时间变化的情况, logfire 在这种情况下可能非常有用,因为您可以编写自定义查询来衡量代理的性能
系统提示定制
系统提示是开发者控制代理行为的主要工具,因此能够自定义系统提示并观察性能变化通常是有用的。当系统提示包含一系列示例时,尤其 relevant 你想要了解更改该列表如何影响模型的性能。
假设我们有以下应用程序来运行由用户提示生成的SQL(这个示例省略了许多细节,请参阅SQL gen示例以获取更完整的代码):
import json
from pathlib import Path
from typing import Union
from pydantic_ai import Agent, RunContext
from fake_database import DatabaseConn
class SqlSystemPrompt: # (1)!
def __init__(
self, examples: Union[list[dict[str, str]], None] = None, db: str = 'PostgreSQL'
):
if examples is None:
# if examples aren't provided, load them from file, this is the default
with Path('examples.json').open('rb') as f:
self.examples = json.load(f)
else:
self.examples = examples
self.db = db
def build_prompt(self) -> str: # (2)!
return f"""\
Given the following {self.db} table of records, your job is to
write a SQL query that suits the user's request.
Database schema:
CREATE TABLE records (
...
);
{''.join(self.format_example(example) for example in self.examples)}
"""
@staticmethod
def format_example(example: dict[str, str]) -> str: # (3)!
return f"""\
<example>
<request>{example['request']}</request>
<sql>{example['sql']}</sql>
</example>
"""
sql_agent = Agent(
'google-gla:gemini-1.5-flash',
deps_type=SqlSystemPrompt,
)
@sql_agent.system_prompt
async def system_prompt(ctx: RunContext[SqlSystemPrompt]) -> str:
return ctx.deps.build_prompt()
async def user_search(user_prompt: str) -> list[dict[str, str]]:
"""Search the database based on the user's prompts."""
... # (4)!
result = await sql_agent.run(user_prompt, deps=SqlSystemPrompt())
conn = DatabaseConn()
return await conn.execute(result.data)
- 该
SqlSystemPrompt类用于构建系统提示,可以通过示例列表和数据库类型进行定制。我们将其作为一个单独的类实现,并作为依赖项传递给代理,这样在评估期间可以通过依赖注入覆盖输入和逻辑。 - 该
build_prompt方法根据示例和数据库类型构建系统提示。 - 一些人认为,如果示例格式化为XML,LLMs更有可能生成良好的回复,因为这样更容易识别字符串的结束,参见 #93。
- 实际上,您在这里会有更多的逻辑,使得独立于更广泛应用程序运行该代理变得不切实际。
examples.json 看起来像这样:
request: show me error records with the tag "foobar"
response: SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags)
{
"examples": [
{
"request": "Show me all records",
"sql": "SELECT * FROM records;"
},
{
"request": "Show me all records from 2021",
"sql": "SELECT * FROM records WHERE date_trunc('year', date) = '2021-01-01';"
},
{
"request": "show me error records with the tag 'foobar'",
"sql": "SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags);"
},
...
]
}
现在我们想要一种方法来量化SQL生成的成功,以便我们可以判断对代理的更改如何影响其性能。
我们可以使用 Agent.override 来用一个自定义的系统提示替换现有的提示,该自定义提示使用了一组示例,然后运行应用程序代码(在这种情况下为 user_search)。我们还运行示例中的实际 SQL,并将示例 SQL 的“正确”结果与代理生成的 SQL 进行比较。(我们比较运行 SQL 的结果,而不是 SQL 本身,因为 SQL 可能在语义上是等价的,但以不同的方式书写)。
为了获得性能的定量指标,我们为每次运行分配以下分数:
- -100 分数如果生成的 SQL 无效
- -1 分,每返回一行由代理返回的结果(因此不鼓励返回大量结果)
- +5 分数对于每一行由代理返回且与预期结果匹配
我们使用5折交叉验证来评估代理在我们现有示例集上的表现。
import json
import statistics
from pathlib import Path
from itertools import chain
from fake_database import DatabaseConn, QueryError
from sql_app import sql_agent, SqlSystemPrompt, user_search
async def main():
with Path('examples.json').open('rb') as f:
examples = json.load(f)
# split examples into 5 folds
fold_size = len(examples) // 5
folds = [examples[i : i + fold_size] for i in range(0, len(examples), fold_size)]
conn = DatabaseConn()
scores = []
for i, fold in enumerate(folds):
fold_score = 0
# build all other folds into a list of examples
other_folds = list(chain(*(f for j, f in enumerate(folds) if j != i)))
# create a new system prompt with the other fold examples
system_prompt = SqlSystemPrompt(examples=other_folds)
# override the system prompt with the new one
with sql_agent.override(deps=system_prompt):
for case in fold:
try:
agent_results = await user_search(case['request'])
except QueryError as e:
print(f'Fold {i} {case}: {e}')
fold_score -= 100
else:
# get the expected results using the SQL from this case
expected_results = await conn.execute(case['sql'])
agent_ids = [r['id'] for r in agent_results]
# each returned value has a score of -1
fold_score -= len(agent_ids)
expected_ids = {r['id'] for r in expected_results}
# each return value that matches the expected value has a score of 3
fold_score += 5 * len(set(agent_ids) & expected_ids)
scores.append(fold_score)
overall_score = statistics.mean(scores)
print(f'Overall score: {overall_score:0.2f}')
#> Overall score: 12.00
然后我们可以更改提示、模型或示例,观察分数随时间的变化。