From 5f0d601baa449edb3b28837150ecc4259be098db Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 3 Jul 2025 09:59:27 +0800 Subject: [PATCH] feat: add support for selecting provider and models in webchat --- .../process_stage/method/llm_request.py | 23 +- astrbot/core/pipeline/waking_check/stage.py | 2 +- .../sources/webchat/webchat_adapter.py | 4 + astrbot/core/provider/entities.py | 3 + astrbot/core/provider/provider.py | 2 + .../core/provider/sources/anthropic_source.py | 6 +- .../core/provider/sources/dashscope_source.py | 2 + astrbot/core/provider/sources/dify_source.py | 15 +- .../core/provider/sources/gemini_source.py | 26 +- .../core/provider/sources/openai_source.py | 19 +- astrbot/core/provider/sources/zhipu_source.py | 3 +- astrbot/dashboard/routes/chat.py | 4 + .../components/chat/ProviderModelSelector.vue | 353 ++++++++++++++++++ dashboard/src/views/ChatPage.vue | 296 +-------------- uv.lock | 2 +- 15 files changed, 450 insertions(+), 310 deletions(-) create mode 100644 dashboard/src/components/chat/ProviderModelSelector.vue diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 961463c7a..a6c772c4f 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -24,6 +24,7 @@ from astrbot.core.provider.entities import ( ) from astrbot.core.star.star_handler import EventType from ..agent_runner.tool_loop_agent import ToolLoopAgent +from astrbot.core.provider import Provider class LLMRequestSubStage(Stage): @@ -51,16 +52,25 @@ class LLMRequestSubStage(Stage): self.conv_manager = ctx.plugin_manager.context.conversation_manager + def _select_provider(self, event: AstrMessageEvent) -> Provider | None: + """选择使用的 LLM 提供商""" + sel_provider = event.get_extra("selected_provider") + _ctx = self.ctx.plugin_manager.context + if sel_provider and isinstance(sel_provider, str): + provider = _ctx.get_provider_by_id(sel_provider) + return provider + + return _ctx.get_using_provider(umo=event.unified_msg_origin) + async def process( self, event: AstrMessageEvent, _nested: bool = False ) -> Union[None, AsyncGenerator[None, None]]: - req: ProviderRequest = None + req: ProviderRequest | None = None if not self.ctx.astrbot_config["provider_settings"]["enable"]: logger.debug("未启用 LLM 能力,跳过处理。") return - umo = event.unified_msg_origin - provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo) + provider = self._select_provider(event) if provider is None: return @@ -75,6 +85,8 @@ class LLMRequestSubStage(Stage): else: req = ProviderRequest(prompt="", image_urls=[]) + if sel_model := event.get_extra("selected_model"): + req.model = sel_model if self.provider_wake_prefix: if not event.message_str.startswith(self.provider_wake_prefix): return @@ -165,7 +177,10 @@ class LLMRequestSubStage(Stage): if self.streaming_response: # 用来标记流式响应需要分节 yield MessageChain(chain=[], type="break") - if self.show_tool_use or event.get_platform_name() == "webchat": + if ( + self.show_tool_use + or event.get_platform_name() == "webchat" + ): resp.data["chain"].type = "tool_call" await event.send(resp.data["chain"]) continue diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index 82e36ca2d..0354260e9 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -164,7 +164,7 @@ class WakingCheckStage(Stage): "parsed_params" ) - event.clear_extra() + event._extras.pop("parsed_params", None) event.set_extra("activated_handlers", activated_handlers) event.set_extra("handlers_parsed_params", handlers_parsed_params) diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 41d3e9418..aaac8e289 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -151,6 +151,10 @@ class WebChatAdapter(Platform): session_id=message.session_id, ) + _, _, payload = message.raw_message # type: ignore + message_event.set_extra("selected_provider", payload.get("selected_provider")) + message_event.set_extra("selected_model", payload.get("selected_model")) + self.commit_event(message_event) async def terminate(self): diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index abb01960c..2d120d7f6 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -110,6 +110,9 @@ class ProviderRequest: tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" + model: str | None = None + """模型名称,为 None 时使用提供商的默认模型""" + def __repr__(self): return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})" diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 1ecca3537..98e8fab85 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -88,6 +88,7 @@ class Provider(AbstractProvider): contexts: list = None, system_prompt: str = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, + model: str | None = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 @@ -116,6 +117,7 @@ class Provider(AbstractProvider): contexts: list = None, system_prompt: str = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, + model: str | None = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index 4ea4c2e02..aaff177e5 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -235,6 +235,7 @@ class ProviderAnthropic(Provider): contexts=None, system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ) -> LLMResponse: if contexts is None: @@ -259,7 +260,7 @@ class ProviderAnthropic(Provider): system_prompt, new_messages = self._prepare_payload(context_query) model_config = self.provider_config.get("model_config", {}) - model_config["model"] = self.get_model() + model_config["model"] = model or self.get_model() payloads = {"messages": new_messages, **model_config} @@ -285,6 +286,7 @@ class ProviderAnthropic(Provider): contexts=..., system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ): if contexts is None: @@ -309,7 +311,7 @@ class ProviderAnthropic(Provider): system_prompt, new_messages = self._prepare_payload(context_query) model_config = self.provider_config.get("model_config", {}) - model_config["model"] = self.get_model() + model_config["model"] = model or self.get_model() payloads = {"messages": new_messages, **model_config} diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 3498f8346..46b12726b 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -67,6 +67,7 @@ class ProviderDashscope(ProviderOpenAIOfficial): func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, + model=None, **kwargs, ) -> LLMResponse: if contexts is None: @@ -163,6 +164,7 @@ class ProviderDashscope(ProviderOpenAIOfficial): contexts=..., system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ): # raise NotImplementedError("This method is not implemented yet.") diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index b3a0ccccf..cc3e8062e 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -60,6 +60,8 @@ class ProviderDify(Provider): func_tool: FuncCall = None, contexts: List = None, system_prompt: str = None, + tool_calls_result=None, + model=None, **kwargs, ) -> LLMResponse: if image_urls is None: @@ -84,11 +86,13 @@ class ProviderDify(Provider): f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" ) continue - files_payload.append({ - "type": "image", - "transfer_method": "local_file", - "upload_file_id": file_response["id"], - }) + files_payload.append( + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response["id"], + } + ) # 获得会话变量 payload_vars = self.variables.copy() @@ -195,6 +199,7 @@ class ProviderDify(Provider): contexts=..., system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ): # raise NotImplementedError("This method is not implemented yet.") diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 573fe7684..56526c121 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -259,10 +259,12 @@ class ProviderGoogleGenAI(Provider): contents.append(content_cls(parts=part)) gemini_contents: list[types.Content] = [] - native_tool_enabled = any([ - self.provider_config.get("gm_native_coderunner", False), - self.provider_config.get("gm_native_search", False), - ]) + native_tool_enabled = any( + [ + self.provider_config.get("gm_native_coderunner", False), + self.provider_config.get("gm_native_search", False), + ] + ) for message in payloads["messages"]: role, content = message["role"], message.get("content") @@ -505,6 +507,7 @@ class ProviderGoogleGenAI(Provider): contexts=None, system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ) -> LLMResponse: if contexts is None: @@ -527,7 +530,7 @@ class ProviderGoogleGenAI(Provider): context_query.extend(tcr.to_openai_messages()) model_config = self.provider_config.get("model_config", {}) - model_config["model"] = self.get_model() + model_config["model"] = model or self.get_model() payloads = {"messages": context_query, **model_config} @@ -551,6 +554,7 @@ class ProviderGoogleGenAI(Provider): contexts=None, system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: if contexts is None: @@ -573,7 +577,7 @@ class ProviderGoogleGenAI(Provider): context_query.extend(tcr.to_openai_messages()) model_config = self.provider_config.get("model_config", {}) - model_config["model"] = self.get_model() + model_config["model"] = model or self.get_model() payloads = {"messages": context_query, **model_config} @@ -632,10 +636,12 @@ class ProviderGoogleGenAI(Provider): if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue - user_content["content"].append({ - "type": "image_url", - "image_url": {"url": image_data}, - }) + user_content["content"].append( + { + "type": "image_url", + "image_url": {"url": image_data}, + } + ) return user_content else: return {"role": "user", "content": text} diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 936fc2e34..f4c4987f4 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -99,6 +99,8 @@ class ProviderOpenAIOfficial(Provider): for key in to_del: del payloads[key] + logger.info(f"payloads: {payloads}") + completion = await self.client.chat.completions.create( **payloads, stream=False, extra_body=extra_body ) @@ -222,6 +224,7 @@ class ProviderOpenAIOfficial(Provider): contexts: list | None = None, system_prompt: str | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, + model: str | None = None, **kwargs, ) -> tuple: """准备聊天所需的有效载荷和上下文""" @@ -245,7 +248,7 @@ class ProviderOpenAIOfficial(Provider): context_query.extend(tcr.to_openai_messages()) model_config = self.provider_config.get("model_config", {}) - model_config["model"] = self.get_model() + model_config["model"] = model or self.get_model() payloads = {"messages": context_query, **model_config} @@ -346,6 +349,7 @@ class ProviderOpenAIOfficial(Provider): contexts=None, system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ) -> LLMResponse: payloads, context_query = await self._prepare_chat_payload( @@ -354,6 +358,7 @@ class ProviderOpenAIOfficial(Provider): contexts, system_prompt, tool_calls_result, + model=model, **kwargs, ) @@ -413,6 +418,7 @@ class ProviderOpenAIOfficial(Provider): contexts=[], system_prompt=None, tool_calls_result=None, + model=None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """流式对话,与服务商交互并逐步返回结果""" @@ -422,6 +428,7 @@ class ProviderOpenAIOfficial(Provider): contexts, system_prompt, tool_calls_result, + model=model, **kwargs, ) @@ -525,10 +532,12 @@ class ProviderOpenAIOfficial(Provider): if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue - user_content["content"].append({ - "type": "image_url", - "image_url": {"url": image_data}, - }) + user_content["content"].append( + { + "type": "image_url", + "image_url": {"url": image_data}, + } + ) return user_content else: return {"role": "user", "content": text} diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py index 428dee8f4..cf52e95fc 100644 --- a/astrbot/core/provider/sources/zhipu_source.py +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -28,6 +28,7 @@ class ProviderZhipu(ProviderOpenAIOfficial): func_tool: FuncCall = None, contexts=None, system_prompt=None, + model=None, **kwargs, ) -> LLMResponse: if contexts is None: @@ -38,7 +39,7 @@ class ProviderZhipu(ProviderOpenAIOfficial): context_query = [*contexts, new_record] model_cfgs: dict = self.provider_config.get("model_config", {}) - model = self.get_model() + model = model or self.get_model() # glm-4v-flash 只支持一张图片 if model.lower() == "glm-4v-flash" and image_urls and len(context_query) > 1: logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片") diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index a273bccdc..e7b086cd1 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -120,6 +120,8 @@ class ChatRoute(Route): conversation_id = post_data["conversation_id"] image_url = post_data.get("image_url") audio_url = post_data.get("audio_url") + selected_provider = post_data.get("selected_provider") + selected_model = post_data.get("selected_model") if not message and not image_url and not audio_url: return ( Response() @@ -202,6 +204,8 @@ class ChatRoute(Route): "message": message, "image_url": image_url, # list "audio_url": audio_url, + "selected_provider": selected_provider, + "selected_model": selected_model, }, ) ) diff --git a/dashboard/src/components/chat/ProviderModelSelector.vue b/dashboard/src/components/chat/ProviderModelSelector.vue new file mode 100644 index 000000000..7509b5295 --- /dev/null +++ b/dashboard/src/components/chat/ProviderModelSelector.vue @@ -0,0 +1,353 @@ + + + + + diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index 1d849c8cb..4995babfc 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -80,7 +80,7 @@

{{ getCurrentConversation.title || tm('conversation.newConversation') - }}

+ }}
{{ formatDate(getCurrentConversation.updated_at) }}
@@ -190,17 +190,11 @@ -
+
- - - {{ selectedProviderId }} / {{ selectedModelName }} - - - 选择模型 - - +
- - - - - - 选择提供商和模型 - - -
- -
-
-

提供商

-
- - - {{ provider.id }} - {{ provider.api_base }} - - -
- -
暂无可用提供商
-
-
- - -
-
-

模型

- - -
- - - {{ model }} - {{ model.description }} - - -
- -
请先选择提供商
-
-
- -
该提供商暂无可用模型
-
-
-
-
- - - 取消 - - 确认选择 - - -
-
@@ -1917,82 +1729,4 @@ export default { flex-shrink: 0; /* 防止header被压缩 */ } - -/* 提供商和模型选择对话框样式 */ -.provider-model-container { - display: flex; - height: 500px; - border: 1px solid var(--v-theme-border); - border-radius: 8px; - overflow: hidden; -} - -.provider-list-panel, -.model-list-panel { - flex: 1; - display: flex; - flex-direction: column; - background-color: var(--v-theme-surface); -} - -.provider-list-panel { - border-right: 1px solid var(--v-theme-border); -} - -.panel-header { - display: flex; - align-items: center; - justify-content: space-between; - padding: 16px; - border-bottom: 1px solid var(--v-theme-border); - background-color: var(--v-theme-containerBg); -} - -.panel-header h4 { - margin: 0; - font-size: 16px; - font-weight: 500; - color: var(--v-theme-primaryText); -} - -.provider-list, -.model-list { - flex: 1; - overflow-y: auto; - padding: 8px; -} - -.provider-item, -.model-item { - margin-bottom: 4px; - border-radius: 8px !important; - transition: all 0.2s ease; - cursor: pointer; -} - -.provider-item:hover, -.model-item:hover { - background-color: rgba(103, 58, 183, 0.05); -} - -.provider-item.v-list-item--active, -.model-item.v-list-item--active { - background-color: rgba(103, 58, 183, 0.1); - color: var(--v-theme-secondary); -} - -.empty-state { - display: flex; - flex-direction: column; - align-items: center; - justify-content: center; - height: 200px; - opacity: 0.6; - gap: 12px; -} - -.empty-text { - font-size: 14px; - color: var(--v-theme-secondaryText); -} \ No newline at end of file diff --git a/uv.lock b/uv.lock index 6279381bf..7a245997f 100644 --- a/uv.lock +++ b/uv.lock @@ -204,7 +204,7 @@ wheels = [ [[package]] name = "astrbot" -version = "3.5.17" +version = "3.5.18" source = { editable = "." } dependencies = [ { name = "aiocqhttp" },