From 84c450aef9f994faa921855d3ef2072d12dae356 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 27 Jul 2024 04:25:27 -0400 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=20Provider=20?= =?UTF-8?q?=E6=B3=A8=E5=86=8C=E6=8E=A5=E5=8F=A3=EF=BC=9B=E6=96=B0=E5=A2=9E?= =?UTF-8?q?=20provider=20=E6=8C=87=E4=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/bootstrap.py | 1 + astrbot/message/handler.py | 10 ++++++-- model/command/internal_handler.py | 31 ++++++++++++++++++++++++- model/command/manager.py | 3 ++- type/command.py | 10 ++++++++ type/message_event.py | 7 ++---- type/types.py | 9 +++++++ util/agent/func_call.py | 4 ++-- util/agent/web_searcher.py | 10 ++++---- util/t2i/strategies/network_strategy.py | 2 +- 10 files changed, 70 insertions(+), 17 deletions(-) diff --git a/astrbot/bootstrap.py b/astrbot/bootstrap.py index ef548dd39..30a2b5bf4 100644 --- a/astrbot/bootstrap.py +++ b/astrbot/bootstrap.py @@ -101,6 +101,7 @@ class AstrBotBootstrap(): self.openai_command_handler = OpenAIOfficialCommandHandler(self.command_manager) self.llm_instance = ProviderOpenAIOfficial(self.context) self.openai_command_handler.set_provider(self.llm_instance) + self.context.register_provider("internal_openai", self.llm_instance) logger.info("已启用 OpenAI API 支持。") def load_plugins(self): diff --git a/astrbot/message/handler.py b/astrbot/message/handler.py index 0ad9ac39f..59ac1d11f 100644 --- a/astrbot/message/handler.py +++ b/astrbot/message/handler.py @@ -115,6 +115,9 @@ class MessageHandler(): self.nicks = self.context.nick self.provider = provider self.reply_prefix = self.context.reply_prefix + + def set_provider(self, provider: Provider): + self.provider = provider async def handle(self, message: AstrMessageEvent, llm_provider: Provider = None) -> MessageResult: ''' @@ -148,7 +151,8 @@ class MessageHandler(): assert(isinstance(cmd_res, CommandResult)) return MessageResult( cmd_res.message_chain, - is_command_call=True + is_command_call=True, + use_t2i=cmd_res.is_use_t2i ) # check if the message is a llm-wake-up command @@ -178,7 +182,9 @@ class MessageHandler(): llm_result = await web_searcher.web_search(msg_plain, provider, message.session_id, inner_provider) else: llm_result = await provider.text_chat( - msg_plain, message.session_id, image_url + prompt=msg_plain, + session_id=message.session_id, + image_url=image_url ) except BaseException as e: logger.error(traceback.format_exc()) diff --git a/model/command/internal_handler.py b/model/command/internal_handler.py index 50f05527a..2fad5ef22 100644 --- a/model/command/internal_handler.py +++ b/model/command/internal_handler.py @@ -26,7 +26,36 @@ class InternalCommandHandler: self.manager.register("websearch", "网页搜索开关", 10, self.web_search) self.manager.register("t2i", "文本转图片开关", 10, self.t2i_toggle) self.manager.register("myid", "获取你在此平台上的ID", 10, self.myid) - + self.manager.register("provider", "查看和切换当前使用的 LLM 资源来源", 10, self.provider) + + def provider(self, message: AstrMessageEvent, context: Context): + if len(context.llms) == 0: + return CommandResult().message("当前没有加载任何 LLM 资源。") + + tokens = self.manager.command_parser.parse(message.message_str) + + if tokens.len == 1: + ret = "## 当前载入的 LLM 资源\n" + for idx, llm in enumerate(context.llms): + ret += f"{idx}. {llm.llm_name}" + if llm.origin: + ret += f" (来源: {llm.origin})" + if context.message_handler.provider == llm.llm_instance: + ret += " (当前使用)" + ret += "\n" + + ret += "\n使用 provider <序号> 切换 LLM 资源。" + return CommandResult().message(ret) + else: + try: + idx = int(tokens.get(1)) + if idx >= len(context.llms): + return CommandResult().message("provider: 无效的序号。") + context.message_handler.set_provider(context.llms[idx].llm_instance) + return CommandResult().message(f"已经成功切换到 LLM 资源 {context.llms[idx].llm_name}。") + except BaseException as e: + return CommandResult().message("provider: 参数错误。") + def set_nick(self, message: AstrMessageEvent, context: Context): message_str = message.message_str if message.role != "admin": diff --git a/model/command/manager.py b/model/command/manager.py index 4dcfcbbb5..f79677089 100644 --- a/model/command/manager.py +++ b/model/command/manager.py @@ -73,7 +73,8 @@ class CommandManager(): if message_str.startswith(command): logger.info(f"触发 {command} 指令。") command_result = await self.execute_handler(command, message_event, context) - return command_result + if command_result.hit: + return command_result async def execute_handler(self, command: str, diff --git a/type/command.py b/type/command.py index 5411c2991..a8723063f 100644 --- a/type/command.py +++ b/type/command.py @@ -24,6 +24,7 @@ class CommandResult(): self.success = success self.message_chain = message_chain self.command_name = command_name + self.is_use_t2i = None # default def message(self, message: str): ''' @@ -61,6 +62,15 @@ class CommandResult(): ''' self.message_chain = [Image.fromFileSystem(path), ] return self + + # def use_t2i(self, use_t2i: bool): + # ''' + # 设置是否使用文本转图片服务。如果不设置,则跟随用户的设置。 + + # CommandResult().use_t2i(False) + # ''' + # self.is_use_t2i = use_t2i + # return self def _result_tuple(self): return (self.success, self.message_chain, self.command_name) diff --git a/type/message_event.py b/type/message_event.py index 12d3931c8..222ac91e0 100644 --- a/type/message_event.py +++ b/type/message_event.py @@ -43,13 +43,10 @@ class AstrMessageEvent(): context, session_id) return ame - - - - - + @dataclass class MessageResult(): result_message: Union[str, list] is_command_call: Optional[bool] = False + use_t2i: Optional[bool] = None # None 为跟随用户设置 callback: Optional[callable] = None diff --git a/type/types.py b/type/types.py index 4acf0fc3a..e34515a4d 100644 --- a/type/types.py +++ b/type/types.py @@ -9,6 +9,7 @@ from util.updator.astrbot_updator import AstrBotUpdator from util.image_uploader import ImageUploader from util.updator.plugin_updator import PluginUpdator from model.plugin.command import PluginCommandBridge +from model.provider.provider import Provider class Context: @@ -68,6 +69,14 @@ class Context: ''' task = asyncio.create_task(coro, name=task_name) self.ext_tasks.append(task) + + def register_provider(self, llm_name: str, provider: Provider, origin: str = ''): + ''' + 注册一个提供 LLM 资源的 Provider。 + + `provider`: Provider 对象。即你的实现需要继承 Provider 类。至少应该实现 text_chat() 方法。 + ''' + self.llms.append(RegisteredLLM(llm_name, provider, origin)) def find_platform(self, platform_name: str) -> RegisteredPlatform: for platform in self.platforms: diff --git a/util/agent/func_call.py b/util/agent/func_call.py index a642c0181..ffacf242b 100644 --- a/util/agent/func_call.py +++ b/util/agent/func_call.py @@ -122,7 +122,7 @@ class FuncCall(): _c = 0 while _c < 3: try: - res = self.provider.text_chat(prompt, session_id) + res = self.provider.text_chat(prompt=prompt, session_id=session_id) if res.find('```') != -1: res = res[res.find('```json') + 7: res.rfind('```')] gu.log("REVGPT func_call json result", @@ -187,7 +187,7 @@ class FuncCall(): _c = 0 while _c < 5: try: - res = self.provider.text_chat(after_prompt, session_id) + res = self.provider.text_chat(prompt=after_prompt, session_id=session_id) # 截取```之间的内容 gu.log( "DEBUG BEGIN", bg=gu.BG_COLORS["yellow"], fg=gu.FG_COLORS["white"]) diff --git a/util/agent/web_searcher.py b/util/agent/web_searcher.py index 70a6f1121..6badf8188 100644 --- a/util/agent/web_searcher.py +++ b/util/agent/web_searcher.py @@ -127,7 +127,7 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False): function_invoked_ret = "" if official_fc: # we use official function-calling - result = await provider.text_chat(prompt, session_id, tools=new_func_call.get_func()) + result = await provider.text_chat(prompt=prompt, session_id=session_id, tools=new_func_call.get_func()) if isinstance(result, Function): logger.debug(f"web_searcher - function-calling: {result}") func_obj = None @@ -136,14 +136,14 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False): func_obj = i["func_obj"] break if not func_obj: - return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)" + return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)" try: args = json.loads(result.arguments) function_invoked_ret = await func_obj(**args) has_func = True except BaseException as e: traceback.print_exc() - return await provider.text_chat(prompt, session_id) + "\n(网页搜索失败, 此为默认回复)" + return await provider.text_chat(prompt=prompt, session_id=session_id, ) + "\n(网页搜索失败, 此为默认回复)" else: return result else: @@ -162,7 +162,7 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False): has_func = True if has_func: - await provider.forget(session_id) + await provider.forget(session_id=session_id, ) summary_prompt = f""" 你是一个专业且高效的助手,你的任务是 1. 根据下面的相关材料对用户的问题 `{prompt}` 进行总结; @@ -178,6 +178,6 @@ async def web_search(prompt, provider: Provider, session_id, official_fc=False): # 相关材料 {function_invoked_ret}""" - ret = await provider.text_chat(summary_prompt, session_id) + ret = await provider.text_chat(prompt=summary_prompt, session_id=session_id) return ret return function_invoked_ret diff --git a/util/t2i/strategies/network_strategy.py b/util/t2i/strategies/network_strategy.py index 67f406d6a..b05022b3a 100644 --- a/util/t2i/strategies/network_strategy.py +++ b/util/t2i/strategies/network_strategy.py @@ -32,7 +32,7 @@ class NetworkRenderStrategy(RenderStrategy): "options": { "full_page": True, "type": "jpeg", - "quality": 25, + "quality": 40, } }