跳到主要内容

AgentOptimizer:一种训练您的LLM代理的代理方式

Open In Colab Open on GitHub

AutoGen 提供了由LLM、工具或人类驱动的可对话代理,这些代理可以通过自动聊天共同执行任务。该框架允许通过多代理对话使用工具和人类参与。请参阅有关此功能的文档这里

在传统的机器学习流程中,我们通过在训练集上的损失来更新模型参数来训练模型,而在LLM代理的时代,我们应该如何训练一个代理呢?在这里,我们向代理训练迈出了初始的一步。受到OpenAI提供的函数调用能力的启发,我们将模型参数与代理函数/技能进行了类比,并根据代理在训练集上的历史表现来更新代理的函数/技能。作为一种代理式的训练方法,我们的方法可以帮助增强代理的能力,而无需访问LLM的参数。

在本笔记本中,我们介绍了一个新类‘AgentOptimizer’,它能够根据历史对话记录改进一个助手-用户代理对的函数列表。此功能将支持代理提高其解决与先前任务相同类型问题的能力。具体而言,给定一组训练数据,AgentOptimizer将迭代地提示LLM优化现有的AssistantAgent和UserProxyAgent的函数列表,并在必要时提供代码实现。它还包括两种策略,即回滚和早停,以简化训练过程。在示例场景中,我们测试了提出的AgentOptimizer在解决MATH数据集中的问题的效果。

AgentOptimizer

更多信息可以在 论文中找到。

作者:- Shaokun Zhang,宾夕法尼亚州立大学博士生 - Jieyu Zhang,华盛顿大学博士生

import copy
import json
import os
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

from openai import BadRequestError

import autogen
from autogen import config_list_from_json
from autogen.agentchat import Agent
from autogen.agentchat.contrib.agent_optimizer import AgentOptimizer
from autogen.agentchat.contrib.math_user_proxy_agent import MathUserProxyAgent
from autogen.code_utils import extract_code
from autogen.math_utils import get_answer

使用function_call的MathUserProxy

此代理是自定义的MathUserProxy,继承自其父类

它支持使用function_call和python来解决数学问题。

def is_termination_msg_mathchat(message):
"""Check if a message is a termination message."""
if isinstance(message, dict):
message = message.get("content")
if message is None:
return False
cb = extract_code(message)
contain_code = False
for c in cb:
if c[0] == "python":
contain_code = True
break
if message.rstrip().find("TERMINATE") >= 0:
return True
return not contain_code and get_answer(message) is not None and get_answer(message) != ""


class MathUserProxyAgent(MathUserProxyAgent):
MAX_CONSECUTIVE_AUTO_REPLY = 15
DEFAULT_REPLY = "Continue. Please keep solving the problem until you need to query. (If you get to the answer, put it in \\boxed{}.)"
PROMPTS = """Let's solve a math problem.
Query requirements:
You should always use the 'print' function for the output and use fractions/radical forms instead of decimals.
You can use packages like sympy to help you.
You must follow the formats below to write your code:
```python
# your code
```
If some packages are missing, you could also suggest a code to install the corresponding package.

Please follow this process:
1. Solve the problem step by step (do not over-divide the steps).
2. Take out any queries that can be asked through Python code (for example, any calculations or equations that can be calculated) and functions you know in the context of this conversation.

Please
(1) do not mix suggested Python codes and function calls in one step.
(2) You MUST remember that you don’t have a function named "python" available.

You must follow the formats below to write your Python code:
```python
# your code
```

3. Wait for me to give the results or wait for the executed results of the function call.
4. Continue if you think the result is correct. If the result is invalid or unexpected, please correct your query or reasoning.

After all the queries are run and you get the answer, put the answer in \\boxed{}.

Problem:
"""

def __init__(
self,
name: Optional[str] = "MathChatAgent",
is_termination_msg: Optional[Callable[[Dict], bool]] = is_termination_msg_mathchat,
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "NEVER",
default_auto_reply: Optional[Union[str, Dict, None]] = DEFAULT_REPLY,
max_invalid_q_per_step=3,
**kwargs,
):
super().__init__(
name=name,
is_termination_msg=is_termination_msg,
human_input_mode=human_input_mode,
default_auto_reply=default_auto_reply,
max_invalid_q_per_step=max_invalid_q_per_step,
**kwargs,
)
del self._reply_func_list[2]
self.register_reply([Agent, None], MathUserProxyAgent._generate_math_reply, position=4)
del self._reply_func_list[3]
self.register_reply(
trigger=autogen.ConversableAgent, reply_func=MathUserProxyAgent.generate_function_call_reply, position=3
)
self.register_reply(
trigger=autogen.ConversableAgent, reply_func=MathUserProxyAgent._check_final_result, position=0
)

self.max_function_call_trial = 3
self.query = None
self.answer = None
self.is_correct = None

def generate_function_call_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[autogen.ConversableAgent] = None,
config: Optional[Any] = None,
) -> Tuple[bool, Union[Dict, None]]:
"""Generate a reply using function call."""
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
if "function_call" in message:
is_exec_success, func_return = self.execute_function(message["function_call"])
if is_exec_success:
self.max_function_call_trial = 3
return True, func_return
else:
if self.max_function_call_trial == 0:
error_message = func_return["content"]
self.max_function_call_trial = 3
return (
True,
"The func is executed failed many times. "
+ error_message
+ ". Please directly reply me with TERMINATE. We need to terminate the conversation.",
)
else:
revise_prompt = "You may make a wrong function call (It may due the arguments you provided doesn't fit the function arguments like missing required positional argument). \
If you think this error occurs due to you make a wrong function arguments input and you could make it success, please try to call this function again using the correct arguments. \
Otherwise, the error may be caused by the function itself. Please directly reply me with TERMINATE. We need to terminate the conversation. "
error_message = func_return["content"]
return True, "The func is executed failed." + error_message + revise_prompt
return False, None

def initiate_chat(
self,
recipient,
answer: None,
silent: Optional[bool] = False,
**context,
):
self.query = context["problem"]
if not isinstance(answer, str):
answer = str(answer)
if answer.endswith(".0"):
answer = answer[:-2]
self._answer = answer
else:
self._answer = answer

self.is_correct = None

self._prepare_chat(recipient, True)
error_message = None
try:
prompt = self.PROMPTS + context["problem"]
self.send(prompt, recipient, silent=silent)
except BadRequestError as e:
error_message = str(e)
self.is_correct = 0
print("error information: {}".format(error_message))

recipient.reset()
is_correct = copy.deepcopy(self.is_correct)
self._reset()
return is_correct

def _check_final_result(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[autogen.Agent] = None,
config: Optional[Any] = None,
):

messages = messages[-1]
if isinstance(messages, dict):
messages = messages.get("content")
if messages is None:
return False, None

cb = extract_code(messages)
contain_code = False
for c in cb:
if c[0] == "python":
contain_code = True
break
if not contain_code and get_answer(messages) is not None and get_answer(messages) != "":
if get_answer(messages) == self._answer:
self.is_correct = 1
return True, "The result is Correct. Please reply me with TERMINATE."
else:
self.is_correct = 0
return False, None
else:
return False, None

def _reset(self):
super()._reset()
self.max_function_call_trial = 3
self.is_correct = None
self.query = None
self.answer = None

加载数据集

MATAH数据集包含12,500个具有挑战性的竞赛数学问题。MATH中的每个问题都有一个完整的逐步解法,可用于教模型生成答案推导和解释。

我们严格遵守 train/test 的分割方式,采用Craft的数据集。请指定 您自己的数据集路径。这里我们抽取了前10个代数 问题作为示例。

test_data, train_data = [], []
with open("MATH/dataset/algebra.jsonl", "r", encoding="utf-8") as f:
for line in f:
test_data.append(json.loads(line))
with open("MATH/dataset/train/algebra.jsonl", "r", encoding="utf-8") as f:
for line in f:
train_data.append(json.loads(line))
test_data, train_data = test_data[0:10], train_data[0:10]

代理构建

构建用于解决这些问题的MathUserProxyAgent和AssistantAgent。在这里,我们使用gpt-4-1106-preview来构建AssistantAgent。

llm_config = {
"config_list": [
{
"model": "gpt-4-1106-preview",
"api_type": "azure",
"api_key": os.environ["AZURE_OPENAI_API_KEY"],
"base_url": "https://ENDPOINT.openai.azure.com/",
"api_version": "2023-07-01-preview",
}
]
}

assistant = autogen.AssistantAgent(
name="assistant",
system_message="You are a helpful assistant.",
llm_config=llm_config,
)
user_proxy = MathUserProxyAgent(
name="mathproxyagent",
human_input_mode="NEVER",
code_execution_config={"work_dir": "_output", "use_docker": False},
)

测试无代理优化

以下是未经代理优化过程获取性能的代码。

在这种情况下,AssistantAgent 和 MathUserProxyAgent 没有任何函数调用,而是仅使用 Python 解决问题。

sum = 0
for index, query in enumerate(test_data):
is_correct = user_proxy.initiate_chat(recipient=assistant, answer=query["answer"], problem=query["question"])
print(is_correct)
sum += is_correct
success_rate_without_agent_training = sum / 10

代理训练

然后,我们使用AgentOptimizer通过根据历史对话和性能优化函数调用来迭代优化代理。AgentOptimizer在每次迭代中生成register_for_llm和register_for_executor,随后分别用于更新assistant和user_proxy代理。这里我们对这两个代理进行了十个epoch的优化。

EPOCH = 10
optimizer_model = "gpt-4-1106-preview"
optimizer = AgentOptimizer(max_actions_per_step=3, llm_config=llm_config, optimizer_model=optimizer_model)
for i in range(EPOCH):
for index, query in enumerate(train_data):
is_correct = user_proxy.initiate_chat(assistant, answer=query["answer"], problem=query["question"])
history = assistant.chat_messages_for_summary(user_proxy)
optimizer.record_one_conversation(history, is_satisfied=is_correct)
register_for_llm, register_for_exector = optimizer.step()
for item in register_for_llm:
assistant.update_function_signature(**item)
if len(register_for_exector.keys()) > 0:
user_proxy.register_function(function_map=register_for_exector)

测试代理优化

代理优化后,agents 从 AgentOptimizers 获取了10次优化迭代后的函数列表,如下所示。

然后我们展示了有/没有代理优化过程的最终性能。我们观察到优化后的代理明显更好。

sum = 0
for index, query in enumerate(test_data):
is_correct = user_proxy.initiate_chat(recipient=assistant, answer=query["answer"], problem=query["question"])
sum += is_correct
success_rate_with_agent_training = sum / 10
print(
"------------------------------------------------Functions learned------------------------------------------------"
)
for func in assistant.llm_config["functions"]:
print(func["name"] + ": " + func["description"] + "\n")
print("------------------------------------------------Summary------------------------------------------------\n")
print("success_rate_without_agent_training: {average}%\n".format(average=success_rate_without_agent_training * 100))
print("success_rate_with_agent_training: {average}%\n".format(average=success_rate_with_agent_training * 100))
------------------------------------------------Functions learned------------------------------------------------
evaluate_expression: Evaluate arithmetic or mathematical expressions provided as strings.

calculate_compound_interest_principal: Calculate the principal amount needed to achieve a certain future value with quarterly compound interest.

solve_linear_system: Solve a system of linear equations represented as coefficients and variables.

------------------------------------------------Summary------------------------------------------------

success_rate_without_agent_training: 60.0%

success_rate_with_agent_training: 90.0%
优云智算