一个与SQL数据库交互的代理¶
在本教程中,我们将逐步介绍如何构建一个能够回答关于SQL数据库问题的代理。
从高层次来看,代理将: 1. 从数据库中获取可用的表 2. 决定哪些表与问题相关 3. 获取相关表的DDL(数据定义语言) 4. 根据问题和DDL信息生成查询 5. 使用LLM(大型语言模型)对查询进行常见错误的复查 6. 执行查询并返回结果 7. 纠正数据库引擎反馈的错误,直到查询成功 8. 根据结果形成响应
端到端的工作流程大致如下所示:
设置¶
首先,让我们安装所需的软件包并设置我们的 API 密钥。
import getpass
import os
def _set_env(key: str):
if key not in os.environ:
os.environ[key] = getpass.getpass(f"{key}:")
_set_env("OPENAI_API_KEY")
为 LangGraph 开发设置 LangSmith
注册 LangSmith 以快速发现问题并提高您的 LangGraph 项目的性能。LangSmith 让您可以使用追踪数据来调试、测试和监控您使用 LangGraph 构建的 LLM 应用程序 — 阅读有关如何开始的更多信息 这里。
配置数据库¶
我们将为本教程创建一个SQLite数据库。SQLite是一个轻量级的数据库,易于设置和使用。我们将加载chinook数据库,这是一个代表数字媒体商店的示例数据库。
有关该数据库的更多信息,请查看这里。
为了方便,我们将数据库(Chinook.db)托管在一个公共的GCS存储桶中。
import requests
url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
response = requests.get(url)
if response.status_code == 200:
# 以二进制写入模式打开本地文件
with open("Chinook.db", "wb") as file:
# 将响应的内容(文件)写入本地文件。
file.write(response.content)
print("File downloaded and saved as Chinook.db")
else:
print(f"Failed to download the file. Status code: {response.status_code}")
langchain_community包中提供的一个方便的SQL数据库包装器与数据库进行交互。该包装器提供了一个简单的接口来执行SQL查询并获取结果。在教程后面,我们还将使用langchain_openai包与OpenAI API进行语言模型的交互。
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")
sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"
实用函数¶
我们将定义一些实用函数来帮助我们实现代理。具体来说,我们将用一个回退机制包装 ToolNode 以处理错误并将其呈现给代理。
from typing import Any
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
from langgraph.prebuilt import ToolNode
def create_tool_node_with_fallback(tools: list) -> RunnableWithFallbacks[Any, dict]:
"""
创建一个 ToolNode,并设置回退机制来处理错误并将其呈现给代理。
"""
return ToolNode(tools).with_fallbacks(
[RunnableLambda(handle_tool_error)], exception_key="error"
)
def handle_tool_error(state) -> dict:
error = state.get("error")
tool_calls = state["messages"][-1].tool_calls
return {
"messages": [
ToolMessage(
content=f"Error: {repr(error)}\n please fix your mistakes.",
tool_call_id=tc["id"],
)
for tc in tool_calls
]
}
定义代理的工具¶
我们将定义几个工具,代理将使用这些工具与数据库进行交互。
list_tables_tool:获取数据库中的可用表get_schema_tool:获取表的DDLdb_query_tool:执行查询并获取结果,或者在查询失败时返回错误信息
对于前两个工具,我们将从SQLDatabaseToolkit中获取,它也可以在langchain_community包中找到。
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_openai import ChatOpenAI
toolkit = SQLDatabaseToolkit(db=db, llm=ChatOpenAI(model="gpt-4o"))
tools = toolkit.get_tools()
list_tables_tool = next(tool for tool in tools if tool.name == "sql_db_list_tables")
get_schema_tool = next(tool for tool in tools if tool.name == "sql_db_schema")
print(list_tables_tool.invoke(""))
print(get_schema_tool.invoke("Artist"))
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
CREATE TABLE "Artist" (
"ArtistId" INTEGER NOT NULL,
"Name" NVARCHAR(120),
PRIMARY KEY ("ArtistId")
)
/*
3 rows from Artist table:
ArtistId Name
1 AC/DC
2 Accept
3 Aerosmith
*/
第三个将手动定义。对于 db_query_tool,我们将对数据库执行查询并返回结果。
from langchain_core.tools import tool
@tool
def db_query_tool(query: str) -> str:
"""
Execute a SQL query against the database and get back the result.
If the query is not correct, an error message will be returned.
If an error is returned, rewrite the query, check the query, and try again.
"""
result = db.run_no_throw(query)
if not result:
return "Error: Query failed. Please rewrite your query and try again."
return result
print(db_query_tool.invoke("SELECT * FROM Artist LIMIT 10;"))
[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]
虽然这严格来说不是一个工具,但我们将提示一个LLM检查查询中的常见错误,随后将其作为一个节点添加到工作流程中。
from langchain_core.prompts import ChatPromptTemplate
query_check_system = """You are a SQL expert with a strong attention to detail.
Double check the SQLite query for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
You will call the appropriate tool to execute the query after running this check."""
query_check_prompt = ChatPromptTemplate.from_messages(
[("system", query_check_system), ("placeholder", "{messages}")]
)
query_check = query_check_prompt | ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(
[db_query_tool], tool_choice="required"
)
query_check.invoke({"messages": [("user", "SELECT * FROM Artist LIMIT 10;")]})
AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_la8JTjHox6P1VjTqc15GSgdk', 'function': {'arguments': '{"query":"SELECT * FROM Artist LIMIT 10;"}', 'name': 'db_query_tool'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 20, 'prompt_tokens': 221, 'total_tokens': 241}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_a2ff031fb5', 'finish_reason': 'stop', 'logprobs': None}, id='run-dd7873ef-d2f7-4769-a5c0-e6776ec2c515-0', tool_calls=[{'name': 'db_query_tool', 'args': {'query': 'SELECT * FROM Artist LIMIT 10;'}, 'id': 'call_la8JTjHox6P1VjTqc15GSgdk', 'type': 'tool_call'}], usage_metadata={'input_tokens': 221, 'output_tokens': 20, 'total_tokens': 241})
定义工作流程¶
我们将定义代理的工作流程。代理将首先强制调用 list_tables_tool 以从数据库中获取可用的表,然后遵循教程开头提到的步骤。
在LangChain中使用Pydantic
本笔记本使用Pydantic v2 BaseModel,这需要langchain-core >= 0.3。使用langchain-core < 0.3会导致错误,因为混合使用了Pydantic v1和v2的BaseModels。
from typing import Annotated, Literal
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from langgraph.graph import END, StateGraph, START
from langgraph.graph.message import AnyMessage, add_messages
# 为代理定义状态。
class State(TypedDict):
messages: Annotated[list[AnyMessage], add_messages]
# 定义一个新图。
workflow = StateGraph(State)
# 为第一次工具调用添加一个节点。
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
return {
"messages": [
AIMessage(
content="",
tool_calls=[
{
"name": "sql_db_list_tables",
"args": {},
"id": "tool_abcd123",
}
],
)
]
}
def model_check_query(state: State) -> dict[str, list[AIMessage]]:
"""
在执行之前,请使用此工具仔细检查您的查询是否正确。
"""
return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
workflow.add_node("first_tool_call", first_tool_call)
# 为前两个工具添加节点。
workflow.add_node(
"list_tables_tool", create_tool_node_with_fallback([list_tables_tool])
)
workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
# 为模型添加一个节点,以根据问题和可用表选择相关表。
model_get_schema = ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(
[get_schema_tool]
)
workflow.add_node(
"model_get_schema",
lambda state: {
"messages": [model_get_schema.invoke(state["messages"])],
},
)
# 描述一种工具以表示最终状态。
class SubmitFinalAnswer(BaseModel):
"""根据查询结果向用户提交最终答案。"""
final_answer: str = Field(..., description="The final answer to the user")
# 添加一个节点,用于生成基于问题和模式的查询的模型。
query_gen_system = """You are a SQL expert with a strong attention to detail.
Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.
When generating the query:
Output the SQL query that answers the input question without a tool call.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
If you get an error while executing a query, rewrite the query and try again.
If you get an empty result set, you should try to rewrite the query to get a non-empty result set.
NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.
If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database."""
query_gen_prompt = ChatPromptTemplate.from_messages(
[("system", query_gen_system), ("placeholder", "{messages}")]
)
query_gen = query_gen_prompt | ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(
[SubmitFinalAnswer]
)
def query_gen_node(state: State):
message = query_gen.invoke(state)
# 有时,语言模型会产生幻觉并调用错误的工具。我们需要捕捉这一点并返回错误信息。
tool_messages = []
if message.tool_calls:
for tc in message.tool_calls:
if tc["name"] != "SubmitFinalAnswer":
tool_messages.append(
ToolMessage(
content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
tool_call_id=tc["id"],
)
)
else:
tool_messages = []
return {"messages": [message] + tool_messages}
workflow.add_node("query_gen", query_gen_node)
# 为模型添加一个节点,以在执行查询之前检查查询。
workflow.add_node("correct_query", model_check_query)
# 添加节点以执行查询
workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
# 定义一个条件边以决定是继续还是结束工作流。
def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
messages = state["messages"]
last_message = messages[-1]
# 如果有工具调用,那么我们就完成了。
if getattr(last_message, "tool_calls", None):
return END
if last_message.content.startswith("Error:"):
return "query_gen"
else:
return "correct_query"
# 指定节点之间的边缘。
workflow.add_edge(START, "first_tool_call")
workflow.add_edge("first_tool_call", "list_tables_tool")
workflow.add_edge("list_tables_tool", "model_get_schema")
workflow.add_edge("model_get_schema", "get_schema_tool")
workflow.add_edge("get_schema_tool", "query_gen")
workflow.add_conditional_edges(
"query_gen",
should_continue,
)
workflow.add_edge("correct_query", "execute_query")
workflow.add_edge("execute_query", "query_gen")
# 将工作流程编译成可运行的程序。
app = workflow.compile()
可视化图形¶
from IPython.display import Image, display
from langchain_core.runnables.graph import MermaidDrawMethod
display(
Image(
app.get_graph().draw_mermaid_png(
draw_method=MermaidDrawMethod.API,
)
)
)
运行代理¶
messages = app.invoke(
{"messages": [("user", "Which sales agent made the most in sales in 2009?")]}
)
json_str = messages["messages"][-1].tool_calls[0]["args"]["final_answer"]
json_str
for event in app.stream(
{"messages": [("user", "Which sales agent made the most in sales in 2009?")]}
):
print(event)
{'first_tool_call': {'messages': [AIMessage(content='', tool_calls=[{'name': 'sql_db_list_tables', 'args': {}, 'id': 'tool_abcd123', 'type': 'tool_call'}])]}}
{'list_tables_tool': {'messages': [ToolMessage(content='Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track', name='sql_db_list_tables', tool_call_id='tool_abcd123')]}}
{'model_get_schema': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_z1tyC7cEAawi5oIQn731Uknp', 'function': {'arguments': '{"table_names":"Employee, Invoice"}', 'name': 'sql_db_schema'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 18, 'prompt_tokens': 177, 'total_tokens': 195}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_a2ff031fb5', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-c91a5aad-fc05-4881-87f9-0662d703c3c8-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'Employee, Invoice'}, 'id': 'call_z1tyC7cEAawi5oIQn731Uknp', 'type': 'tool_call'}], usage_metadata={'input_tokens': 177, 'output_tokens': 18, 'total_tokens': 195})]}}
{'get_schema_tool': {'messages': [ToolMessage(content='\nCREATE TABLE "Employee" (\n\t"EmployeeId" INTEGER NOT NULL, \n\t"LastName" NVARCHAR(20) NOT NULL, \n\t"FirstName" NVARCHAR(20) NOT NULL, \n\t"Title" NVARCHAR(30), \n\t"ReportsTo" INTEGER, \n\t"BirthDate" DATETIME, \n\t"HireDate" DATETIME, \n\t"Address" NVARCHAR(70), \n\t"City" NVARCHAR(40), \n\t"State" NVARCHAR(40), \n\t"Country" NVARCHAR(40), \n\t"PostalCode" NVARCHAR(10), \n\t"Phone" NVARCHAR(24), \n\t"Fax" NVARCHAR(24), \n\t"Email" NVARCHAR(60), \n\tPRIMARY KEY ("EmployeeId"), \n\tFOREIGN KEY("ReportsTo") REFERENCES "Employee" ("EmployeeId")\n)\n\n/*\n3 rows from Employee table:\nEmployeeId\tLastName\tFirstName\tTitle\tReportsTo\tBirthDate\tHireDate\tAddress\tCity\tState\tCountry\tPostalCode\tPhone\tFax\tEmail\n1\tAdams\tAndrew\tGeneral Manager\tNone\t1962-02-18 00:00:00\t2002-08-14 00:00:00\t11120 Jasper Ave NW\tEdmonton\tAB\tCanada\tT5K 2N1\t+1 (780) 428-9482\t+1 (780) 428-3457\tandrew@chinookcorp.com\n2\tEdwards\tNancy\tSales Manager\t1\t1958-12-08 00:00:00\t2002-05-01 00:00:00\t825 8 Ave SW\tCalgary\tAB\tCanada\tT2P 2T3\t+1 (403) 262-3443\t+1 (403) 262-3322\tnancy@chinookcorp.com\n3\tPeacock\tJane\tSales Support Agent\t2\t1973-08-29 00:00:00\t2002-04-01 00:00:00\t1111 6 Ave SW\tCalgary\tAB\tCanada\tT2P 5M5\t+1 (403) 262-3443\t+1 (403) 262-6712\tjane@chinookcorp.com\n*/\n\n\nCREATE TABLE "Invoice" (\n\t"InvoiceId" INTEGER NOT NULL, \n\t"CustomerId" INTEGER NOT NULL, \n\t"InvoiceDate" DATETIME NOT NULL, \n\t"BillingAddress" NVARCHAR(70), \n\t"BillingCity" NVARCHAR(40), \n\t"BillingState" NVARCHAR(40), \n\t"BillingCountry" NVARCHAR(40), \n\t"BillingPostalCode" NVARCHAR(10), \n\t"Total" NUMERIC(10, 2) NOT NULL, \n\tPRIMARY KEY ("InvoiceId"), \n\tFOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")\n)\n\n/*\n3 rows from Invoice table:\nInvoiceId\tCustomerId\tInvoiceDate\tBillingAddress\tBillingCity\tBillingState\tBillingCountry\tBillingPostalCode\tTotal\n1\t2\t2009-01-01 00:00:00\tTheodor-Heuss-Straße 34\tStuttgart\tNone\tGermany\t70174\t1.98\n2\t4\t2009-01-02 00:00:00\tUllevålsveien 14\tOslo\tNone\tNorway\t0171\t3.96\n3\t8\t2009-01-03 00:00:00\tGrétrystraat 63\tBrussels\tNone\tBelgium\t1000\t5.94\n*/', name='sql_db_schema', tool_call_id='call_z1tyC7cEAawi5oIQn731Uknp')]}}
{'query_gen': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_ErWLktUfxKsHGNGr74m72yYD', 'function': {'arguments': '{"table_names":"Customer"}', 'name': 'sql_db_schema'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 16, 'prompt_tokens': 1179, 'total_tokens': 1195}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_a2ff031fb5', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-19e02169-5e1e-40d9-90a2-384336ca5069-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': 'Customer'}, 'id': 'call_ErWLktUfxKsHGNGr74m72yYD', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1179, 'output_tokens': 16, 'total_tokens': 1195}), ToolMessage(content='Error: The wrong tool was called: sql_db_schema. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.', id='de5d25f5-b891-4e47-8282-d04dc9b93e9e', tool_call_id='call_ErWLktUfxKsHGNGr74m72yYD')]}}
{'query_gen': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_TFaA52SbhgEqm3ElEAd4HCsn', 'function': {'arguments': '{"table_names":["Customer"]}', 'name': 'sql_db_schema'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 17, 'prompt_tokens': 1245, 'total_tokens': 1262}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_a2ff031fb5', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-2c5f800f-43dc-4224-847b-49b5079efd2a-0', tool_calls=[{'name': 'sql_db_schema', 'args': {'table_names': ['Customer']}, 'id': 'call_TFaA52SbhgEqm3ElEAd4HCsn', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1245, 'output_tokens': 17, 'total_tokens': 1262}), ToolMessage(content='Error: The wrong tool was called: sql_db_schema. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.', id='6c962a35-fc24-4f27-86f0-6ec05256d478', tool_call_id='call_TFaA52SbhgEqm3ElEAd4HCsn')]}}
{'query_gen': {'messages': [AIMessage(content="To determine which sales agent made the most in sales in 2009, we need to join the `Invoice`, `Customer`, and `Employee` tables. Here is the query to find the top sales agent:\n\n\`\`\`sql\nSELECT e.FirstName, e.LastName, SUM(i.Total) as TotalSales\nFROM Invoice i\nJOIN Customer c ON i.CustomerId = c.CustomerId\nJOIN Employee e ON c.SupportRepId = e.EmployeeId\nWHERE strftime('%Y', i.InvoiceDate) = '2009'\nGROUP BY e.EmployeeId\nORDER BY TotalSales DESC\nLIMIT 1;\n\`\`\`", additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 125, 'prompt_tokens': 1312, 'total_tokens': 1437}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_3aa7262c27', 'finish_reason': 'stop', 'logprobs': None}, id='run-6cacd10d-d3aa-49ae-b9d7-8cc209fc4ccc-0', usage_metadata={'input_tokens': 1312, 'output_tokens': 125, 'total_tokens': 1437})]}}
{'correct_query': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_FwCE2c7WORU7lKHdSWqMv0ON', 'function': {'arguments': '{"query":"SELECT e.FirstName, e.LastName, SUM(i.Total) as TotalSales\\nFROM Invoice i\\nJOIN Customer c ON i.CustomerId = c.CustomerId\\nJOIN Employee e ON c.SupportRepId = e.EmployeeId\\nWHERE strftime(\'%Y\', i.InvoiceDate) = \'2009\'\\nGROUP BY e.EmployeeId\\nORDER BY TotalSales DESC\\nLIMIT 1;"}', 'name': 'db_query_tool'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 90, 'prompt_tokens': 337, 'total_tokens': 427}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_a2ff031fb5', 'finish_reason': 'stop', 'logprobs': None}, id='run-71067e75-80f6-4356-8239-518e466b3526-0', tool_calls=[{'name': 'db_query_tool', 'args': {'query': "SELECT e.FirstName, e.LastName, SUM(i.Total) as TotalSales\nFROM Invoice i\nJOIN Customer c ON i.CustomerId = c.CustomerId\nJOIN Employee e ON c.SupportRepId = e.EmployeeId\nWHERE strftime('%Y', i.InvoiceDate) = '2009'\nGROUP BY e.EmployeeId\nORDER BY TotalSales DESC\nLIMIT 1;"}, 'id': 'call_FwCE2c7WORU7lKHdSWqMv0ON', 'type': 'tool_call'}], usage_metadata={'input_tokens': 337, 'output_tokens': 90, 'total_tokens': 427})]}}
{'execute_query': {'messages': [ToolMessage(content="[('Steve', 'Johnson', 164.34)]", name='db_query_tool', tool_call_id='call_FwCE2c7WORU7lKHdSWqMv0ON')]}}
{'query_gen': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_fHJ4lvdiFM9HY6gupE6vLZV4', 'function': {'arguments': '{"final_answer":"The sales agent who made the most in sales in 2009 is Steve Johnson with total sales of 164.34."}', 'name': 'SubmitFinalAnswer'}, 'type': 'function'}], 'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 41, 'prompt_tokens': 1553, 'total_tokens': 1594}, 'model_name': 'gpt-4o-2024-05-13', 'system_fingerprint': 'fp_cb7cc8e106', 'finish_reason': 'tool_calls', 'logprobs': None}, id='run-2ec7bf3a-2a16-47bd-aa9c-b7d6dc531c1b-0', tool_calls=[{'name': 'SubmitFinalAnswer', 'args': {'final_answer': 'The sales agent who made the most in sales in 2009 is Steve Johnson with total sales of 164.34.'}, 'id': 'call_fHJ4lvdiFM9HY6gupE6vLZV4', 'type': 'tool_call'}], usage_metadata={'input_tokens': 1553, 'output_tokens': 41, 'total_tokens': 1594})]}}
评估¶
现在,我们可以评估这个智能体!我们之前在我们的 LangSmith 评估食谱中定义了 简单 SQL 智能体,并评估了其对我们数据库的 5 个问题的响应。我们可以将这个智能体与之前在相同数据集上的智能体进行比较。智能体评估 可以关注以下三个方面:
响应:输入是一个提示和一组工具。输出是智能体的响应。单个工具:与之前一样,输入是一个提示和一组工具。输出是工具调用。轨迹:与之前一样,输入是一个提示和一组工具。输出是工具调用的列表。
响应¶
我们将评估智能体的端到端响应,与参考答案进行对比。让我们在 相同数据集上 运行 响应评估。
import json
def predict_sql_agent_answer(example: dict):
"""将其用作答案评估。"""
msg = {"messages": ("user", example["input"])}
messages = app.invoke(msg)
json_str = messages["messages"][-1].tool_calls[0]["args"]
response = json_str["final_answer"]
return {"response": response}
from langchain import hub
from langchain_openai import ChatOpenAI
# 评分提示
grade_prompt_answer_accuracy = prompt = hub.pull("langchain-ai/rag-answer-vs-reference")
def answer_evaluator(run, example) -> dict:
"""
一个简单的评估器,用于评估RAG答案的准确性。
"""
# 获取问题、真实答案、链条
input_question = example.inputs["input"]
reference = example.outputs["output"]
prediction = run.outputs["response"]
# LLM 评分器
llm = ChatOpenAI(model="gpt-4-turbo", temperature=0)
# 结构化提示
answer_grader = grade_prompt_answer_accuracy | llm
# 运行评估器
score = answer_grader.invoke(
{
"question": input_question,
"correct_answer": reference,
"student_answer": prediction,
}
)
score = score["Score"]
return {"key": "answer_v_reference_score", "score": score}
from langsmith.evaluation import evaluate
dataset_name = "SQL Agent Response"
try:
experiment_results = evaluate(
predict_sql_agent_answer,
data=dataset_name,
evaluators=[answer_evaluator],
num_repetitions=3,
experiment_prefix="sql-agent-multi-step-response-v-reference",
metadata={"version": "Chinook, gpt-4o multi-step-agent"},
)
except:
print("Please setup LangSmith")
摘要指标(请参见数据集 这里):
- 这里的
multi-step代理在性能上超过了之前定义的 基础案例 SQL 代理
轨迹¶
让我们在这个相同的数据集上运行 轨迹评估。
# 这些是我们期望代理使用的工具。
expected_trajectory = [
"sql_db_list_tables", # 第一:列出表格工具节点
"sql_db_schema", # 第二:获取模式工具节点
"db_query_tool", # 第三:执行查询节点
"SubmitFinalAnswer",
] # 第四:查询生成
def predict_sql_agent_messages(example: dict):
"""将此用于答案评估。"""
msg = {"messages": ("user", example["input"])}
messages = app.invoke(msg)
return {"response": messages}
from langsmith.schemas import Example, Run
def find_tool_calls(messages):
"""
找到返回的消息中的所有工具调用。
"""
tool_calls = [
tc["name"] for m in messages["messages"] for tc in getattr(m, "tool_calls", [])
]
return tool_calls
def contains_all_tool_calls_in_order_exact_match(
root_run: Run, example: Example
) -> dict:
"""
检查所有预期的工具是否按确切顺序调用,并且没有任何额外的工具调用。
"""
expected_trajectory = [
"sql_db_list_tables",
"sql_db_schema",
"db_query_tool",
"SubmitFinalAnswer",
]
messages = root_run.outputs["response"]
tool_calls = find_tool_calls(messages)
# 打印调试工具调用
print("Here are my tool calls:")
print(tool_calls)
# 检查工具调用是否与预期轨迹完全匹配。
if tool_calls == expected_trajectory:
score = 1
else:
score = 0
return {"score": int(score), "key": "multi_tool_call_in_exact_order"}
def contains_all_tool_calls_in_order(root_run: Run, example: Example) -> dict:
"""
检查所有预期工具是否按顺序调用,但允许在预期工具之间调用其他工具。
"""
messages = root_run.outputs["response"]
tool_calls = find_tool_calls(messages)
# 打印调试工具调用
print("Here are my tool calls:")
print(tool_calls)
it = iter(tool_calls)
if all(elem in it for elem in expected_trajectory):
score = 1
else:
score = 0
return {"score": int(score), "key": "multi_tool_call_in_order"}
try:
experiment_results = evaluate(
predict_sql_agent_messages,
data=dataset_name,
evaluators=[
contains_all_tool_calls_in_order,
contains_all_tool_calls_in_order_exact_match,
],
num_repetitions=3,
experiment_prefix="sql-agent-multi-step-tool-calling-trajecory-in-order",
metadata={"version": "Chinook, gpt-4o multi-step-agent"},
)
except:
print("Please setup LangSmith")
汇总得分显示我们从未准确地按顺序调用工具:
查看日志,我们可以看到一些有趣的内容 -
我们似乎在大多数运行中向我们的轨迹中注入了一个虚构的工具调用 sql_db_query。
这就是为什么 multi_tool_call_in_exact_order 失败,但 multi_tool_call_in_order 仍然通过的原因。
我们将在未来的食谱中探索使用 LangGraph 解决这个问题的方法!