From ca50618af6e137ffe8768834055d83de41ce9086 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 27 May 2025 15:01:58 +0800 Subject: [PATCH] perf: load providers when llm config is off and rebooting astrbot fixes: #1466 --- .../core/pipeline/preprocess_stage/stage.py | 51 ++++++++-------- .../process_stage/method/llm_request.py | 4 ++ .../core/pipeline/result_decorate/stage.py | 5 +- astrbot/core/provider/manager.py | 58 ++++++------------- packages/astrbot/main.py | 33 ++++------- 5 files changed, 61 insertions(+), 90 deletions(-) diff --git a/astrbot/core/pipeline/preprocess_stage/stage.py b/astrbot/core/pipeline/preprocess_stage/stage.py index 96d7ff4b7..3e89e1c3e 100644 --- a/astrbot/core/pipeline/preprocess_stage/stage.py +++ b/astrbot/core/pipeline/preprocess_stage/stage.py @@ -46,28 +46,29 @@ class PreProcessStage(Stage): stt_provider = ( self.plugin_manager.context.provider_manager.curr_stt_provider_inst ) - if stt_provider: - message_chain = event.get_messages() - for idx, component in enumerate(message_chain): - if isinstance(component, Record) and component.url: - path = component.url.removeprefix("file://") - retry = 5 - for i in range(retry): - try: - result = await stt_provider.get_text(audio_url=path) - if result: - logger.info("语音转文本结果: " + result) - message_chain[idx] = Plain(result) - event.message_str += result - event.message_obj.message_str += result - break - except FileNotFoundError as e: - # napcat workaround - logger.warning(e) - logger.warning(f"重试中: {i + 1}/{retry}") - await asyncio.sleep(0.5) - continue - except BaseException as e: - logger.error(traceback.format_exc()) - logger.error(f"语音转文本失败: {e}") - break + if not stt_provider: + return + message_chain = event.get_messages() + for idx, component in enumerate(message_chain): + if isinstance(component, Record) and component.url: + path = component.url.removeprefix("file://") + retry = 5 + for i in range(retry): + try: + result = await stt_provider.get_text(audio_url=path) + if result: + logger.info("语音转文本结果: " + result) + message_chain[idx] = Plain(result) + event.message_str += result + event.message_obj.message_str += result + break + except FileNotFoundError as e: + # napcat workaround + logger.warning(e) + logger.warning(f"重试中: {i + 1}/{retry}") + await asyncio.sleep(0.5) + continue + except BaseException as e: + logger.error(traceback.format_exc()) + logger.error(f"语音转文本失败: {e}") + break diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 28745f2c5..d22a1f453 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -67,6 +67,10 @@ class LLMRequestSubStage(Stage): ) -> Union[None, AsyncGenerator[None, None]]: req: ProviderRequest = None + if not self.ctx.astrbot_config["provider_settings"]["enable"]: + logger.debug("未启用 LLM 能力,跳过处理。") + return + provider = self.ctx.plugin_manager.context.get_using_provider() if provider is None: return diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 9adaf2da6..7a12788a0 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -169,11 +169,14 @@ class ResultDecorateStage(Stage): result.chain = new_chain # TTS + tts_provider = ( + self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst + ) if ( self.ctx.astrbot_config["provider_tts_settings"]["enable"] and result.is_llm_result() + and tts_provider ): - tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst new_chain = [] for comp in result.chain: if isinstance(comp, Plain) and len(comp.text) > 1: diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 68aa98e89..78337ce95 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -21,9 +21,9 @@ class ProviderManager: self.selected_provider_id = sp.get("curr_provider") self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id") self.selected_tts_provider_id = self.provider_settings.get("provider_id") - self.provider_enabled = self.provider_settings.get("enable", False) - self.stt_enabled = self.provider_stt_settings.get("enable", False) - self.tts_enabled = self.provider_tts_settings.get("enable", False) + # self.provider_enabled = self.provider_settings.get("enable", False) + # self.stt_enabled = self.provider_stt_settings.get("enable", False) + # self.tts_enabled = self.provider_tts_settings.get("enable", False) # 人格情景管理 # 目前没有拆成独立的模块 @@ -101,6 +101,8 @@ class ProviderManager: self.inst_map = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools + self.default_provider_inst: Provider = None + """默认的 Provider 实例。第 0 个或者用户以前指定的 Provider 实例""" self.curr_provider_inst: Provider = None """当前使用的 Provider 实例""" self.curr_stt_provider_inst: STTProvider = None @@ -119,14 +121,9 @@ class ProviderManager: for provider_config in self.providers_config: await self.load_provider(provider_config) - if not self.curr_provider_inst: - logger.warning("未启用任何用于 文本生成 的提供商适配器。") - - if self.stt_enabled and not self.curr_stt_provider_inst: - logger.warning("未启用任何用于 语音转文本 的提供商适配器。") - - if self.tts_enabled and not self.curr_tts_provider_inst: - logger.warning("未启用任何用于 文本转语音 的提供商适配器。") + self.default_provider_inst = self.inst_map.get(self.selected_provider_id) + if not self.default_provider_inst and self.provider_insts: + self.default_provider_inst = self.provider_insts[0] # 初始化 MCP Client 连接 asyncio.create_task( @@ -245,15 +242,12 @@ class ProviderManager: await inst.initialize() self.stt_provider_insts.append(inst) - if ( - self.selected_stt_provider_id == provider_config["id"] - and self.stt_enabled - ): + if self.selected_stt_provider_id == provider_config["id"]: self.curr_stt_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。" ) - if not self.curr_stt_provider_inst and self.stt_enabled: + if not self.curr_stt_provider_inst: self.curr_stt_provider_inst = inst elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: @@ -266,15 +260,12 @@ class ProviderManager: await inst.initialize() self.tts_provider_insts.append(inst) - if ( - self.selected_tts_provider_id == provider_config["id"] - and self.tts_enabled - ): + if self.selected_tts_provider_id == provider_config["id"]: self.curr_tts_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。" ) - if not self.curr_tts_provider_inst and self.tts_enabled: + if not self.curr_tts_provider_inst: self.curr_tts_provider_inst = inst elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: @@ -291,15 +282,12 @@ class ProviderManager: await inst.initialize() self.provider_insts.append(inst) - if ( - self.selected_provider_id == provider_config["id"] - and self.provider_enabled - ): + if self.selected_provider_id == provider_config["id"]: self.curr_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。" ) - if not self.curr_provider_inst and self.provider_enabled: + if not self.curr_provider_inst: self.curr_provider_inst = inst self.inst_map[provider_config["id"]] = inst @@ -322,11 +310,7 @@ class ProviderManager: if len(self.provider_insts) == 0: self.curr_provider_inst = None - elif ( - self.curr_provider_inst is None - and len(self.provider_insts) > 0 - and self.provider_enabled - ): + elif self.curr_provider_inst is None and len(self.provider_insts) > 0: self.curr_provider_inst = self.provider_insts[0] self.selected_provider_id = self.curr_provider_inst.meta().id logger.info( @@ -335,11 +319,7 @@ class ProviderManager: if len(self.stt_provider_insts) == 0: self.curr_stt_provider_inst = None - elif ( - self.curr_stt_provider_inst is None - and len(self.stt_provider_insts) > 0 - and self.stt_enabled - ): + elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0: self.curr_stt_provider_inst = self.stt_provider_insts[0] self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id logger.info( @@ -348,11 +328,7 @@ class ProviderManager: if len(self.tts_provider_insts) == 0: self.curr_tts_provider_inst = None - elif ( - self.curr_tts_provider_inst is None - and len(self.tts_provider_insts) > 0 - and self.tts_enabled - ): + elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0: self.curr_tts_provider_inst = self.tts_provider_insts[0] self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id logger.info( diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 92ef6bfea..6ea0a25cb 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -61,7 +61,6 @@ class RstScene(Enum): version="4.0.0", ) class Main(star.Star): - def __init__(self, context: star.Context) -> None: self.context = context cfg = context.get_config() @@ -216,9 +215,7 @@ class Main(star.Star): """获取已经安装的插件列表。""" plugin_list_info = "已加载的插件:\n" for plugin in self.context.get_all_stars(): - plugin_list_info += ( - f"- `{plugin.name}` By {plugin.author}: {plugin.desc}" - ) + plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}" if not plugin.activated: plugin_list_info += " (未启用)" plugin_list_info += "\n" @@ -271,9 +268,7 @@ class Main(star.Star): event.set_result(MessageEventResult().message("安装插件成功。")) except Exception as e: logger.error(f"安装插件失败: {e}") - event.set_result( - MessageEventResult().message(f"安装插件失败: {e}") - ) + event.set_result(MessageEventResult().message(f"安装插件失败: {e}")) return @plugin.command("help") @@ -319,7 +314,6 @@ class Main(star.Star): ret += "更多帮助信息请查看插件仓库 README。" event.set_result(MessageEventResult().message(ret).use_t2i(False)) - @filter.command("t2i") async def t2i(self, event: AstrMessageEvent): """开关文本转图片""" @@ -426,18 +420,13 @@ UID: {user_id} 此 ID 可用于设置管理员。 ): """查看或者切换 LLM Provider""" - if not self.context.get_using_provider(): - event.set_result( - MessageEventResult().message("未找到任何 LLM 提供商。请先配置。") - ) - return - if idx is None: ret = "## 载入的 LLM 提供商\n" for idx, llm in enumerate(self.context.get_all_providers()): id_ = llm.meta().id ret += f"{idx + 1}. {id_} ({llm.meta().model})" - if self.context.get_using_provider().meta().id == id_: + provider_using = self.context.get_using_provider() + if provider_using and provider_using.meta().id == id_: ret += " (当前使用)" ret += "\n" @@ -1033,7 +1022,11 @@ UID: {user_id} 此 ID 可用于设置管理员。 message.unified_msg_origin, cid ) if not conversation: - message.set_result(MessageEventResult().message("请先进入一个对话。可以使用 /new 创建。")) + message.set_result( + MessageEventResult().message( + "请先进入一个对话。可以使用 /new 创建。" + ) + ) if not conversation.persona_id and not conversation.persona_id == "[%None]": curr_persona_name = ( self.context.provider_manager.selected_default_persona["name"] @@ -1176,7 +1169,7 @@ UID: {user_id} 此 ID 可用于设置管理员。 @filter.command("gewe_code") async def gewe_code(self, event: AstrMessageEvent, code: str): """保存 gewechat 验证码""" - code_path = os.path.join(get_astrbot_data_path(), "temp","gewe_code") + code_path = os.path.join(get_astrbot_data_path(), "temp", "gewe_code") with open(code_path, "w", encoding="utf-8") as f: f.write(code) yield event.plain_result("验证码已保存。") @@ -1462,9 +1455,3 @@ UID: {user_id} 此 ID 可用于设置管理员。 plugin_cfg["reset"] = reset_cfg alter_cmd_cfg["astrbot"] = plugin_cfg sp.put("alter_cmd", alter_cmd_cfg) - - @filter.command("test") - async def test_to(self, event: AstrMessageEvent): - import asyncio - await asyncio.sleep(10) - yield event.plain_result("OK")