使用自定义模型
在使用GroupChatManager
时,我们需要在构造函数中传递一个GroupChat
对象,这是一个负责收集代理、从提示模板准备消息并选择发言人(最终使用here中描述的speaker_selection_method
)的数据类。
为此,GroupChat 内部初始化了两个 ConversableAgent 实例。为了控制 GroupChat 中实例化的代理所使用的模型客户端,GroupChatManager 已经接收了传入的 llm_config
,可以设置可选的 model_client_cls
属性。
示例
首先我们需要定义一个llm_config
并定义一些将参与群聊的代理:
from autogen import GroupChat, ConversableAgent, GroupChatManager, UserProxyAgent
from somewhere import MyModelClient
# Define the custom model configuration
llm_config = {
"config_list": [
{
"model": "gpt-3.5-turbo",
"model_client_cls": "MyModelClient"
}
]
}
# Initialize the agents with the custom model
agent1 = ConversableAgent(
name="Agent 1",
llm_config=llm_config
)
agent1.register_model_client(model_client_cls=MyModelClient)
agent2 = ConversableAgent(
name="Agent 2",
llm_config=llm_config
)
agent2.register_model_client(model_client_cls=MyModelClient)
agent3 = ConversableAgent(
name="Agent 2",
llm_config=llm_config
)
agent3.register_model_client(model_client_cls=MyModelClient)
user_proxy = UserProxyAgent(name="user", llm_config=llm_config, code_execution_config={"use_docker": False})
user_proxy.register_model_client(MyModelClient)
请注意,这里展示的代理定义是最小的,可能不适合您的需求。唯一目的是展示群聊场景的基本设置。
然后我们创建一个GroupChat
,如果我们希望GroupChat使用的底层代理能够使用我们自定义的客户端,我们将在model_client_cls
属性中传递它。
最后我们创建了一个GroupChatManager
的实例,并将配置传递给它。这个相同的配置将被转发到GroupChat,它(如果需要)将自动处理自定义模型的注册。
# Create a GroupChat instance and add the agents
group_chat = GroupChat(agents=[agent1, agent2, agent3], messages=[], model_client_cls=MyModelClient)
# Create the GroupChatManager with the GroupChat, UserProxy, and model configuration
chat_manager = GroupChatManager(groupchat=group_chat, llm_config=llm_config)
chat_manager.register_model_client(model_client_cls=MyModelClient)
# Initiate the chat using the UserProxy
user_proxy.initiate_chat(chat_manager, initial_message="Suggest me the most trending papers in microbiology that you think might interest me")
该属性可以是一个类或一组类,这些类遵循ModelClient
协议(有关定义自定义模型客户端类的更多信息,请参见此链接)。
需要注意的是,在使用Azure OpenAI、OpenAI或库原生支持的其他非自定义模型时,不必定义model_client_cls
。