perf: load providers when llm config is off and rebooting astrbot

fixes: #1466
This commit is contained in:
Soulter
2025-05-27 15:01:58 +08:00
parent 29c07ba83e
commit ca50618af6
5 changed files with 61 additions and 90 deletions
+26 -25
View File
@@ -46,28 +46,29 @@ class PreProcessStage(Stage):
stt_provider = (
self.plugin_manager.context.provider_manager.curr_stt_provider_inst
)
if stt_provider:
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.url:
path = component.url.removeprefix("file://")
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break
if not stt_provider:
return
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.url:
path = component.url.removeprefix("file://")
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break
@@ -67,6 +67,10 @@ class LLMRequestSubStage(Stage):
) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
provider = self.ctx.plugin_manager.context.get_using_provider()
if provider is None:
return
@@ -169,11 +169,14 @@ class ResultDecorateStage(Stage):
result.chain = new_chain
# TTS
tts_provider = (
self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
)
if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
and result.is_llm_result()
and tts_provider
):
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
+17 -41
View File
@@ -21,9 +21,9 @@ class ProviderManager:
self.selected_provider_id = sp.get("curr_provider")
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
self.provider_enabled = self.provider_settings.get("enable", False)
self.stt_enabled = self.provider_stt_settings.get("enable", False)
self.tts_enabled = self.provider_tts_settings.get("enable", False)
# self.provider_enabled = self.provider_settings.get("enable", False)
# self.stt_enabled = self.provider_stt_settings.get("enable", False)
# self.tts_enabled = self.provider_tts_settings.get("enable", False)
# 人格情景管理
# 目前没有拆成独立的模块
@@ -101,6 +101,8 @@ class ProviderManager:
self.inst_map = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
self.default_provider_inst: Provider = None
"""默认的 Provider 实例。第 0 个或者用户以前指定的 Provider 实例"""
self.curr_provider_inst: Provider = None
"""当前使用的 Provider 实例"""
self.curr_stt_provider_inst: STTProvider = None
@@ -119,14 +121,9 @@ class ProviderManager:
for provider_config in self.providers_config:
await self.load_provider(provider_config)
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
if self.stt_enabled and not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
if self.tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
self.default_provider_inst = self.inst_map.get(self.selected_provider_id)
if not self.default_provider_inst and self.provider_insts:
self.default_provider_inst = self.provider_insts[0]
# 初始化 MCP Client 连接
asyncio.create_task(
@@ -245,15 +242,12 @@ class ProviderManager:
await inst.initialize()
self.stt_provider_insts.append(inst)
if (
self.selected_stt_provider_id == provider_config["id"]
and self.stt_enabled
):
if self.selected_stt_provider_id == provider_config["id"]:
self.curr_stt_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。"
)
if not self.curr_stt_provider_inst and self.stt_enabled:
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
@@ -266,15 +260,12 @@ class ProviderManager:
await inst.initialize()
self.tts_provider_insts.append(inst)
if (
self.selected_tts_provider_id == provider_config["id"]
and self.tts_enabled
):
if self.selected_tts_provider_id == provider_config["id"]:
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。"
)
if not self.curr_tts_provider_inst and self.tts_enabled:
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
@@ -291,15 +282,12 @@ class ProviderManager:
await inst.initialize()
self.provider_insts.append(inst)
if (
self.selected_provider_id == provider_config["id"]
and self.provider_enabled
):
if self.selected_provider_id == provider_config["id"]:
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。"
)
if not self.curr_provider_inst and self.provider_enabled:
if not self.curr_provider_inst:
self.curr_provider_inst = inst
self.inst_map[provider_config["id"]] = inst
@@ -322,11 +310,7 @@ class ProviderManager:
if len(self.provider_insts) == 0:
self.curr_provider_inst = None
elif (
self.curr_provider_inst is None
and len(self.provider_insts) > 0
and self.provider_enabled
):
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0]
self.selected_provider_id = self.curr_provider_inst.meta().id
logger.info(
@@ -335,11 +319,7 @@ class ProviderManager:
if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None
elif (
self.curr_stt_provider_inst is None
and len(self.stt_provider_insts) > 0
and self.stt_enabled
):
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
logger.info(
@@ -348,11 +328,7 @@ class ProviderManager:
if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None
elif (
self.curr_tts_provider_inst is None
and len(self.tts_provider_insts) > 0
and self.tts_enabled
):
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
logger.info(
+10 -23
View File
@@ -61,7 +61,6 @@ class RstScene(Enum):
version="4.0.0",
)
class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
cfg = context.get_config()
@@ -216,9 +215,7 @@ class Main(star.Star):
"""获取已经安装的插件列表。"""
plugin_list_info = "已加载的插件:\n"
for plugin in self.context.get_all_stars():
plugin_list_info += (
f"- `{plugin.name}` By {plugin.author}: {plugin.desc}"
)
plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}"
if not plugin.activated:
plugin_list_info += " (未启用)"
plugin_list_info += "\n"
@@ -271,9 +268,7 @@ class Main(star.Star):
event.set_result(MessageEventResult().message("安装插件成功。"))
except Exception as e:
logger.error(f"安装插件失败: {e}")
event.set_result(
MessageEventResult().message(f"安装插件失败: {e}")
)
event.set_result(MessageEventResult().message(f"安装插件失败: {e}"))
return
@plugin.command("help")
@@ -319,7 +314,6 @@ class Main(star.Star):
ret += "更多帮助信息请查看插件仓库 README。"
event.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent):
"""开关文本转图片"""
@@ -426,18 +420,13 @@ UID: {user_id} 此 ID 可用于设置管理员。
):
"""查看或者切换 LLM Provider"""
if not self.context.get_using_provider():
event.set_result(
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")
)
return
if idx is None:
ret = "## 载入的 LLM 提供商\n"
for idx, llm in enumerate(self.context.get_all_providers()):
id_ = llm.meta().id
ret += f"{idx + 1}. {id_} ({llm.meta().model})"
if self.context.get_using_provider().meta().id == id_:
provider_using = self.context.get_using_provider()
if provider_using and provider_using.meta().id == id_:
ret += " (当前使用)"
ret += "\n"
@@ -1033,7 +1022,11 @@ UID: {user_id} 此 ID 可用于设置管理员。
message.unified_msg_origin, cid
)
if not conversation:
message.set_result(MessageEventResult().message("请先进入一个对话。可以使用 /new 创建。"))
message.set_result(
MessageEventResult().message(
"请先进入一个对话。可以使用 /new 创建。"
)
)
if not conversation.persona_id and not conversation.persona_id == "[%None]":
curr_persona_name = (
self.context.provider_manager.selected_default_persona["name"]
@@ -1176,7 +1169,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
@filter.command("gewe_code")
async def gewe_code(self, event: AstrMessageEvent, code: str):
"""保存 gewechat 验证码"""
code_path = os.path.join(get_astrbot_data_path(), "temp","gewe_code")
code_path = os.path.join(get_astrbot_data_path(), "temp", "gewe_code")
with open(code_path, "w", encoding="utf-8") as f:
f.write(code)
yield event.plain_result("验证码已保存。")
@@ -1462,9 +1455,3 @@ UID: {user_id} 此 ID 可用于设置管理员。
plugin_cfg["reset"] = reset_cfg
alter_cmd_cfg["astrbot"] = plugin_cfg
sp.put("alter_cmd", alter_cmd_cfg)
@filter.command("test")
async def test_to(self, event: AstrMessageEvent):
import asyncio
await asyncio.sleep(10)
yield event.plain_result("OK")