跳转到内容

结构化输出的反思工作流程

本笔记本将指导您如何设置一个 Workflow,通过重试和错误反思来提供可靠的结构化输出。

本笔记本最适合与开源LLM配合使用,因此我们将使用Ollama。如果您尚未运行Ollama,请访问https://ollama.com开始使用并下载您想要使用的模型。(在本例中,我们在运行此笔记本之前执行了ollama pull llama3.1)。

!pip install -U llama-index llama-index-llms-ollama

Since workflows are async first, this all runs fine in a notebook. If you were running in your own code, you would want to use asyncio.run() to start an async event loop if one isn’t already running.

async def main():
<async code>
if __name__ == "__main__":
import asyncio
asyncio.run(main())

要验证大型语言模型的结构化输出,我们只需要两个步骤:

  1. 生成结构化输出
  2. 验证输出是否为正确的JSON

这里的关键在于,如果输出无效,我们会循环直到它有效,将错误反馈传递给下一次生成。

为了处理这些步骤,我们需要定义几个事件:

  1. 用于传递生成提取结果的事件
  2. 当提取无效时提供反馈的事件

The other steps will use the built-in StartEvent and StopEvent events.

from llama_index.core.workflow import Event
class ExtractionDone(Event):
output: str
passage: str
class ValidationErrorEvent(Event):
error: str
wrong_output: str
passage: str

为了提示我们的模型,让我们定义一个想要提取的pydantic模型。

from pydantic import BaseModel
class Car(BaseModel):
brand: str
model: str
power: int
class CarCollection(BaseModel):
cars: list[Car]

定义好事件后,我们可以构建工作流和步骤。

请注意,工作流会自动使用类型注解进行自我验证,因此我们步骤中的类型注解非常有用!

import json
from llama_index.core.workflow import (
Workflow,
StartEvent,
StopEvent,
Context,
step,
)
from llama_index.llms.ollama import Ollama
EXTRACTION_PROMPT = """
Context information is below:
---------------------
{passage}
---------------------
Given the context information and not prior knowledge, create a JSON object from the information in the context.
The JSON object must follow the JSON schema:
{schema}
"""
REFLECTION_PROMPT = """
You already created this output previously:
---------------------
{wrong_answer}
---------------------
This caused the JSON decode error: {error}
Try again, the response must contain only valid JSON code. Do not add any sentence before or after the JSON object.
Do not repeat the schema.
"""
class ReflectionWorkflow(Workflow):
max_retries: int = 3
@step
async def extract(
self, ctx: Context, ev: StartEvent | ValidationErrorEvent
) -> StopEvent | ExtractionDone:
current_retries = await ctx.store.get("retries", default=0)
if current_retries >= self.max_retries:
return StopEvent(result="Max retries reached")
else:
await ctx.store.set("retries", current_retries + 1)
if isinstance(ev, StartEvent):
passage = ev.get("passage")
if not passage:
return StopEvent(result="Please provide some text in input")
reflection_prompt = ""
elif isinstance(ev, ValidationErrorEvent):
passage = ev.passage
reflection_prompt = REFLECTION_PROMPT.format(
wrong_answer=ev.wrong_output, error=ev.error
)
llm = Ollama(
model="llama3",
request_timeout=30,
# Manually set the context window to limit memory usage
context_window=8000,
)
prompt = EXTRACTION_PROMPT.format(
passage=passage, schema=CarCollection.schema_json()
)
if reflection_prompt:
prompt += reflection_prompt
output = await llm.acomplete(prompt)
return ExtractionDone(output=str(output), passage=passage)
@step
async def validate(
self, ev: ExtractionDone
) -> StopEvent | ValidationErrorEvent:
try:
CarCollection.model_validate_json(ev.output)
except Exception as e:
print("Validation failed, retrying...")
return ValidationErrorEvent(
error=str(e), wrong_output=ev.output, passage=ev.passage
)
return StopEvent(result=ev.output)

就这样!让我们稍微探索一下我们编写的工作流程。

  • 我们有一个入口点,extract(接受 StartEvent 的步骤)
  • extract 完成时,它会发出一个 ExtractionDone 事件
  • validate 运行并确认提取结果:
    • 如果正常,则发出 StopEvent 并停止工作流
    • 如果不正常,则返回带有错误信息的 ValidationErrorEvent
  • 任何发出的 ValidationErrorEvent 都会触发循环,并且 extract 会再次运行!
  • 这一过程将持续,直到结构化输出通过验证

注意:使用循环时,我们需要留意运行时间。这里我们设置了120秒的超时限制。

w = ReflectionWorkflow(timeout=120, verbose=True)
# Run the workflow
ret = await w.run(
passage="I own two cars: a Fiat Panda with 45Hp and a Honda Civic with 330Hp."
)
Running step extract
Step extract produced event ExtractionDone
Running step validate
Validation failed, retrying...
Step validate produced event ValidationErrorEvent
Running step extract
Step extract produced event ExtractionDone
Running step validate
Step validate produced event StopEvent
print(ret)
{ "cars": [ { "brand": "Fiat", "model": "Panda", "power": 45 }, { "brand": "Honda", "model": "Civic", "power": 330 } ] }