跳到主要内容

使用自定义模型的AutoGen:让用户能够使用自己的推理机制

· 6 min read
Olga Vrousgou

TL;DR

AutoGen 现在支持自定义模型!此功能使用户能够定义和加载自己的模型,从而实现更灵活和个性化的推理机制。通过遵循特定协议,您可以集成自定义模型以与 AutoGen 一起使用,并通过使用您想要的任何模型/API 调用/硬编码响应来响应提示。

注意:根据您使用的模型,您可能需要调整Agent的默认提示

快速入门

一个交互且简单的方式是通过跟随这里的笔记本来开始,该笔记本从HuggingFace加载一个本地模型到AutoGen中并使用它进行推理,并对提供的类进行更改。

第一步:创建自定义模型客户端类

要开始在AutoGen中使用自定义模型,你需要创建一个遵循client.py中定义的ModelClient协议的模型客户端类。新的模型客户端类应实现这些方法:

  • create(): 返回一个实现了ModelClientResponseProtocol的响应对象(更多详情请参阅协议部分)。
  • message_retrieval(): 处理响应对象并返回一个字符串列表或消息对象列表(更多详细信息请参见协议部分)。
  • cost(): 返回响应的成本。
  • get_usage(): 返回一个字典,其键为RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]

例如,一个简单的虚拟自定义类:

class CustomModelClient:
def __init__(self, config, **kwargs):
print(f"CustomModelClient config: {config}")

def create(self, params):
num_of_responses = params.get("n", 1)

# can create my own data response class
# here using SimpleNamespace for simplicity
# as long as it adheres to the ModelClientResponseProtocol

response = SimpleNamespace()
response.choices = []
response.model = "model_name" # should match the OAI_CONFIG_LIST registration

for _ in range(num_of_responses):
text = "this is a dummy text response"
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = text
choice.message.function_call = None
response.choices.append(choice)
return response

def message_retrieval(self, response):
choices = response.choices
return [choice.message.content for choice in choices]

def cost(self, response) -> float:
response.cost = 0
return 0

@staticmethod
def get_usage(response):
return {}

步骤 2:将配置添加到 OAI_CONFIG_LIST

必要的字段是将model_client_cls设置为新类的名称(作为字符串)"model_client_cls":"CustomModelClient"。其他所有字段将传递给类的构造函数,因此您可以完全控制指定哪些参数以及如何使用它们。例如:

{
"model": "Open-Orca/Mistral-7B-OpenOrca",
"model_client_cls": "CustomModelClient",
"device": "cuda",
"n": 1,
"params": {
"max_length": 1000,
}
}

第三步:将新的自定义模型注册到将使用它的代理

如果已将带有字段 "model_client_cls":"<class name>" 的配置添加到 Agent 的配置列表中,则必须在创建代理之后且在初始化对话之前注册具有所需类的相应模型:

my_agent.register_model_client(model_client_cls=CustomModelClient, [other args that will be forwarded to CustomModelClient constructor])

model_client_cls=CustomModelClient 参数与 OAI_CONFIG_LIST 中指定的参数匹配,CustomModelClient 是遵循 ModelClient 协议的类(有关该协议的更多详细信息如下)。

如果在初始化聊天时新模型客户端在配置列表中但未注册,则会引发错误。

协议详情

一个自定义模型类可以通过多种方式创建,但需要遵循在client.py中定义并如下所示的ModelClient协议和响应结构。

当前的响应协议使用了autogen代码库中与OpenAI响应结构匹配的最小必需字段。任何与OpenAI响应结构匹配的响应协议可能更适应未来的变化,但我们从最小要求开始,以使该功能的采用更加容易。


class ModelClient(Protocol):
"""
A client class must implement the following methods:
- create must return a response object that implements the ModelClientResponseProtocol
- cost must return the cost of the response
- get_usage must return a dict with the following keys:
- prompt_tokens
- completion_tokens
- total_tokens
- cost
- model

This class is used to create a client that can be used by OpenAIWrapper.
The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed.
The message_retrieval method must be implemented to return a list of str or a list of messages from the response.
"""

RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]

class ModelClientResponseProtocol(Protocol):
class Choice(Protocol):
class Message(Protocol):
content: Optional[str]

message: Message

choices: List[Choice]
model: str

def create(self, params) -> ModelClientResponseProtocol:
...

def message_retrieval(
self, response: ModelClientResponseProtocol
) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:
"""
Retrieve and return a list of strings or a list of Choice.Message from the response.

NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,
since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.
"""
...

def cost(self, response: ModelClientResponseProtocol) -> float:
...

@staticmethod
def get_usage(response: ModelClientResponseProtocol) -> Dict:
"""Return usage summary of the response using RESPONSE_USAGE_KEYS."""
...

故障排除步骤

如果某些功能无法正常工作,请按照检查清单进行操作:

  • Make sure you have followed the client protocol and client response protocol when creating the custom model class
    • create() 方法:在 create 调用期间返回推理响应时,必须遵循 ModelClientResponseProtocol
    • message_retrieval() 方法:返回一个字符串列表或消息对象列表。如果返回的是消息对象列表,它们目前必须包含 OpenAI 的 ChatCompletion Message 对象的字段,因为在当前代码库的其余部分中,除非使用自定义代理,否则这是函数或工具调用所期望的。
    • cost()方法:返回一个整数,如果你不关心成本跟踪,你可以直接返回0
    • get_usage(): 返回一个字典,如果不关心使用情况跟踪,可以只返回一个空字典 {}
  • 确保在OAI_CONFIG_LIST中有相应的条目,并且该条目包含"model_client_cls":"<custom-model-class-name>"字段。
  • 确保你已经使用相应的配置条目注册了客户端,以及你的新类 agent.register_model_client(model_client_cls=<class-of-custom-model>, [其他可选参数])
  • 确保在OAI_CONFIG_LIST中定义的所有自定义模型都已注册。
  • 任何其他故障排除可能需要在自定义代码本身中进行。

结论

通过使用自定义模型的能力,AutoGen现在为您的AI应用程序提供了更大的灵活性和功能。无论您是训练了自己的模型还是希望使用特定的预训练模型,AutoGen都能满足您的需求。编程愉快!