From 5f0d601baa449edb3b28837150ecc4259be098db Mon Sep 17 00:00:00 2001
From: Soulter <905617992@qq.com>
Date: Thu, 3 Jul 2025 09:59:27 +0800
Subject: [PATCH] feat: add support for selecting provider and models in
webchat
---
.../process_stage/method/llm_request.py | 23 +-
astrbot/core/pipeline/waking_check/stage.py | 2 +-
.../sources/webchat/webchat_adapter.py | 4 +
astrbot/core/provider/entities.py | 3 +
astrbot/core/provider/provider.py | 2 +
.../core/provider/sources/anthropic_source.py | 6 +-
.../core/provider/sources/dashscope_source.py | 2 +
astrbot/core/provider/sources/dify_source.py | 15 +-
.../core/provider/sources/gemini_source.py | 26 +-
.../core/provider/sources/openai_source.py | 19 +-
astrbot/core/provider/sources/zhipu_source.py | 3 +-
astrbot/dashboard/routes/chat.py | 4 +
.../components/chat/ProviderModelSelector.vue | 353 ++++++++++++++++++
dashboard/src/views/ChatPage.vue | 296 +--------------
uv.lock | 2 +-
15 files changed, 450 insertions(+), 310 deletions(-)
create mode 100644 dashboard/src/components/chat/ProviderModelSelector.vue
diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py
index 961463c7a..a6c772c4f 100644
--- a/astrbot/core/pipeline/process_stage/method/llm_request.py
+++ b/astrbot/core/pipeline/process_stage/method/llm_request.py
@@ -24,6 +24,7 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.star.star_handler import EventType
from ..agent_runner.tool_loop_agent import ToolLoopAgent
+from astrbot.core.provider import Provider
class LLMRequestSubStage(Stage):
@@ -51,16 +52,25 @@ class LLMRequestSubStage(Stage):
self.conv_manager = ctx.plugin_manager.context.conversation_manager
+ def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
+ """选择使用的 LLM 提供商"""
+ sel_provider = event.get_extra("selected_provider")
+ _ctx = self.ctx.plugin_manager.context
+ if sel_provider and isinstance(sel_provider, str):
+ provider = _ctx.get_provider_by_id(sel_provider)
+ return provider
+
+ return _ctx.get_using_provider(umo=event.unified_msg_origin)
+
async def process(
self, event: AstrMessageEvent, _nested: bool = False
) -> Union[None, AsyncGenerator[None, None]]:
- req: ProviderRequest = None
+ req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
- umo = event.unified_msg_origin
- provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo)
+ provider = self._select_provider(event)
if provider is None:
return
@@ -75,6 +85,8 @@ class LLMRequestSubStage(Stage):
else:
req = ProviderRequest(prompt="", image_urls=[])
+ if sel_model := event.get_extra("selected_model"):
+ req.model = sel_model
if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix):
return
@@ -165,7 +177,10 @@ class LLMRequestSubStage(Stage):
if self.streaming_response:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
- if self.show_tool_use or event.get_platform_name() == "webchat":
+ if (
+ self.show_tool_use
+ or event.get_platform_name() == "webchat"
+ ):
resp.data["chain"].type = "tool_call"
await event.send(resp.data["chain"])
continue
diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py
index 82e36ca2d..0354260e9 100644
--- a/astrbot/core/pipeline/waking_check/stage.py
+++ b/astrbot/core/pipeline/waking_check/stage.py
@@ -164,7 +164,7 @@ class WakingCheckStage(Stage):
"parsed_params"
)
- event.clear_extra()
+ event._extras.pop("parsed_params", None)
event.set_extra("activated_handlers", activated_handlers)
event.set_extra("handlers_parsed_params", handlers_parsed_params)
diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py
index 41d3e9418..aaac8e289 100644
--- a/astrbot/core/platform/sources/webchat/webchat_adapter.py
+++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py
@@ -151,6 +151,10 @@ class WebChatAdapter(Platform):
session_id=message.session_id,
)
+ _, _, payload = message.raw_message # type: ignore
+ message_event.set_extra("selected_provider", payload.get("selected_provider"))
+ message_event.set_extra("selected_model", payload.get("selected_model"))
+
self.commit_event(message_event)
async def terminate(self):
diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py
index abb01960c..2d120d7f6 100644
--- a/astrbot/core/provider/entities.py
+++ b/astrbot/core/provider/entities.py
@@ -110,6 +110,9 @@ class ProviderRequest:
tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
+ model: str | None = None
+ """模型名称,为 None 时使用提供商的默认模型"""
+
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py
index 1ecca3537..98e8fab85 100644
--- a/astrbot/core/provider/provider.py
+++ b/astrbot/core/provider/provider.py
@@ -88,6 +88,7 @@ class Provider(AbstractProvider):
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
+ model: str | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -116,6 +117,7 @@ class Provider(AbstractProvider):
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
+ model: str | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py
index 4ea4c2e02..aaff177e5 100644
--- a/astrbot/core/provider/sources/anthropic_source.py
+++ b/astrbot/core/provider/sources/anthropic_source.py
@@ -235,6 +235,7 @@ class ProviderAnthropic(Provider):
contexts=None,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -259,7 +260,7 @@ class ProviderAnthropic(Provider):
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
- model_config["model"] = self.get_model()
+ model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
@@ -285,6 +286,7 @@ class ProviderAnthropic(Provider):
contexts=...,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
):
if contexts is None:
@@ -309,7 +311,7 @@ class ProviderAnthropic(Provider):
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
- model_config["model"] = self.get_model()
+ model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py
index 3498f8346..46b12726b 100644
--- a/astrbot/core/provider/sources/dashscope_source.py
+++ b/astrbot/core/provider/sources/dashscope_source.py
@@ -67,6 +67,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
+ model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -163,6 +164,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
contexts=...,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py
index b3a0ccccf..cc3e8062e 100644
--- a/astrbot/core/provider/sources/dify_source.py
+++ b/astrbot/core/provider/sources/dify_source.py
@@ -60,6 +60,8 @@ class ProviderDify(Provider):
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
+ tool_calls_result=None,
+ model=None,
**kwargs,
) -> LLMResponse:
if image_urls is None:
@@ -84,11 +86,13 @@ class ProviderDify(Provider):
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
)
continue
- files_payload.append({
- "type": "image",
- "transfer_method": "local_file",
- "upload_file_id": file_response["id"],
- })
+ files_payload.append(
+ {
+ "type": "image",
+ "transfer_method": "local_file",
+ "upload_file_id": file_response["id"],
+ }
+ )
# 获得会话变量
payload_vars = self.variables.copy()
@@ -195,6 +199,7 @@ class ProviderDify(Provider):
contexts=...,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
):
# raise NotImplementedError("This method is not implemented yet.")
diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py
index 573fe7684..56526c121 100644
--- a/astrbot/core/provider/sources/gemini_source.py
+++ b/astrbot/core/provider/sources/gemini_source.py
@@ -259,10 +259,12 @@ class ProviderGoogleGenAI(Provider):
contents.append(content_cls(parts=part))
gemini_contents: list[types.Content] = []
- native_tool_enabled = any([
- self.provider_config.get("gm_native_coderunner", False),
- self.provider_config.get("gm_native_search", False),
- ])
+ native_tool_enabled = any(
+ [
+ self.provider_config.get("gm_native_coderunner", False),
+ self.provider_config.get("gm_native_search", False),
+ ]
+ )
for message in payloads["messages"]:
role, content = message["role"], message.get("content")
@@ -505,6 +507,7 @@ class ProviderGoogleGenAI(Provider):
contexts=None,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -527,7 +530,7 @@ class ProviderGoogleGenAI(Provider):
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
- model_config["model"] = self.get_model()
+ model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
@@ -551,6 +554,7 @@ class ProviderGoogleGenAI(Provider):
contexts=None,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
@@ -573,7 +577,7 @@ class ProviderGoogleGenAI(Provider):
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
- model_config["model"] = self.get_model()
+ model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
@@ -632,10 +636,12 @@ class ProviderGoogleGenAI(Provider):
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
- user_content["content"].append({
- "type": "image_url",
- "image_url": {"url": image_data},
- })
+ user_content["content"].append(
+ {
+ "type": "image_url",
+ "image_url": {"url": image_data},
+ }
+ )
return user_content
else:
return {"role": "user", "content": text}
diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py
index 936fc2e34..f4c4987f4 100644
--- a/astrbot/core/provider/sources/openai_source.py
+++ b/astrbot/core/provider/sources/openai_source.py
@@ -99,6 +99,8 @@ class ProviderOpenAIOfficial(Provider):
for key in to_del:
del payloads[key]
+ logger.info(f"payloads: {payloads}")
+
completion = await self.client.chat.completions.create(
**payloads, stream=False, extra_body=extra_body
)
@@ -222,6 +224,7 @@ class ProviderOpenAIOfficial(Provider):
contexts: list | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
+ model: str | None = None,
**kwargs,
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
@@ -245,7 +248,7 @@ class ProviderOpenAIOfficial(Provider):
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
- model_config["model"] = self.get_model()
+ model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
@@ -346,6 +349,7 @@ class ProviderOpenAIOfficial(Provider):
contexts=None,
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
) -> LLMResponse:
payloads, context_query = await self._prepare_chat_payload(
@@ -354,6 +358,7 @@ class ProviderOpenAIOfficial(Provider):
contexts,
system_prompt,
tool_calls_result,
+ model=model,
**kwargs,
)
@@ -413,6 +418,7 @@ class ProviderOpenAIOfficial(Provider):
contexts=[],
system_prompt=None,
tool_calls_result=None,
+ model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
@@ -422,6 +428,7 @@ class ProviderOpenAIOfficial(Provider):
contexts,
system_prompt,
tool_calls_result,
+ model=model,
**kwargs,
)
@@ -525,10 +532,12 @@ class ProviderOpenAIOfficial(Provider):
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
- user_content["content"].append({
- "type": "image_url",
- "image_url": {"url": image_data},
- })
+ user_content["content"].append(
+ {
+ "type": "image_url",
+ "image_url": {"url": image_data},
+ }
+ )
return user_content
else:
return {"role": "user", "content": text}
diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py
index 428dee8f4..cf52e95fc 100644
--- a/astrbot/core/provider/sources/zhipu_source.py
+++ b/astrbot/core/provider/sources/zhipu_source.py
@@ -28,6 +28,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
func_tool: FuncCall = None,
contexts=None,
system_prompt=None,
+ model=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
@@ -38,7 +39,7 @@ class ProviderZhipu(ProviderOpenAIOfficial):
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
- model = self.get_model()
+ model = model or self.get_model()
# glm-4v-flash 只支持一张图片
if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py
index a273bccdc..e7b086cd1 100644
--- a/astrbot/dashboard/routes/chat.py
+++ b/astrbot/dashboard/routes/chat.py
@@ -120,6 +120,8 @@ class ChatRoute(Route):
conversation_id = post_data["conversation_id"]
image_url = post_data.get("image_url")
audio_url = post_data.get("audio_url")
+ selected_provider = post_data.get("selected_provider")
+ selected_model = post_data.get("selected_model")
if not message and not image_url and not audio_url:
return (
Response()
@@ -202,6 +204,8 @@ class ChatRoute(Route):
"message": message,
"image_url": image_url, # list
"audio_url": audio_url,
+ "selected_provider": selected_provider,
+ "selected_model": selected_model,
},
)
)
diff --git a/dashboard/src/components/chat/ProviderModelSelector.vue b/dashboard/src/components/chat/ProviderModelSelector.vue
new file mode 100644
index 000000000..7509b5295
--- /dev/null
+++ b/dashboard/src/components/chat/ProviderModelSelector.vue
@@ -0,0 +1,353 @@
+
+ 提供商
+ 模型
+