Skip to content

使用RAG和自我纠正进行代码生成

AlphaCodium提出了一种使用控制流进行代码生成的方法。

主要思想: 通过迭代方式构建编程问题的答案。

AlphaCodium 迭代地测试和改进某个问题的答案,使用公共和AI生成的测试。

我们将使用 LangGraph 从头实现一些这些想法:

  1. 我们从用户指定的一组文档开始
  2. 我们使用长上下文的LLM来处理文档,并执行RAG以根据文档回答问题
  3. 我们将调用一个工具来生成结构化输出
  4. 在将解决方案返回给用户之前,我们将执行两个单元测试(检查导入和代码执行)

Screenshot 2024-05-23 at 2.17.42 PM.png

设置

首先,让我们安装所需的包,并设置我们需要的API密钥。

! pip install -U langchain_community langchain-openai langchain-anthropic langchain langgraph bs4
import getpass
import os


def _set_env(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass.getpass(f"{var}: ")


_set_env("OPENAI_API_KEY")
_set_env("ANTHROPIC_API_KEY")

为LangGraph开发设置LangSmith

注册LangSmith,以快速发现问题并提高您的LangGraph项目的性能。LangSmith允许您使用追踪数据来调试、测试和监控您基于LangGraph构建的LLM应用——有关如何开始的更多信息,请阅读这里

文档

加载LangChain 表达式语言 (LCEL) 文档作为示例。

from bs4 import BeautifulSoup as Soup
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader

# LCEL 文档
url = "https://python.langchain.com/docs/concepts/lcel/"
loader = RecursiveUrlLoader(
    url=url, max_depth=20, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()

# 根据网址对列表进行排序并获取文本。
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
    [doc.page_content for doc in d_reversed]
)
API Reference: RecursiveUrlLoader

LLMs

代码解决方案

首先,我们将尝试 OpenAI 和 Claude3 的函数调用。

我们将创建一个 code_gen_chain,并使用 OpenAI 或 Claude 在这里进行测试。

在LangChain中使用Pydantic

本笔记本使用Pydantic v2 BaseModel,这需要 langchain-core >= 0.3。使用 langchain-core < 0.3 会导致因混合使用Pydantic v1和v2 BaseModels 而出现错误。

from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field

# ##OpenAI

# 评分者提示
code_gen_prompt = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            """You are a coding assistant with expertise in LCEL, LangChain expression language. \n 
    Here is a full set of LCEL documentation:  \n ------- \n  {context} \n ------- \n Answer the user 
    question based on the above provided documentation. Ensure any code you provide can be executed \n 
    with all required imports and variables defined. Structure your answer with a description of the code solution. \n
    Then list the imports. And finally list the functioning code block. Here is the user question:""",
        ),
        ("placeholder", "{messages}"),
    ]
)


# 数据模型
class code(BaseModel):
    """Schema for code solutions to questions about LCEL."""

    prefix: str = Field(description="Description of the problem and approach")
    imports: str = Field(description="Code block import statements")
    code: str = Field(description="Code block not including import statements")


expt_llm = "gpt-4o-mini"
llm = ChatOpenAI(temperature=0, model=expt_llm)
code_gen_chain_oai = code_gen_prompt | llm.with_structured_output(code)
question = "How do I build a RAG chain in LCEL?"
solution = code_gen_chain_oai.invoke(
    {"context": concatenated_content, "messages": [("user", question)]}
)
solution
code(prefix='To build a Retrieval-Augmented Generation (RAG) chain in LCEL, you will need to set up a chain that combines a retriever and a language model (LLM). The retriever will fetch relevant documents based on a query, and the LLM will generate a response using the retrieved documents as context. Here’s how you can do it:', imports='from langchain_core.prompts import ChatPromptTemplate\nfrom langchain_openai import ChatOpenAI\nfrom langchain_core.output_parsers import StrOutputParser\nfrom langchain_core.retrievers import MyRetriever', code='# Define the retriever\nretriever = MyRetriever()  # Replace with your specific retriever implementation\n\n# Define the LLM model\nmodel = ChatOpenAI(model="gpt-4")\n\n# Create a prompt template for the LLM\nprompt_template = ChatPromptTemplate.from_template("Given the following documents, answer the question: {question}\nDocuments: {documents}")\n\n# Create the RAG chain\nrag_chain = prompt_template | retriever | model | StrOutputParser()\n\n# Example usage\nquery = "What are the benefits of using RAG?"\nresponse = rag_chain.invoke({"question": query})\nprint(response)')
API Reference: ChatPromptTemplate | ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate

# ##人类的

# 促使工具使用的提示
code_gen_prompt_claude = ChatPromptTemplate.from_messages(
    [
        (
            "system",
            """<instructions> You are a coding assistant with expertise in LCEL, LangChain expression language. \n 
    Here is the LCEL documentation:  \n ------- \n  {context} \n ------- \n Answer the user  question based on the \n 
    above provided documentation. Ensure any code you provide can be executed with all required imports and variables \n
    defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n
    Invoke the code tool to structure the output correctly. </instructions> \n Here is the user question:""",
        ),
        ("placeholder", "{messages}"),
    ]
)


# 大型语言模型
expt_llm = "claude-3-opus-20240229"
llm = ChatAnthropic(
    model=expt_llm,
    default_headers={"anthropic-beta": "tools-2024-04-04"},
)

structured_llm_claude = llm.with_structured_output(code, include_raw=True)


# 可选:检查工具使用中的错误,以防不稳定。
def check_claude_output(tool_output):
    """Check for parse error or failure to call the tool"""

    # 解析错误
    if tool_output["parsing_error"]:
        # 报告输出和解析错误
        print("Parsing error!")
        raw_output = str(tool_output["raw"].content)
        error = tool_output["parsing_error"]
        raise ValueError(
            f"Error parsing your output! Be sure to invoke the tool. Output: {raw_output}. \n Parse error: {error}"
        )

    # 工具未被调用
    elif not tool_output["parsed"]:
        print("Failed to invoke tool!")
        raise ValueError(
            "You did not use the provided tool! Be sure to invoke the tool to structure the output."
        )
    return tool_output


# 带输出检查的链条
code_chain_claude_raw = (
    code_gen_prompt_claude | structured_llm_claude | check_claude_output
)


def insert_errors(inputs):
    """在消息中插入工具解析的错误。"""

    # 获取错误
    error = inputs["error"]
    messages = inputs["messages"]
    messages += [
        (
            "assistant",
            f"Retry. You are required to fix the parsing errors: {error} \n\n You must invoke the provided tool.",
        )
    ]
    return {
        "messages": messages,
        "context": inputs["context"],
    }


# 这将作为后备链运行。
fallback_chain = insert_errors | code_chain_claude_raw
N = 3  # 最大重试次数
code_gen_chain_re_try = code_chain_claude_raw.with_fallbacks(
    fallbacks=[fallback_chain] * N, exception_key="error"
)


def parse_output(solution):
    """当我们在结构化输出中添加'include_raw=True'时,它将返回一个包含'raw'、'parsed'和'parsing_error'的字典。"""

    return solution["parsed"]


# 可选:带重试以纠正未能调用工具的情况。
code_gen_chain = code_gen_chain_re_try | parse_output

# 不再重试
code_gen_chain = code_gen_prompt_claude | structured_llm_claude | parse_output
# 测试
question = "How do I build a RAG chain in LCEL?"
solution = code_gen_chain.invoke(
    {"context": concatenated_content, "messages": [("user", question)]}
)
solution
code(prefix="To build a RAG (Retrieval Augmented Generation) chain in LCEL, you can use a retriever to fetch relevant documents and then pass those documents to a chat model to generate a response based on the retrieved context. Here's an example of how to do this:", imports='from langchain_expressions import retrieve, chat_completion', code='question = "What is the capital of France?"\n\nrelevant_docs = retrieve(question)\n\nresult = chat_completion(\n    model=\'openai-gpt35\', \n    messages=[\n        {{{"role": "system", "content": "Answer the question based on the retrieved context.}}},\n        {{{"role": "user", "content": \'\'\'\n            Context: {relevant_docs}\n            Question: {question}\n        \'\'\'}}\n    ]\n)\n\nprint(result)')

状态

我们的状态是一个字典,将包含与代码生成相关的键(错误、问题、代码生成)。

from typing import List
from typing_extensions import TypedDict


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        error : Binary flag for control flow to indicate whether test error was tripped
        messages : With user question, error messages, reasoning
        generation : Code solution
        iterations : Number of tries
    """

    error: str
    messages: List
    generation: str
    iterations: int

图表

我们的图表展示了上述图形中所示的逻辑流程。

# ##参数

# 麦克斯尝试了。
max_iterations = 3
# 反思
# 标志 = 'reflect'
flag = "do not reflect"

# ##节点


def generate(state: GraphState):
    """
    生成代码解决方案

    参数:
        state (dict): 当前图形状态

    返回:
        state (dict): 新键添加到状态,generation
    """

    print("---GENERATING CODE SOLUTION---")

    # 状态
    messages = state["messages"]
    iterations = state["iterations"]
    error = state["error"]

    # 我们被重新引导回生成过程,但出现了一个错误。
    if error == "yes":
        messages += [
            (
                "user",
                "Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:",
            )
        ]

    # 解决方案
    code_solution = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )
    messages += [
        (
            "assistant",
            f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}",
        )
    ]

    # 增量
    iterations = iterations + 1
    return {"generation": code_solution, "messages": messages, "iterations": iterations}


def code_check(state: GraphState):
    """
    检查代码

    参数:
        state (dict): 当前图形状态

    返回:
        state (dict): 新添加的键到状态,错误
    """

    print("---CHECKING CODE---")

    # 状态
    messages = state["messages"]
    code_solution = state["generation"]
    iterations = state["iterations"]

    # 获取解决方案组件
    imports = code_solution.imports
    code = code_solution.code

    # 检查导入项
    try:
        exec(imports)
    except Exception as e:
        print("---CODE IMPORT CHECK: FAILED---")
        error_message = [("user", f"Your solution failed the import test: {e}")]
        messages += error_message
        return {
            "generation": code_solution,
            "messages": messages,
            "iterations": iterations,
            "error": "yes",
        }

    # 检查执行情况
    try:
        exec(imports + "\n" + code)
    except Exception as e:
        print("---CODE BLOCK CHECK: FAILED---")
        error_message = [("user", f"Your solution failed the code execution test: {e}")]
        messages += error_message
        return {
            "generation": code_solution,
            "messages": messages,
            "iterations": iterations,
            "error": "yes",
        }

    # 没有错误
    print("---NO CODE TEST FAILURES---")
    return {
        "generation": code_solution,
        "messages": messages,
        "iterations": iterations,
        "error": "no",
    }


def reflect(state: GraphState):
    """
    反思错误

    参数:
        state (dict):当前图形状态

    返回:
        state (dict):添加到状态的新键,generation
    """

    print("---GENERATING CODE SOLUTION---")

    # 状态
    messages = state["messages"]
    iterations = state["iterations"]
    code_solution = state["generation"]

    # 提示反思

    # 添加反思
    reflections = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": messages}
    )
    messages += [("assistant", f"Here are reflections on the error: {reflections}")]
    return {"generation": code_solution, "messages": messages, "iterations": iterations}


# ##边缘


def decide_to_finish(state: GraphState):
    """
    决定是否完成。

    参数:
        state (dict): 当前图的状态

    返回:
        str: 下一个要调用的节点
    """
    error = state["error"]
    iterations = state["iterations"]

    if error == "no" or iterations == max_iterations:
        print("---DECISION: FINISH---")
        return "end"
    else:
        print("---DECISION: RE-TRY SOLUTION---")
        if flag == "reflect":
            return "reflect"
        else:
            return "generate"
from langgraph.graph import END, StateGraph, START

workflow = StateGraph(GraphState)

# 定义节点
workflow.add_node("generate", generate)  # 代际解决方案
workflow.add_node("check_code", code_check)  # 检查代码
workflow.add_node("reflect", reflect)  # 反映

# 构建图形
workflow.add_edge(START, "generate")
workflow.add_edge("generate", "check_code")
workflow.add_conditional_edges(
    "check_code",
    decide_to_finish,
    {
        "end": END,
        "reflect": "reflect",
        "generate": "generate",
    },
)
workflow.add_edge("reflect", "generate")
app = workflow.compile()
API Reference: END | StateGraph | START

question = "How can I directly pass a string to a runnable and use it to construct the input needed for my prompt?"
solution = app.invoke({"messages": [("user", question)], "iterations": 0, "error": ""})
---GENERATING CODE SOLUTION---
---CHECKING CODE---
---CODE IMPORT CHECK: FAILED---
---DECISION: RE-TRY SOLUTION---
---GENERATING CODE SOLUTION---
---CHECKING CODE---
---CODE IMPORT CHECK: FAILED---
---DECISION: RE-TRY SOLUTION---
---GENERATING CODE SOLUTION---
---CHECKING CODE---
---CODE BLOCK CHECK: FAILED---
---DECISION: FINISH---

solution["generation"]
code(prefix='To directly pass a string to a runnable and use it to construct the input needed for a prompt, you can use the `_from_value` method on a PromptTemplate in LCEL. Create a PromptTemplate with the desired template string, then call `_from_value` on it with a dictionary mapping the input variable names to their values. This will return a PromptValue that you can pass directly to any chain or model that accepts a prompt input.', imports='from langchain_core.prompts import PromptTemplate', code='user_string = "langchain is awesome"\n\nprompt_template = PromptTemplate.from_template("Tell me more about how {user_input}.")\n\nprompt_value = prompt_template._from_value({"user_input": user_string})\n\n# Pass the PromptValue directly to a model or chain \nchain.run(prompt_value)')

评估

这里 是一个LCEL问题的公共数据集。

我将其保存为 lcel-teacher-eval

你也可以在这里找到csv文件 这里

import langsmith

client = langsmith.Client()
# 将数据集克隆到你的租户,以便使用。
try:
    public_dataset = (
        "https://smith.langchain.com/public/326674a6-62bd-462d-88ae-eea49d503f9d/d"
    )
    client.clone_public_dataset(public_dataset)
except:
    print("Please setup LangSmith")
Dataset(name='lcel-teacher-eval', description='Eval set for LCEL teacher', data_type=<DataType.kv: 'kv'>, id=UUID('8b57696d-14ea-4f00-9997-b3fc74a16846'), created_at=datetime.datetime(2024, 9, 16, 22, 50, 4, 169288, tzinfo=datetime.timezone.utc), modified_at=datetime.datetime(2024, 9, 16, 22, 50, 4, 169288, tzinfo=datetime.timezone.utc), example_count=0, session_count=0, last_session_start_time=None, inputs_schema=None, outputs_schema=None)

自定义评估。

from langsmith.schemas import Example, Run


def check_import(run: Run, example: Example) -> dict:
    imports = run.outputs.get("imports")
    try:
        exec(imports)
        return {"key": "import_check", "score": 1}
    except Exception:
        return {"key": "import_check", "score": 0}


def check_execution(run: Run, example: Example) -> dict:
    imports = run.outputs.get("imports")
    code = run.outputs.get("code")
    try:
        exec(imports + "\n" + code)
        return {"key": "code_execution_check", "score": 1}
    except Exception:
        return {"key": "code_execution_check", "score": 0}

LangGraph与Context Stuffing的比较。

def predict_base_case(example: dict):
    """上下文填充"""
    solution = code_gen_chain.invoke(
        {"context": concatenated_content, "messages": [("user", example["question"])]}
    )
    return {"imports": solution.imports, "code": solution.code}


def predict_langgraph(example: dict):
    """LangGraph"""
    graph = app.invoke(
        {"messages": [("user", example["question"])], "iterations": 0, "error": ""}
    )
    solution = graph["generation"]
    return {"imports": solution.imports, "code": solution.code}
from langsmith.evaluation import evaluate

# 评估者
code_evalulator = [check_import, check_execution]

# 数据集
dataset_name = "lcel-teacher-eval"
# 运行基准案例
try:
    experiment_results_ = evaluate(
        predict_base_case,
        data=dataset_name,
        evaluators=code_evalulator,
        experiment_prefix=f"test-without-langgraph-{expt_llm}",
        max_concurrency=2,
        metadata={
            "llm": expt_llm,
        },
    )
except:
    print("Please setup LangSmith")
# 使用 langgraph 运行
try:
    experiment_results = evaluate(
        predict_langgraph,
        data=dataset_name,
        evaluators=code_evalulator,
        experiment_prefix=f"test-with-langgraph-{expt_llm}-{flag}",
        max_concurrency=2,
        metadata={
            "llm": expt_llm,
            "feedback": flag,
        },
    )
except:
    print("Please setup LangSmith")

结果:

  • LangGraph 超越了基本案例:添加重试循环提升了性能
  • 反思没有帮助:重试回归之前的反思与直接将错误传递回 LLM 进行对比
  • GPT-4 超越了 Claude3:Claude3 因工具使用错误分别在 Opus 和 Haiku 上出现了 3 次和 1 次失败

https://smith.langchain.com/public/78a3d858-c811-4e46-91cb-0f10ef56260b/d

优云智算