diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 85166a081..6da5b3318 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -64,7 +64,7 @@ DEFAULT_CONFIG = { "streaming_response": False, "show_tool_use_status": False, "streaming_segmented": False, - "separate_provider": False, + "separate_provider": True, }, "provider_stt_settings": { "enable": False, diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 2abe59d65..05747c3ff 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -40,11 +40,13 @@ class ProviderManager: begin_dialogs = [] user_turn = True for dialog in begin_dialogs: - bd_processed.append({ - "role": "user" if user_turn else "assistant", - "content": dialog, - "_no_save": None, # 不持久化到 db - }) + bd_processed.append( + { + "role": "user" if user_turn else "assistant", + "content": dialog, + "_no_save": None, # 不持久化到 db + } + ) user_turn = not user_turn if mood_imitation_dialogs: if len(mood_imitation_dialogs) % 2 != 0: @@ -93,15 +95,15 @@ class ProviderManager: """加载的 Text To Speech Provider 的实例""" self.embedding_provider_insts: List[Provider] = [] """加载的 Embedding Provider 的实例""" - self.inst_map = {} + self.inst_map: dict[str, Provider] = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools - self.curr_provider_inst: Provider = None + self.curr_provider_inst: Provider | None = None """默认的 Provider 实例""" - self.curr_stt_provider_inst: STTProvider = None + self.curr_stt_provider_inst: STTProvider | None = None """默认的 Speech To Text Provider 实例""" - self.curr_tts_provider_inst: TTSProvider = None + self.curr_tts_provider_inst: TTSProvider | None = None """默认的 Text To Speech Provider 实例""" self.db_helper = db_helper @@ -145,21 +147,24 @@ class ProviderManager: await self.load_provider(provider_config) # 设置默认提供商 - self.curr_provider_inst = self.inst_map.get( - self.provider_settings.get("default_provider_id") + selected_provider_id = sp.get( + "curr_provider", self.provider_settings.get("default_provider_id") ) + selected_stt_provider_id = sp.get( + "curr_provider_stt", self.provider_stt_settings.get("provider_id") + ) + selected_tts_provider_id = sp.get( + "curr_provider_tts", self.provider_tts_settings.get("provider_id") + ) + self.curr_provider_inst = self.inst_map.get(selected_provider_id) if not self.curr_provider_inst and self.provider_insts: self.curr_provider_inst = self.provider_insts[0] - self.curr_stt_provider_inst = self.inst_map.get( - self.provider_stt_settings.get("provider_id") - ) + self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id) if not self.curr_stt_provider_inst and self.stt_provider_insts: self.curr_stt_provider_inst = self.stt_provider_insts[0] - self.curr_tts_provider_inst = self.inst_map.get( - self.provider_tts_settings.get("provider_id") - ) + self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id) if not self.curr_tts_provider_inst and self.tts_provider_insts: self.curr_tts_provider_inst = self.tts_provider_insts[0] @@ -417,7 +422,7 @@ class ProviderManager: self.curr_tts_provider_inst = None if getattr(self.inst_map[provider_id], "terminate", None): - await self.inst_map[provider_id].terminate() + await self.inst_map[provider_id].terminate() # type: ignore logger.info( f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})" @@ -427,6 +432,6 @@ class ProviderManager: async def terminate(self): for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): - await provider_inst.terminate() + await provider_inst.terminate() # type: ignore # 清理 MCP Client 连接 await self.llm_tools.mcp_service_queue.put({"type": "terminate"})