chore: stage

This commit is contained in:
Soulter
2025-06-05 23:30:18 +08:00
parent a03af55edd
commit 98d2e9bd27
5 changed files with 53 additions and 20 deletions
+8 -1
View File
@@ -42,10 +42,12 @@ DEFAULT_CONFIG = {
"empty_mention_waiting": True,
"friend_message_needs_wake_prefix": False,
"ignore_bot_self_message": False,
"seperate_provider": False,
},
"provider": [],
"provider_settings": {
"enable": True,
"default_provider_id": "",
"wake_prefix": "",
"web_search": False,
"web_search_link": False,
@@ -1379,9 +1381,14 @@ CONFIG_METADATA_2 = {
"enable": {
"description": "启用大语言模型聊天",
"type": "bool",
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
"hint": "如需切换大语言模型提供商,请使用 /provider 命令。",
"obvious_hint": True,
},
"default_provider_id": {
"description": "默认模型提供商 ID",
"type": "str",
"hint": "可选。每个聊天会话的默认提供商 ID。"
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
"type": "string",
@@ -52,6 +52,9 @@ class LLMRequestSubStage(Stage):
self.streaming_response = ctx.astrbot_config["provider_settings"][
"streaming_response"
] # bool
self.seperate_provider = ctx.astrbot_config["provider_settings"][
"seperate_provider"
] # bool
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
@@ -70,8 +73,10 @@ class LLMRequestSubStage(Stage):
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
provider = self.ctx.plugin_manager.context.get_using_provider()
umo = None
if self.seperate_provider:
umo = event.unified_msg_origin
provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo)
if provider is None:
return
+1 -1
View File
@@ -6,7 +6,7 @@ from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core import logger
from astrbot.core import logger, sp
@register_stage
+7 -11
View File
@@ -18,12 +18,9 @@ class ProviderManager:
self.persona_configs: list = config.get("persona", [])
self.astrbot_config = config
self.selected_provider_id = sp.get("curr_provider")
self.selected_provider_id = self.provider_settings.get("default_provider_id")
self.selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
self.selected_tts_provider_id = self.provider_settings.get("provider_id")
# self.provider_enabled = self.provider_settings.get("enable", False)
# self.stt_enabled = self.provider_stt_settings.get("enable", False)
# self.tts_enabled = self.provider_tts_settings.get("enable", False)
# 人格情景管理
# 目前没有拆成独立的模块
@@ -103,12 +100,10 @@ class ProviderManager:
self.inst_map = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
self.default_provider_inst: Provider = None
"""默认的 Provider 实例。第 0 个或者用户以前指定的 Provider 实例"""
self.curr_provider_inst: Provider = None
"""当前使用的 Provider 实例"""
"""默认设置的 Provider 实例"""
self.curr_stt_provider_inst: STTProvider = None
"""当前使用的 Speech To Text Provider 实例"""
"""默认设置的 Speech To Text Provider 实例"""
self.curr_tts_provider_inst: TTSProvider = None
"""当前使用的 Text To Speech Provider 实例"""
self.db_helper = db_helper
@@ -123,9 +118,10 @@ class ProviderManager:
for provider_config in self.providers_config:
await self.load_provider(provider_config)
self.default_provider_inst = self.inst_map.get(self.selected_provider_id)
if not self.default_provider_inst and self.provider_insts:
self.default_provider_inst = self.provider_insts[0]
self.curr_provider_inst = self.inst_map.get(self.selected_provider_id)
if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0]
# 初始化 MCP Client 连接
asyncio.create_task(
+30 -5
View File
@@ -140,24 +140,49 @@ class Context:
"""获取所有用于 STT 任务的 Provider。"""
return self.provider_manager.stt_provider_insts
def get_using_provider(self) -> Provider:
def get_using_provider(self, umo: str = None) -> Provider:
"""
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
通过 /provider 指令切换。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
if umo:
perf = sp.get("session_provider_perf", {})
provider_id = perf.get(umo, None)
inst = self.provider_manager.inst_map.get(provider_id, None)
if inst:
return inst
return self.provider_manager.curr_provider_inst
def get_using_tts_provider(self) -> TTSProvider:
def get_using_tts_provider(self, umo: str = None) -> TTSProvider:
"""
获取当前使用的用于 TTS 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
if umo:
perf = sp.get("session_tts_provider_perf", {})
provider_id = perf.get(umo, None)
inst = self.provider_manager.inst_map.get(provider_id, None)
if inst:
return inst
return self.provider_manager.curr_tts_provider_inst
def get_using_stt_provider(self) -> STTProvider:
def get_using_stt_provider(self, umo: str = None) -> STTProvider:
"""
获取当前使用的用于 STT 任务的 Provider。
Args:
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
"""
if umo:
perf = sp.get("session_stt_provider_perf", {})
provider_id = perf.get(umo, None)
inst = self.provider_manager.inst_map.get(provider_id, None)
if inst:
return inst
return self.provider_manager.curr_stt_provider_inst
def get_config(self) -> AstrBotConfig: