From ef99f6429153867789bcbaa899f84d19b5442b5a Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 21 Oct 2025 00:47:04 +0800 Subject: [PATCH 01/29] =?UTF-8?q?feat(config):=20=E6=B7=BB=E5=8A=A0=20agen?= =?UTF-8?q?t=20=E8=BF=90=E8=A1=8C=E5=99=A8=E7=B1=BB=E5=9E=8B=E5=8F=8A?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E9=85=8D=E7=BD=AE=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 69 ++++++++++- .../src/components/shared/AstrBotConfigV4.vue | 110 +++++++++--------- 2 files changed, 122 insertions(+), 57 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 7977e4392..361f557fd 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -71,6 +71,9 @@ DEFAULT_CONFIG = { "streaming_response": False, "show_tool_use_status": False, "streaming_segmented": False, + "agent_runner_type": "local", + "dify_runner_provider_id": "", + "coze_runner_provider_id": "", "max_agent_step": 30, "tool_call_timeout": 60, }, @@ -1931,12 +1934,19 @@ CONFIG_METADATA_2 = { "streaming_segmented": { "type": "bool", }, + "agent_runner_type": { + "type": "string", + }, + "dify_runner_provider_id": { + "type": "string", + }, + "coze_runner_provider_id": { + "type": "string", + }, "max_agent_step": { - "description": "工具调用轮数上限", "type": "int", }, "tool_call_timeout": { - "description": "工具调用超时时间(秒)", "type": "int", }, }, @@ -2070,6 +2080,46 @@ CONFIG_METADATA_3 = { "ai_group": { "name": "AI 配置", "metadata": { + "agent_runner": { + "description": "Agent", + "type": "object", + "items": { + "provider_settings.agent_runner_type": { + "description": "执行器", + "type": "string", + "options": ["local", "dify", "coze"], + "labels": ["内置 Agent", "Dify", "Coze"], + }, + }, + }, + "dify_runner": { + "description": "Dify", + "type": "object", + "items": { + "provider_settings.dify_runner_provider_id": { + "description": "Dify 执行器提供商 ID", + "type": "string", + "_special": "select_dify_runner_provider", + }, + }, + "condition": { + "provider_settings.agent_runner_type": "dify", + }, + }, + "coze_runner": { + "description": "Coze", + "type": "object", + "items": { + "provider_settings.coze_runner_provider_id": { + "description": "Coze 执行器提供商 ID", + "type": "string", + "_special": "select_coze_runner_provider", + }, + }, + "condition": { + "provider_settings.agent_runner_type": "coze", + }, + }, "ai": { "description": "模型", "type": "object", @@ -2123,6 +2173,9 @@ CONFIG_METADATA_3 = { "type": "text", }, }, + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "persona": { "description": "人格", @@ -2134,6 +2187,9 @@ CONFIG_METADATA_3 = { "_special": "select_persona", }, }, + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "knowledgebase": { "description": "知识库", @@ -2145,6 +2201,9 @@ CONFIG_METADATA_3 = { "_special": "select_knowledgebase", }, }, + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "websearch": { "description": "网页搜索", @@ -2181,6 +2240,9 @@ CONFIG_METADATA_3 = { "type": "bool", }, }, + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "others": { "description": "其他配置", @@ -2248,6 +2310,9 @@ CONFIG_METADATA_3 = { "type": "bool", }, }, + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, }, }, diff --git a/dashboard/src/components/shared/AstrBotConfigV4.vue b/dashboard/src/components/shared/AstrBotConfigV4.vue index 6ae758dfb..08373595f 100644 --- a/dashboard/src/components/shared/AstrBotConfigV4.vue +++ b/dashboard/src/components/shared/AstrBotConfigV4.vue @@ -101,6 +101,21 @@ function shouldShowItem(itemMeta, itemKey) { return true } +// 检查最外层的 object 是否应该显示 +function shouldShowSection() { + const sectionMeta = props.metadata[props.metadataKey] + if (!sectionMeta?.condition) { + return true + } + for (const [conditionKey, expectedValue] of Object.entries(sectionMeta.condition)) { + const actualValue = getValueBySelector(props.iterable, conditionKey) + if (actualValue !== expectedValue) { + return false + } + } + return true +} + function hasVisibleItemsAfter(items, currentIndex) { const itemEntries = Object.entries(items) @@ -114,12 +129,27 @@ function hasVisibleItemsAfter(items, currentIndex) { return false } + +// 将 options 和 labels 转换为 v-select 的 items 格式 +function getSelectItems(itemMeta) { + if (!itemMeta?.options) { + return [] + } + if (itemMeta?.labels && itemMeta.labels.length === itemMeta.options.length) { + return itemMeta.options.map((value, index) => ({ + title: itemMeta.labels[index], + value: value + })) + } + return itemMeta.options +} - + From e74f6263837970a78ed78b25800431e33d6fc7c4 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 21 Oct 2025 09:55:14 +0800 Subject: [PATCH 02/29] stage --- .../core/agent/runners/dify_agent_runner.py | 326 ++++++++++++++++++ astrbot/core/config/default.py | 8 +- .../process_stage/method/llm_request.py | 17 +- astrbot/core/platform/astr_message_event.py | 2 +- .../components/provider/AddNewProvider.vue | 28 +- .../src/components/shared/AstrBotConfigV4.vue | 6 + .../src/i18n/locales/en-US/core/common.json | 1 + .../i18n/locales/en-US/features/provider.json | 4 +- .../i18n/locales/zh-CN/features/provider.json | 8 +- dashboard/src/views/ProviderPage.vue | 9 +- 10 files changed, 372 insertions(+), 37 deletions(-) create mode 100644 astrbot/core/agent/runners/dify_agent_runner.py diff --git a/astrbot/core/agent/runners/dify_agent_runner.py b/astrbot/core/agent/runners/dify_agent_runner.py new file mode 100644 index 000000000..145e4a627 --- /dev/null +++ b/astrbot/core/agent/runners/dify_agent_runner.py @@ -0,0 +1,326 @@ +import sys +import os +import typing as T +from .base import BaseAgentRunner, AgentResponse, AgentState +from ..hooks import BaseAgentRunHooks +from ..tool_executor import BaseFunctionToolExecutor +from ..run_context import ContextWrapper, TContext +from ..response import AgentResponseData +from astrbot.core.provider.provider import Provider +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import ( + ProviderRequest, + LLMResponse, +) +from astrbot.core.utils.dify_api_client import DifyAPIClient +from astrbot.core.utils.io import download_image_by_url, download_file +from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core import logger, sp +import astrbot.core.message.components as Comp + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +class DifyAgentRunner(BaseAgentRunner[TContext]): + """Dify Agent Runner""" + + @override + async def reset( + self, + provider: Provider, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + tool_executor: BaseFunctionToolExecutor[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.provider = provider + self.final_llm_resp = None + self._state = AgentState.IDLE + self.tool_executor = tool_executor + self.agent_hooks = agent_hooks + self.run_context = run_context + + # Dify 特定配置 - 从 provider 或 kwargs 中获取 + self.api_key = kwargs.get("dify_api_key", "") + api_base = kwargs.get("dify_api_base", "https://api.dify.ai/v1") + self.api_type = kwargs.get("dify_api_type", "") + + self.workflow_output_key = kwargs.get( + "dify_workflow_output_key", "astrbot_wf_output" + ) + self.dify_query_input_key = kwargs.get( + "dify_query_input_key", "astrbot_text_query" + ) + if not self.dify_query_input_key: + self.dify_query_input_key = "astrbot_text_query" + if not self.workflow_output_key: + self.workflow_output_key = "astrbot_wf_output" + + self.variables: dict = kwargs.get("variables", {}) + self.timeout = kwargs.get("timeout", 120) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + + self.conversation_ids = {} + """记录当前 session id 的对话 ID""" + + self.api_client = DifyAPIClient(self.api_key, api_base) + + def _transition_state(self, new_state: AgentState) -> None: + """转换 Agent 状态""" + if self._state != new_state: + logger.debug(f"Dify Agent state transition: {self._state} -> {new_state}") + self._state = new_state + + @override + async def step(self): + """ + 执行 Dify Agent 的一个步骤 + """ + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + + if self._state == AgentState.IDLE: + try: + await self.agent_hooks.on_agent_begin(self.run_context) + except Exception as e: + logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True) + + # 开始处理,转换到运行状态 + self._transition_state(AgentState.RUNNING) + + try: + # 执行 Dify 请求并处理结果 + async for response in self._execute_dify_request(): + yield response + except Exception as e: + logger.error(f"Dify 请求失败:{str(e)}") + self._transition_state(AgentState.ERROR) + self.final_llm_resp = LLMResponse( + role="err", completion_text=f"Dify 请求失败:{str(e)}" + ) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=MessageChain().message(f"Dify 请求失败:{str(e)}") + ), + ) + + async def _execute_dify_request(self): + """执行 Dify 请求的核心逻辑""" + prompt = self.req.prompt or "" + session_id = self.req.session_id or "unknown" + image_urls = self.req.image_urls or [] + system_prompt = self.req.system_prompt + + conversation_id = self.conversation_ids.get(session_id, "") + result = "" + + # 处理图片上传 + files_payload = [] + for image_url in image_urls: + try: + image_path = ( + await download_image_by_url(image_url) + if image_url.startswith("http") + else image_url + ) + file_response = await self.api_client.file_upload( + image_path, user=session_id + ) + logger.debug(f"Dify 上传图片响应:{file_response}") + if "id" not in file_response: + logger.warning( + f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" + ) + continue + files_payload.append( + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response["id"], + } + ) + except Exception as e: + logger.warning(f"上传图片失败:{e}") + continue + + # 获得会话变量 + payload_vars = self.variables.copy() + # 动态变量 + session_var = await sp.session_get(session_id, "session_variables", default={}) + payload_vars.update(session_var) + payload_vars["system_prompt"] = system_prompt + + # 处理不同的 API 类型 + match self.api_type: + case "chat" | "agent" | "chatflow": + if not prompt: + prompt = "请描述这张图片。" + + async for chunk in self.api_client.chat_messages( + inputs={ + **payload_vars, + }, + query=prompt, + user=session_id, + conversation_id=conversation_id, + files=files_payload, + timeout=self.timeout, + ): + logger.debug(f"dify resp chunk: {chunk}") + if chunk["event"] == "message" or chunk["event"] == "agent_message": + result += chunk["answer"] + if not conversation_id: + self.conversation_ids[session_id] = chunk["conversation_id"] + conversation_id = chunk["conversation_id"] + + # 如果是流式响应,发送增量数据 + if self.streaming and chunk["answer"]: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message(chunk["answer"]) + ), + ) + elif chunk["event"] == "message_end": + logger.debug("Dify message end") + break + elif chunk["event"] == "error": + logger.error(f"Dify 出现错误:{chunk}") + raise Exception( + f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}" + ) + + case "workflow": + async for chunk in self.api_client.workflow_run( + inputs={ + self.dify_query_input_key: prompt, + "astrbot_session_id": session_id, + **payload_vars, + }, + user=session_id, + files=files_payload, + timeout=self.timeout, + ): + match chunk["event"]: + case "workflow_started": + logger.info( + f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。" + ) + case "node_finished": + logger.debug( + f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。" + ) + case "workflow_finished": + logger.info( + f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束" + ) + logger.debug(f"Dify 工作流结果:{chunk}") + if chunk["data"]["error"]: + logger.error( + f"Dify 工作流出现错误:{chunk['data']['error']}" + ) + raise Exception( + f"Dify 工作流出现错误:{chunk['data']['error']}" + ) + if self.workflow_output_key not in chunk["data"]["outputs"]: + raise Exception( + f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}" + ) + result = chunk + case _: + raise Exception(f"未知的 Dify API 类型:{self.api_type}") + + if not result: + logger.warning("Dify 请求结果为空,请查看 Debug 日志。") + + # 解析结果 + chain = await self.parse_dify_result(result) + + # 创建最终响应 + self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain) + self._transition_state(AgentState.DONE) + + try: + await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + + # 返回最终结果 + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=chain), + ) + + async def parse_dify_result(self, chunk: dict | str) -> MessageChain: + """解析 Dify 的响应结果""" + if isinstance(chunk, str): + # Chat + return MessageChain(chain=[Comp.Plain(chunk)]) + + async def parse_file(item: dict): + match item["type"]: + case "image": + return Comp.Image(file=item["url"], url=item["url"]) + case "audio": + # 仅支持 wav + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, f"{item['filename']}.wav") + await download_file(item["url"], path) + return Comp.Image(file=item["url"], url=item["url"]) + case "video": + return Comp.Video(file=item["url"]) + case _: + return Comp.File(name=item["filename"], file=item["url"]) + + output = chunk["data"]["outputs"][self.workflow_output_key] + chains = [] + if isinstance(output, str): + # 纯文本输出 + chains.append(Comp.Plain(output)) + elif isinstance(output, list): + # 主要适配 Dify 的 HTTP 请求结点的多模态输出 + for item in output: + # handle Array[File] + if ( + not isinstance(item, dict) + or item.get("dify_model_identity", "") != "__dify__file__" + ): + chains.append(Comp.Plain(str(output))) + break + else: + chains.append(Comp.Plain(str(output))) + + # scan file + files = chunk["data"].get("files", []) + for item in files: + comp = await parse_file(item) + chains.append(comp) + + return MessageChain(chain=chains) + + @override + def done(self) -> bool: + """检查 Agent 是否已完成工作""" + return self._state in (AgentState.DONE, AgentState.ERROR) + + @override + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp + + async def forget(self, session_id): + """忘记会话上下文""" + self.conversation_ids[session_id] = "" + return True + + async def terminate(self): + """终止并清理资源""" + if hasattr(self, "api_client"): + await self.api_client.close() diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 361f557fd..c4d2793ee 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -975,7 +975,7 @@ CONFIG_METADATA_2 = { "id": "dify_app_default", "provider": "dify", "type": "dify", - "provider_type": "chat_completion", + "provider_type": "agent_runner", "enable": True, "dify_api_type": "chat", "dify_api_key": "", @@ -989,7 +989,7 @@ CONFIG_METADATA_2 = { "Coze": { "id": "coze", "provider": "coze", - "provider_type": "chat_completion", + "provider_type": "agent_runner", "type": "coze", "enable": True, "coze_api_key": "", @@ -2099,7 +2099,7 @@ CONFIG_METADATA_3 = { "provider_settings.dify_runner_provider_id": { "description": "Dify 执行器提供商 ID", "type": "string", - "_special": "select_dify_runner_provider", + "_special": "select_provider_dify_runner", }, }, "condition": { @@ -2113,7 +2113,7 @@ CONFIG_METADATA_3 = { "provider_settings.coze_runner_provider_id": { "description": "Coze 执行器提供商 ID", "type": "string", - "_special": "select_coze_runner_provider", + "_special": "select_provider_coze_runner", }, }, "condition": { diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index b2245e4da..85eaf9ffa 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -44,7 +44,7 @@ except (ModuleNotFoundError, ImportError): AgentContextWrapper = ContextWrapper[AstrAgentContext] -AgentRunner = ToolLoopAgentRunner[AgentContextWrapper] +AgentRunner = ToolLoopAgentRunner[AstrAgentContext] class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): @@ -102,7 +102,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): request = ProviderRequest( prompt=input_, - system_prompt=tool.description, + system_prompt=tool.description or "", image_urls=[], # 暂时不传递原始 agent 的上下文 contexts=[], # 暂时不传递原始 agent 的上下文 func_tool=toolset, @@ -239,7 +239,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): yield res -class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]): +class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]): async def on_agent_done(self, run_context, llm_response): # 执行事件钩子 await call_event_hook( @@ -337,7 +337,7 @@ class LLMRequestSubStage(Stage): self.conv_manager = ctx.plugin_manager.context.conversation_manager - def _select_provider(self, event: AstrMessageEvent) -> Provider | None: + def _select_provider(self, event: AstrMessageEvent): """选择使用的 LLM 提供商""" sel_provider = event.get_extra("selected_provider") _ctx = self.ctx.plugin_manager.context @@ -382,6 +382,9 @@ class LLMRequestSubStage(Stage): provider = self._select_provider(event) if provider is None: return + if not isinstance(provider, Provider): + logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。") + return if event.get_extra("provider_request"): req = event.get_extra("provider_request") @@ -520,8 +523,10 @@ class LLMRequestSubStage(Stage): chain = ( MessageChain().message(final_llm_resp.completion_text).chain ) - else: + elif final_llm_resp.result_chain: chain = final_llm_resp.result_chain.chain + else: + chain = MessageChain().chain event.set_result( MessageEventResult( chain=chain, @@ -553,6 +558,8 @@ class LLMRequestSubStage(Stage): self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider ): """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" + if not req.conversation: + return conversation = await self.conv_manager.get_conversation( event.unified_msg_origin, req.conversation.cid ) diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 05169c4fe..e948ed5bc 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -179,7 +179,7 @@ class AstrMessageEvent(abc.ABC): def get_extra( self, key: str | None = None, default: _VT = None - ) -> dict[str, Any] | _VT: + ) -> Any: """ 获取额外的信息。 """ diff --git a/dashboard/src/components/provider/AddNewProvider.vue b/dashboard/src/components/provider/AddNewProvider.vue index b4cd1eb92..f876779fb 100644 --- a/dashboard/src/components/provider/AddNewProvider.vue +++ b/dashboard/src/components/provider/AddNewProvider.vue @@ -7,6 +7,10 @@ mdi-message-text {{ tm('dialogs.addProvider.tabs.basic') }} + + mdi-cogs + {{ tm('dialogs.addProvider.tabs.agentRunner') }} + mdi-microphone-message {{ tm('dialogs.addProvider.tabs.speechToText') }} @@ -27,7 +31,7 @@
- 接入 {{ name }} + {{ name }} {{ getProviderDescription(template, name) }} @@ -54,7 +58,7 @@ - {{ tm('dialogs.addProvider.noTemplates', { type: getTabTypeName(tabType) }) }} + {{ t('dialogs.addProvider.noTemplates') }} @@ -104,19 +108,6 @@ export default { this.$emit('update:show', value); } }, - - // 翻译消息的计算属性 - messages() { - return { - tabTypes: { - 'chat_completion': this.tm('providers.tabs.chatCompletion'), - 'speech_to_text': this.tm('providers.tabs.speechToText'), - 'text_to_speech': this.tm('providers.tabs.textToSpeech'), - 'embedding': this.tm('providers.tabs.embedding'), - 'rerank': this.tm('providers.tabs.rerank') - } - }; - } }, methods: { closeDialog() { @@ -140,11 +131,6 @@ export default { // 从工具函数导入 getProviderIcon, - // 获取Tab类型的中文名称 - getTabTypeName(tabType) { - return this.messages.tabTypes[tabType] || tabType; - }, - // 获取提供商简介 getProviderDescription(template, name) { return getProviderDescription(template, name, this.tm); diff --git a/dashboard/src/components/shared/AstrBotConfigV4.vue b/dashboard/src/components/shared/AstrBotConfigV4.vue index 08373595f..52f793b96 100644 --- a/dashboard/src/components/shared/AstrBotConfigV4.vue +++ b/dashboard/src/components/shared/AstrBotConfigV4.vue @@ -242,6 +242,12 @@ function getSelectItems(itemMeta) {
+
+ +
+
+ +
diff --git a/dashboard/src/i18n/locales/en-US/core/common.json b/dashboard/src/i18n/locales/en-US/core/common.json index 4aff41001..ff23b255a 100644 --- a/dashboard/src/i18n/locales/en-US/core/common.json +++ b/dashboard/src/i18n/locales/en-US/core/common.json @@ -73,6 +73,7 @@ "disabled": "Disabled", "delete": "Delete", "edit": "Edit", + "copy": "Copy", "noData": "No data available" } } \ No newline at end of file diff --git a/dashboard/src/i18n/locales/en-US/features/provider.json b/dashboard/src/i18n/locales/en-US/features/provider.json index 7888c9e0b..e08177d3d 100644 --- a/dashboard/src/i18n/locales/en-US/features/provider.json +++ b/dashboard/src/i18n/locales/en-US/features/provider.json @@ -9,6 +9,7 @@ "tabs": { "all": "All", "chatCompletion": "Chat Completion", + "agentRunner": "Agent Runner", "speechToText": "Speech to Text", "textToSpeech": "Text to Speech", "embedding": "Embedding", @@ -44,12 +45,13 @@ "title": "Service Provider", "tabs": { "basic": "Basic", + "agentRunner": "Agent Runner", "speechToText": "Speech to Text", "textToSpeech": "Text to Speech", "embedding": "Embedding", "rerank": "Rerank" }, - "noTemplates": "No {type} type provider templates available" + "noTemplates": "No this type provider templates available" }, "config": { "addTitle": "Add", diff --git a/dashboard/src/i18n/locales/zh-CN/features/provider.json b/dashboard/src/i18n/locales/zh-CN/features/provider.json index 913d74c30..234018829 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/provider.json +++ b/dashboard/src/i18n/locales/zh-CN/features/provider.json @@ -8,7 +8,8 @@ "providerType": "提供商类型", "tabs": { "all": "全部", - "chatCompletion": "基本对话", + "chatCompletion": "对话", + "agentRunner": "Agent 执行器", "speechToText": "语音转文字", "textToSpeech": "文字转语音", "embedding": "嵌入(Embedding)", @@ -44,13 +45,14 @@ "addProvider": { "title": "模型提供商", "tabs": { - "basic": "基本", + "basic": "对话", + "agentRunner": "Agent 执行器", "speechToText": "语音转文字", "textToSpeech": "文字转语音", "embedding": "嵌入(Embedding)", "rerank": "重排序(Rerank)" }, - "noTemplates": "暂无{type}类型的提供商模板" + "noTemplates": "暂无该类型的提供商模板" }, "config": { "addTitle": "新增", diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index 7264bd35e..811e294f2 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -29,6 +29,10 @@ mdi-message-text {{ tm('providers.tabs.chatCompletion') }} + + mdi-message-text + {{ tm('providers.tabs.agentRunner') }} + mdi-microphone-message {{ tm('providers.tabs.speechToText') }} @@ -312,8 +316,8 @@ export default { "anthropic_chat_completion": "chat_completion", "googlegenai_chat_completion": "chat_completion", "zhipu_chat_completion": "chat_completion", - "dify": "chat_completion", - "coze": "chat_completion", + "dify": "agent_runner", + "coze": "agent_runner", "dashscope": "chat_completion", "openai_whisper_api": "speech_to_text", "openai_whisper_selfhost": "speech_to_text", @@ -357,6 +361,7 @@ export default { }, tabTypes: { 'chat_completion': this.tm('providers.tabs.chatCompletion'), + 'agent_runner': this.tm('providers.tabs.agentRunner'), 'speech_to_text': this.tm('providers.tabs.speechToText'), 'text_to_speech': this.tm('providers.tabs.textToSpeech'), 'embedding': this.tm('providers.tabs.embedding'), From 61a68477d07dedca029c9b6962ff87ea0736251e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Tue, 21 Oct 2025 14:19:38 +0800 Subject: [PATCH 03/29] stage --- astrbot/core/config/default.py | 68 ++++++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 19 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index c4d2793ee..de0190405 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1804,7 +1804,6 @@ CONFIG_METADATA_2 = { "enable": { "description": "启用", "type": "bool", - "hint": "是否启用。", }, "key": { "description": "API Key", @@ -2081,14 +2080,22 @@ CONFIG_METADATA_3 = { "name": "AI 配置", "metadata": { "agent_runner": { - "description": "Agent", + "description": "Agent 执行方式", "type": "object", "items": { + "provider_settings.enable": { + "description": "启用", + "type": "bool", + "hint": "AI 对话总开关", + }, "provider_settings.agent_runner_type": { "description": "执行器", "type": "string", "options": ["local", "dify", "coze"], "labels": ["内置 Agent", "Dify", "Coze"], + "condition": { + "provider_settings.enable": True, + }, }, }, }, @@ -2104,6 +2111,7 @@ CONFIG_METADATA_3 = { }, "condition": { "provider_settings.agent_runner_type": "dify", + "provider_settings.enable": True, }, }, "coze_runner": { @@ -2118,32 +2126,32 @@ CONFIG_METADATA_3 = { }, "condition": { "provider_settings.agent_runner_type": "coze", + "provider_settings.enable": True, }, }, "ai": { "description": "模型", "type": "object", "items": { - "provider_settings.enable": { - "description": "启用大语言模型聊天", - "type": "bool", - }, "provider_settings.default_provider_id": { "description": "默认聊天模型", "type": "string", "_special": "select_provider", - "hint": "留空时使用第一个模型。", + "hint": "留空时使用第一个模型", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.default_image_caption_provider_id": { "description": "默认图片转述模型", "type": "string", "_special": "select_provider", - "hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。", + "hint": "留空代表不使用,可用于非多模态模型", }, "provider_stt_settings.enable": { "description": "启用语音转文本", "type": "bool", - "hint": "STT 总开关。", + "hint": "STT 总开关", }, "provider_stt_settings.provider_id": { "description": "默认语音转文本模型", @@ -2157,12 +2165,11 @@ CONFIG_METADATA_3 = { "provider_tts_settings.enable": { "description": "启用文本转语音", "type": "bool", - "hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。", + "hint": "TTS 总开关", }, "provider_tts_settings.provider_id": { "description": "默认文本转语音模型", "type": "string", - "hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。", "_special": "select_provider_tts", "condition": { "provider_tts_settings.enable": True, @@ -2174,7 +2181,7 @@ CONFIG_METADATA_3 = { }, }, "condition": { - "provider_settings.agent_runner_type": "local", + "provider_settings.enable": True, }, }, "persona": { @@ -2189,6 +2196,7 @@ CONFIG_METADATA_3 = { }, "condition": { "provider_settings.agent_runner_type": "local", + "provider_settings.enable": True, }, }, "knowledgebase": { @@ -2203,6 +2211,7 @@ CONFIG_METADATA_3 = { }, "condition": { "provider_settings.agent_runner_type": "local", + "provider_settings.enable": True, }, }, "websearch": { @@ -2242,6 +2251,7 @@ CONFIG_METADATA_3 = { }, "condition": { "provider_settings.agent_runner_type": "local", + "provider_settings.enable": True, }, }, "others": { @@ -2255,50 +2265,70 @@ CONFIG_METADATA_3 = { "provider_settings.identifier": { "description": "用户识别", "type": "bool", + "hint": "启用后,会在提示词前包含用户 ID 信息。", }, "provider_settings.group_name_display": { "description": "显示群名称", "type": "bool", - "hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。", + "hint": "启用后,在支持的平台(OneBot v11)上会在提示词前包含群名称信息。", }, "provider_settings.datetime_system_prompt": { "description": "现实世界时间感知", "type": "bool", + "hint": "启用后,会在系统提示词中附带当前时间信息。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.show_tool_use_status": { "description": "输出函数调用状态", "type": "bool", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.max_agent_step": { "description": "工具调用轮数上限", "type": "int", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.tool_call_timeout": { "description": "工具调用超时时间(秒)", "type": "int", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.streaming_response": { - "description": "流式回复", + "description": "流式输出", "type": "bool", }, "provider_settings.streaming_segmented": { - "description": "不支持流式回复的平台采取分段输出", + "description": "不支持流式输出的平台采取分段输出", "type": "bool", }, "provider_settings.max_context_length": { "description": "最多携带对话轮数", "type": "int", - "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条。-1 为不限制。", + "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.dequeue_context_length": { "description": "丢弃对话轮数", "type": "int", - "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数。", + "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.wake_prefix": { "description": "LLM 聊天额外唤醒前缀 ", "type": "string", - "hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。", + "hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求", }, "provider_settings.prompt_prefix": { "description": "用户提示词", @@ -2311,7 +2341,7 @@ CONFIG_METADATA_3 = { }, }, "condition": { - "provider_settings.agent_runner_type": "local", + "provider_settings.enable": True, }, }, }, From 766d6f2becd7810dc98c05ce9f3f034fa609c008 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Nov 2025 19:59:21 +0800 Subject: [PATCH 04/29] fix(conversation): update session configuration retrieval to use unified message origin --- packages/astrbot/commands/conversation.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/packages/astrbot/commands/conversation.py b/packages/astrbot/commands/conversation.py index 9538d8f53..67402c660 100644 --- a/packages/astrbot/commands/conversation.py +++ b/packages/astrbot/commands/conversation.py @@ -38,9 +38,8 @@ class ConversationCommands: async def reset(self, message: AstrMessageEvent): """重置 LLM 会话""" - is_unique_session = self.context.get_config()["platform_settings"][ - "unique_session" - ] + cfg = self.context.get_config(umo=message.unified_msg_origin) + is_unique_session = cfg["platform_settings"]["unique_session"] is_group = bool(message.get_group_id()) scene = RstScene.get_scene(is_group, is_unique_session) @@ -227,9 +226,8 @@ class ConversationCommands: else: ret += "\n当前对话: 无" - unique_session = self.context.get_config()["platform_settings"][ - "unique_session" - ] + cfg = self.context.get_config(umo=message.unified_msg_origin) + unique_session = cfg["platform_settings"]["unique_session"] if unique_session: ret += "\n会话隔离粒度: 个人" else: @@ -399,9 +397,8 @@ class ConversationCommands: async def del_conv(self, message: AstrMessageEvent): """删除当前对话""" - is_unique_session = self.context.get_config()["platform_settings"][ - "unique_session" - ] + cfg = self.context.get_config(umo=message.unified_msg_origin) + is_unique_session = cfg["platform_settings"]["unique_session"] if message.get_group_id() and not is_unique_session and message.role != "admin": # 群聊,没开独立会话,发送人不是管理员 message.set_result( From 910ec6c6953b35983f2fe4c4b32c9e4e02fee6bc Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Nov 2025 20:18:06 +0800 Subject: [PATCH 05/29] feat: implement third party agent sub stage and refactor provider management - Added `ThirdPartyAgentSubStage` to handle interactions with third-party agent runners (Dify, Coze, Dashscope). - Refactored `star_request.py` to ensure consistent return types in the `process` method. - Updated `stage.py` to initialize and utilize the new `AgentRequestSubStage`. - Modified `ProviderManager` to skip loading agent runner providers. - Removed `Dify` source implementation as it is now handled by the new agent runner structure. - Enhanced `DifyAPIClient` to support file uploads via both file path and file data. - Cleaned up shared preferences handling to simplify session preference retrieval. - Updated dashboard configuration to reflect changes in agent runner provider selection. - Refactored conversation commands to accommodate the new agent runner structure and remove direct dependencies on Dify. - Adjusted main application logic to ensure compatibility with the new conversation management approach. --- astrbot/core/agent/runners/base.py | 11 +- .../core/agent/runners/coze_agent_runner.py | 367 ++++++++++++++++++ .../agent/runners/dashscope_agent_runner.py | 273 +++++++++++++ .../core/agent/runners/dify_agent_runner.py | 118 +++--- .../agent/runners/tool_loop_agent_runner.py | 6 - astrbot/core/config/default.py | 67 ++-- .../process_stage/method/agent_request.py | 48 +++ .../internal.py} | 47 +-- .../method/agent_sub_stages/third_party.py | 126 ++++++ .../process_stage/method/star_request.py | 2 +- astrbot/core/pipeline/process_stage/stage.py | 13 +- astrbot/core/provider/manager.py | 10 +- astrbot/core/provider/sources/dify_source.py | 285 -------------- astrbot/core/utils/dify_api_client.py | 22 +- astrbot/core/utils/shared_preferences.py | 29 +- .../src/components/shared/AstrBotConfigV4.vue | 7 +- packages/astrbot/commands/conversation.py | 159 +++----- packages/astrbot/main.py | 38 +- 18 files changed, 1012 insertions(+), 616 deletions(-) create mode 100644 astrbot/core/agent/runners/coze_agent_runner.py create mode 100644 astrbot/core/agent/runners/dashscope_agent_runner.py create mode 100644 astrbot/core/pipeline/process_stage/method/agent_request.py rename astrbot/core/pipeline/process_stage/method/{llm_request.py => agent_sub_stages/internal.py} (91%) create mode 100644 astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py delete mode 100644 astrbot/core/provider/sources/dify_source.py diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index f7e0913b4..c53bdc0dc 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -2,13 +2,12 @@ import abc import typing as T from enum import Enum, auto -from astrbot.core.provider import Provider +from astrbot import logger from astrbot.core.provider.entities import LLMResponse from ..hooks import BaseAgentRunHooks from ..response import AgentResponse from ..run_context import ContextWrapper, TContext -from ..tool_executor import BaseFunctionToolExecutor class AgentState(Enum): @@ -24,9 +23,7 @@ class BaseAgentRunner(T.Generic[TContext]): @abc.abstractmethod async def reset( self, - provider: Provider, run_context: ContextWrapper[TContext], - tool_executor: BaseFunctionToolExecutor[TContext], agent_hooks: BaseAgentRunHooks[TContext], **kwargs: T.Any, ) -> None: @@ -60,3 +57,9 @@ class BaseAgentRunner(T.Generic[TContext]): This method should be called after the agent is done. """ ... + + def _transition_state(self, new_state: AgentState) -> None: + """Transition the agent state.""" + if self._state != new_state: + logger.debug(f"Dify Agent state transition: {self._state} -> {new_state}") + self._state = new_state diff --git a/astrbot/core/agent/runners/coze_agent_runner.py b/astrbot/core/agent/runners/coze_agent_runner.py new file mode 100644 index 000000000..885eef699 --- /dev/null +++ b/astrbot/core/agent/runners/coze_agent_runner.py @@ -0,0 +1,367 @@ +import base64 +import json +import sys +import typing as T + +import astrbot.core.message.components as Comp +from astrbot import logger +from astrbot.core import sp +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import ( + LLMResponse, + ProviderRequest, +) +from astrbot.core.provider.sources.coze_api_client import CozeAPIClient + +from ..hooks import BaseAgentRunHooks +from ..response import AgentResponseData +from ..run_context import ContextWrapper, TContext +from .base import AgentResponse, AgentState, BaseAgentRunner + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +class CozeAgentRunner(BaseAgentRunner[TContext]): + """Coze Agent Runner""" + + @override + async def reset( + self, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.final_llm_resp = None + self._state = AgentState.IDLE + self.agent_hooks = agent_hooks + self.run_context = run_context + + self.api_key = provider_config.get("coze_api_key", "") + if not self.api_key: + raise Exception("Coze API Key 不能为空。") + self.bot_id = provider_config.get("bot_id", "") + if not self.bot_id: + raise Exception("Coze Bot ID 不能为空。") + self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn") + + if not isinstance(self.api_base, str) or not self.api_base.startswith( + ("http://", "https://"), + ): + raise Exception( + "Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。", + ) + + self.timeout = provider_config.get("timeout", 120) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + self.auto_save_history = provider_config.get("auto_save_history", True) + + # 创建 API 客户端 + self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base) + + # 会话相关缓存 + self.file_id_cache: dict[str, dict[str, str]] = {} + + @override + async def step(self): + """ + 执行 Coze Agent 的一个步骤 + """ + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + + if self._state == AgentState.IDLE: + try: + await self.agent_hooks.on_agent_begin(self.run_context) + except Exception as e: + logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True) + + # 开始处理,转换到运行状态 + self._transition_state(AgentState.RUNNING) + + try: + # 执行 Coze 请求并处理结果 + async for response in self._execute_coze_request(): + yield response + except Exception as e: + logger.error(f"Coze 请求失败:{str(e)}") + self._transition_state(AgentState.ERROR) + self.final_llm_resp = LLMResponse( + role="err", completion_text=f"Coze 请求失败:{str(e)}" + ) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=MessageChain().message(f"Coze 请求失败:{str(e)}") + ), + ) + finally: + await self.api_client.close() + + @override + async def step_until_done( + self, max_step: int = 30 + ) -> T.AsyncGenerator[AgentResponse, None]: + while not self.done(): + async for resp in self.step(): + yield resp + + async def _execute_coze_request(self): + """执行 Coze 请求的核心逻辑""" + prompt = self.req.prompt or "" + session_id = self.req.session_id or "unknown" + image_urls = self.req.image_urls or [] + contexts = self.req.contexts or [] + system_prompt = self.req.system_prompt + + # 用户ID参数 + user_id = session_id + + # 获取或创建会话ID + conversation_id = await sp.get_async( + scope="umo", + scope_id=user_id, + key="coze_conversation_id", + default="", + ) + + # 构建消息 + additional_messages = [] + + if system_prompt: + if not self.auto_save_history or not conversation_id: + additional_messages.append( + { + "role": "system", + "content": system_prompt, + "content_type": "text", + }, + ) + + # 处理历史上下文 + if not self.auto_save_history and contexts: + for ctx in contexts: + if isinstance(ctx, dict) and "role" in ctx and "content" in ctx: + # 处理上下文中的图片 + content = ctx["content"] + if isinstance(content, list): + # 多模态内容,需要处理图片 + processed_content = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + processed_content.append(item) + elif item.get("type") == "image_url": + # 处理图片上传 + try: + image_data = item.get("image_url", {}) + url = image_data.get("url", "") + if url: + file_id = ( + await self._download_and_upload_image( + url, session_id + ) + ) + processed_content.append( + { + "type": "file", + "file_id": file_id, + "file_url": url, + } + ) + except Exception as e: + logger.warning(f"处理上下文图片失败: {e}") + continue + + if processed_content: + additional_messages.append( + { + "role": ctx["role"], + "content": processed_content, + "content_type": "object_string", + } + ) + else: + # 纯文本内容 + additional_messages.append( + { + "role": ctx["role"], + "content": content, + "content_type": "text", + } + ) + + # 构建当前消息 + if prompt or image_urls: + if image_urls: + # 多模态 + object_string_content = [] + if prompt: + object_string_content.append({"type": "text", "text": prompt}) + + for url in image_urls: + # the url is a base64 string + try: + image_data = base64.b64decode(url) + file_id = await self.api_client.upload_file(image_data) + object_string_content.append( + { + "type": "image", + "file_id": file_id, + } + ) + except Exception as e: + logger.warning(f"处理图片失败 {url}: {e}") + continue + + if object_string_content: + content = json.dumps(object_string_content, ensure_ascii=False) + additional_messages.append( + { + "role": "user", + "content": content, + "content_type": "object_string", + } + ) + elif prompt: + # 纯文本 + additional_messages.append( + { + "role": "user", + "content": prompt, + "content_type": "text", + }, + ) + + # 执行 Coze API 请求 + accumulated_content = "" + message_started = False + + async for chunk in self.api_client.chat_messages( + bot_id=self.bot_id, + user_id=user_id, + additional_messages=additional_messages, + conversation_id=conversation_id, + auto_save_history=self.auto_save_history, + stream=True, + timeout=self.timeout, + ): + event_type = chunk.get("event") + data = chunk.get("data", {}) + + if event_type == "conversation.chat.created": + if isinstance(data, dict) and "conversation_id" in data: + await sp.put_async( + scope="umo", + scope_id=user_id, + key="coze_conversation_id", + value=data["conversation_id"], + ) + + if event_type == "conversation.message.delta": + # 增量消息 + content = data.get("content", "") + if not content and "delta" in data: + content = data["delta"].get("content", "") + if not content and "text" in data: + content = data.get("text", "") + + if content: + accumulated_content += content + message_started = True + + # 如果是流式响应,发送增量数据 + if self.streaming: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message(content) + ), + ) + + elif event_type == "conversation.message.completed": + # 消息完成 + logger.debug("Coze message completed") + message_started = True + + elif event_type == "conversation.chat.completed": + # 对话完成 + logger.debug("Coze chat completed") + break + + elif event_type == "error": + # 错误处理 + error_msg = data.get("msg", "未知错误") + error_code = data.get("code", "UNKNOWN") + logger.error(f"Coze 出现错误: {error_code} - {error_msg}") + raise Exception(f"Coze 出现错误: {error_code} - {error_msg}") + + if not message_started and not accumulated_content: + logger.warning("Coze 未返回任何内容") + accumulated_content = "" + + # 创建最终响应 + chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) + self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain) + self._transition_state(AgentState.DONE) + + try: + await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + + # 返回最终结果 + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=chain), + ) + + async def _download_and_upload_image( + self, + image_url: str, + session_id: str | None = None, + ) -> str: + """下载图片并上传到 Coze,返回 file_id""" + import hashlib + + # 计算哈希实现缓存 + cache_key = hashlib.md5(image_url.encode("utf-8")).hexdigest() + + if session_id: + if session_id not in self.file_id_cache: + self.file_id_cache[session_id] = {} + + if cache_key in self.file_id_cache[session_id]: + file_id = self.file_id_cache[session_id][cache_key] + logger.debug(f"[Coze] 使用缓存的 file_id: {file_id}") + return file_id + + try: + image_data = await self.api_client.download_image(image_url) + file_id = await self.api_client.upload_file(image_data) + + if session_id: + self.file_id_cache[session_id][cache_key] = file_id + logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}") + + return file_id + + except Exception as e: + logger.error(f"处理图片失败 {image_url}: {e!s}") + raise Exception(f"处理图片失败: {e!s}") + + @override + def done(self) -> bool: + """检查 Agent 是否已完成工作""" + return self._state in (AgentState.DONE, AgentState.ERROR) + + @override + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp diff --git a/astrbot/core/agent/runners/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope_agent_runner.py new file mode 100644 index 000000000..6a7cada5f --- /dev/null +++ b/astrbot/core/agent/runners/dashscope_agent_runner.py @@ -0,0 +1,273 @@ +import asyncio +import functools +import re +import sys +import typing as T + +from dashscope import Application +from dashscope.app.application_response import ApplicationResponse + +import astrbot.core.message.components as Comp +from astrbot.core import logger, sp +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import ( + LLMResponse, + ProviderRequest, +) + +from ..hooks import BaseAgentRunHooks +from ..response import AgentResponseData +from ..run_context import ContextWrapper, TContext +from .base import AgentResponse, AgentState, BaseAgentRunner + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +class DashscopeAgentRunner(BaseAgentRunner[TContext]): + """Dashscope Agent Runner""" + + @override + async def reset( + self, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.final_llm_resp = None + self._state = AgentState.IDLE + self.agent_hooks = agent_hooks + self.run_context = run_context + + self.api_key = provider_config.get("dashscope_api_key", "") + if not self.api_key: + raise Exception("阿里云百炼 API Key 不能为空。") + self.app_id = provider_config.get("dashscope_app_id", "") + if not self.app_id: + raise Exception("阿里云百炼 APP ID 不能为空。") + self.dashscope_app_type = provider_config.get("dashscope_app_type", "") + if not self.dashscope_app_type: + raise Exception("阿里云百炼 APP 类型不能为空。") + + self.variables: dict = provider_config.get("variables", {}) or {} + self.rag_options: dict = provider_config.get("rag_options", {}) + self.output_reference = self.rag_options.get("output_reference", False) + self.rag_options = self.rag_options.copy() + self.rag_options.pop("output_reference", None) + + self.timeout = provider_config.get("timeout", 120) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + + def has_rag_options(self): + """判断是否有 RAG 选项 + + Returns: + bool: 是否有 RAG 选项 + + """ + if self.rag_options and ( + len(self.rag_options.get("pipeline_ids", [])) > 0 + or len(self.rag_options.get("file_ids", [])) > 0 + ): + return True + return False + + @override + async def step(self): + """ + 执行 Dashscope Agent 的一个步骤 + """ + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + + if self._state == AgentState.IDLE: + try: + await self.agent_hooks.on_agent_begin(self.run_context) + except Exception as e: + logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True) + + # 开始处理,转换到运行状态 + self._transition_state(AgentState.RUNNING) + + try: + # 执行 Dashscope 请求并处理结果 + async for response in self._execute_dashscope_request(): + yield response + except Exception as e: + logger.error(f"阿里云百炼请求失败:{str(e)}") + self._transition_state(AgentState.ERROR) + self.final_llm_resp = LLMResponse( + role="err", completion_text=f"阿里云百炼请求失败:{str(e)}" + ) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=MessageChain().message(f"阿里云百炼请求失败:{str(e)}") + ), + ) + + @override + async def step_until_done( + self, max_step: int = 30 + ) -> T.AsyncGenerator[AgentResponse, None]: + while not self.done(): + async for resp in self.step(): + yield resp + + async def _remove_image_from_context(self, contexts: list) -> list: + """移除上下文中的图片内容""" + result = [] + for ctx in contexts: + if isinstance(ctx, dict): + content = ctx.get("content", "") + if isinstance(content, list): + # 只保留文本内容 + text_parts = [ + item.get("text", "") + for item in content + if isinstance(item, dict) and item.get("type") == "text" + ] + if text_parts: + new_ctx = ctx.copy() + new_ctx["content"] = " ".join(text_parts) + result.append(new_ctx) + else: + result.append(ctx) + else: + result.append(ctx) + return result + + async def _execute_dashscope_request(self): + """执行 Dashscope 请求的核心逻辑""" + prompt = self.req.prompt or "" + session_id = self.req.session_id or "unknown" + image_urls = self.req.image_urls or [] + contexts = self.req.contexts or [] + system_prompt = self.req.system_prompt + + # 获得会话变量 + payload_vars = self.variables.copy() + # 动态变量 + session_var = await sp.get_async( + scope="umo", + scope_id=session_id, + key="session_variables", + default={}, + ) + payload_vars.update(session_var) + + if ( + self.dashscope_app_type in ["agent", "dialog-workflow"] + and not self.has_rag_options() + ): + # 支持多轮对话的 + new_record = {"role": "user", "content": prompt} + if image_urls: + logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。") + contexts_no_img = await self._remove_image_from_context(contexts) + context_query = [*contexts_no_img, new_record] + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) + for part in context_query: + if "_no_save" in part: + del part["_no_save"] + # 调用阿里云百炼 API + payload = { + "app_id": self.app_id, + "api_key": self.api_key, + "messages": context_query, + "biz_params": payload_vars or None, + } + partial = functools.partial( + Application.call, + **payload, + ) + response = await asyncio.get_event_loop().run_in_executor(None, partial) + else: + # 不支持多轮对话的 + # 调用阿里云百炼 API + payload = { + "app_id": self.app_id, + "prompt": prompt, + "api_key": self.api_key, + "biz_params": payload_vars or None, + } + if self.rag_options: + payload["rag_options"] = self.rag_options + partial = functools.partial( + Application.call, + **payload, + ) + response = await asyncio.get_event_loop().run_in_executor(None, partial) + + assert isinstance(response, ApplicationResponse) + + logger.debug(f"dashscope resp: {response}") + + if response.status_code != 200: + logger.error( + f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code", + ) + self._transition_state(AgentState.ERROR) + self.final_llm_resp = LLMResponse( + role="err", + result_chain=MessageChain().message( + f"阿里云百炼请求失败: message={response.message} code={response.status_code}", + ), + ) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=MessageChain().message( + f"阿里云百炼请求失败: message={response.message} code={response.status_code}" + ) + ), + ) + return + + output_text = response.output.get("text", "") or "" + # RAG 引用脚标格式化 + output_text = re.sub(r"\[(\d+)\]", r"[\1]", output_text) + if self.output_reference and response.output.get("doc_references", None): + ref_parts = [] + for ref in response.output.get("doc_references", []) or []: + ref_title = ( + ref.get("title", "") + if ref.get("title") + else ref.get("doc_name", "") + ) + ref_parts.append(f"{ref['index_id']}. {ref_title}\n") + ref_str = "".join(ref_parts) + output_text += f"\n\n回答来源:\n{ref_str}" + + # 创建最终响应 + chain = MessageChain(chain=[Comp.Plain(output_text)]) + self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain) + self._transition_state(AgentState.DONE) + + try: + await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + + # 返回最终结果 + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=chain), + ) + + @override + def done(self) -> bool: + """检查 Agent 是否已完成工作""" + return self._state in (AgentState.DONE, AgentState.ERROR) + + @override + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp diff --git a/astrbot/core/agent/runners/dify_agent_runner.py b/astrbot/core/agent/runners/dify_agent_runner.py index 145e4a627..1433cfdf4 100644 --- a/astrbot/core/agent/runners/dify_agent_runner.py +++ b/astrbot/core/agent/runners/dify_agent_runner.py @@ -1,22 +1,23 @@ -import sys +import base64 import os +import sys import typing as T -from .base import BaseAgentRunner, AgentResponse, AgentState -from ..hooks import BaseAgentRunHooks -from ..tool_executor import BaseFunctionToolExecutor -from ..run_context import ContextWrapper, TContext -from ..response import AgentResponseData -from astrbot.core.provider.provider import Provider + +import astrbot.core.message.components as Comp +from astrbot.core import logger, sp from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import ( - ProviderRequest, LLMResponse, + ProviderRequest, ) -from astrbot.core.utils.dify_api_client import DifyAPIClient -from astrbot.core.utils.io import download_image_by_url, download_file from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core import logger, sp -import astrbot.core.message.components as Comp +from astrbot.core.utils.dify_api_client import DifyAPIClient +from astrbot.core.utils.io import download_file + +from ..hooks import BaseAgentRunHooks +from ..response import AgentResponseData +from ..run_context import ContextWrapper, TContext +from .base import AgentResponse, AgentState, BaseAgentRunner if sys.version_info >= (3, 12): from typing import override @@ -30,53 +31,36 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): @override async def reset( self, - provider: Provider, request: ProviderRequest, run_context: ContextWrapper[TContext], - tool_executor: BaseFunctionToolExecutor[TContext], agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, **kwargs: T.Any, ) -> None: self.req = request self.streaming = kwargs.get("streaming", False) - self.provider = provider self.final_llm_resp = None self._state = AgentState.IDLE - self.tool_executor = tool_executor self.agent_hooks = agent_hooks self.run_context = run_context - # Dify 特定配置 - 从 provider 或 kwargs 中获取 - self.api_key = kwargs.get("dify_api_key", "") - api_base = kwargs.get("dify_api_base", "https://api.dify.ai/v1") - self.api_type = kwargs.get("dify_api_type", "") - - self.workflow_output_key = kwargs.get( - "dify_workflow_output_key", "astrbot_wf_output" + self.api_key = provider_config.get("dify_api_key", "") + self.api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1") + self.api_type = provider_config.get("dify_api_type", "chat") + self.workflow_output_key = provider_config.get( + "dify_workflow_output_key", + "astrbot_wf_output", ) - self.dify_query_input_key = kwargs.get( - "dify_query_input_key", "astrbot_text_query" + self.dify_query_input_key = provider_config.get( + "dify_query_input_key", + "astrbot_text_query", ) - if not self.dify_query_input_key: - self.dify_query_input_key = "astrbot_text_query" - if not self.workflow_output_key: - self.workflow_output_key = "astrbot_wf_output" - - self.variables: dict = kwargs.get("variables", {}) - self.timeout = kwargs.get("timeout", 120) + self.variables: dict = provider_config.get("variables", {}) or {} + self.timeout = provider_config.get("timeout", 60) if isinstance(self.timeout, str): self.timeout = int(self.timeout) - self.conversation_ids = {} - """记录当前 session id 的对话 ID""" - - self.api_client = DifyAPIClient(self.api_key, api_base) - - def _transition_state(self, new_state: AgentState) -> None: - """转换 Agent 状态""" - if self._state != new_state: - logger.debug(f"Dify Agent state transition: {self._state} -> {new_state}") - self._state = new_state + self.api_client = DifyAPIClient(self.api_key, self.api_base) @override async def step(self): @@ -111,6 +95,16 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): chain=MessageChain().message(f"Dify 请求失败:{str(e)}") ), ) + finally: + await self.api_client.close() + + @override + async def step_until_done( + self, max_step: int = 30 + ) -> T.AsyncGenerator[AgentResponse, None]: + while not self.done(): + async for resp in self.step(): + yield resp async def _execute_dify_request(self): """执行 Dify 请求的核心逻辑""" @@ -119,20 +113,22 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): image_urls = self.req.image_urls or [] system_prompt = self.req.system_prompt - conversation_id = self.conversation_ids.get(session_id, "") + conversation_id = await sp.get_async( + scope="umo", + scope_id=session_id, + key="dify_conversation_id", + default="", + ) result = "" # 处理图片上传 files_payload = [] for image_url in image_urls: + # image_url is a base64 string try: - image_path = ( - await download_image_by_url(image_url) - if image_url.startswith("http") - else image_url - ) + image_data = base64.b64decode(image_url) file_response = await self.api_client.file_upload( - image_path, user=session_id + file_data=image_data, user=session_id ) logger.debug(f"Dify 上传图片响应:{file_response}") if "id" not in file_response: @@ -154,7 +150,12 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): # 获得会话变量 payload_vars = self.variables.copy() # 动态变量 - session_var = await sp.session_get(session_id, "session_variables", default={}) + session_var = await sp.get_async( + scope="umo", + scope_id=session_id, + key="session_variables", + default={}, + ) payload_vars.update(session_var) payload_vars["system_prompt"] = system_prompt @@ -178,7 +179,12 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): if chunk["event"] == "message" or chunk["event"] == "agent_message": result += chunk["answer"] if not conversation_id: - self.conversation_ids[session_id] = chunk["conversation_id"] + await sp.put_async( + scope="umo", + scope_id=session_id, + key="dify_conversation_id", + value=chunk["conversation_id"], + ) conversation_id = chunk["conversation_id"] # 如果是流式响应,发送增量数据 @@ -314,13 +320,3 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): @override def get_final_llm_resp(self) -> LLMResponse | None: return self.final_llm_resp - - async def forget(self, session_id): - """忘记会话上下文""" - self.conversation_ids[session_id] = "" - return True - - async def terminate(self): - """终止并清理资源""" - if hasattr(self, "api_client"): - await self.api_client.close() diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index d74a45982..6f3c813eb 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -69,12 +69,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): ) self.run_context.messages = messages - def _transition_state(self, new_state: AgentState) -> None: - """转换 Agent 状态""" - if self._state != new_state: - logger.debug(f"Agent state transition: {self._state} -> {new_state}") - self._state = new_state - async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: """Yields chunks *and* a final LLMResponse.""" if self.streaming: diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index cf3658670..4e1101010 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -69,8 +69,9 @@ DEFAULT_CONFIG = { "streaming_response": False, "show_tool_use_status": False, "agent_runner_type": "local", - "dify_runner_provider_id": "", - "coze_runner_provider_id": "", + "dify_agent_runner_provider_id": "", + "coze_agent_runner_provider_id": "", + "dashscope_agent_runner_provider_id": "", "unsupported_streaming_strategy": "realtime_segmenting", "max_agent_step": 30, "tool_call_timeout": 60, @@ -1041,7 +1042,7 @@ CONFIG_METADATA_2 = { "id": "dashscope", "provider": "dashscope", "type": "dashscope", - "provider_type": "chat_completion", + "provider_type": "agent_runner", "enable": True, "dashscope_app_type": "agent", "dashscope_api_key": "", @@ -2042,10 +2043,13 @@ CONFIG_METADATA_2 = { "agent_runner_type": { "type": "string", }, - "dify_runner_provider_id": { + "dify_agent_runner_provider_id": { "type": "string", }, - "coze_runner_provider_id": { + "coze_agent_runner_provider_id": { + "type": "string", + }, + "dashscope_agent_runner_provider_id": { "type": "string", }, "max_agent_step": { @@ -2201,42 +2205,36 @@ CONFIG_METADATA_3 = { "provider_settings.agent_runner_type": { "description": "执行器", "type": "string", - "options": ["local", "dify", "coze"], - "labels": ["内置 Agent", "Dify", "Coze"], + "options": ["local", "dify", "coze", "dashscope"], + "labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"], "condition": { "provider_settings.enable": True, }, }, - }, - }, - "dify_runner": { - "description": "Dify", - "type": "object", - "items": { - "provider_settings.dify_runner_provider_id": { - "description": "Dify 执行器提供商 ID", + "provider_settings.coze_agent_runner_provider_id": { + "description": "Coze Agent 执行器提供商 ID", "type": "string", - "_special": "select_provider_dify_runner", + "_special": "select_agent_runner_provider", + "condition": { + "provider_settings.agent_runner_type": "coze", + }, }, - }, - "condition": { - "provider_settings.agent_runner_type": "dify", - "provider_settings.enable": True, - }, - }, - "coze_runner": { - "description": "Coze", - "type": "object", - "items": { - "provider_settings.coze_runner_provider_id": { - "description": "Coze 执行器提供商 ID", + "provider_settings.dify_agent_runner_provider_id": { + "description": "Dify Agent 执行器提供商 ID", "type": "string", - "_special": "select_provider_coze_runner", + "_special": "select_agent_runner_provider", + "condition": { + "provider_settings.agent_runner_type": "dify", + }, + }, + "provider_settings.dashscope_agent_runner_provider_id": { + "description": "阿里云百炼应用 Agent 执行器提供商 ID", + "type": "string", + "_special": "select_agent_runner_provider", + "condition": { + "provider_settings.agent_runner_type": "dashscope", + }, }, - }, - "condition": { - "provider_settings.agent_runner_type": "coze", - "provider_settings.enable": True, }, }, "ai": { @@ -2248,9 +2246,6 @@ CONFIG_METADATA_3 = { "type": "string", "_special": "select_provider", "hint": "留空时使用第一个模型", - "condition": { - "provider_settings.agent_runner_type": "local", - }, }, "provider_settings.default_image_caption_provider_id": { "description": "默认图片转述模型", diff --git a/astrbot/core/pipeline/process_stage/method/agent_request.py b/astrbot/core/pipeline/process_stage/method/agent_request.py new file mode 100644 index 000000000..f6f81631e --- /dev/null +++ b/astrbot/core/pipeline/process_stage/method/agent_request.py @@ -0,0 +1,48 @@ +from collections.abc import AsyncGenerator + +from astrbot.core import logger +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.star.session_llm_manager import SessionServiceManager + +from ...context import PipelineContext +from ..stage import Stage +from .agent_sub_stages.internal import InternalAgentSubStage +from .agent_sub_stages.third_party import ThirdPartyAgentSubStage + + +class AgentRequestSubStage(Stage): + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + self.config = ctx.astrbot_config + + self.bot_wake_prefixs: list[str] = self.config["wake_prefix"] + self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"] + for bwp in self.bot_wake_prefixs: + if self.prov_wake_prefix.startswith(bwp): + logger.info( + f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。", + ) + self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :] + + agent_runner_type = self.config["provider_settings"]["agent_runner_type"] + if agent_runner_type == "local": + self.agent_sub_stage = InternalAgentSubStage() + else: + self.agent_sub_stage = ThirdPartyAgentSubStage() + await self.agent_sub_stage.initialize(ctx) + + async def process(self, event: AstrMessageEvent) -> AsyncGenerator[None, None]: + if not self.ctx.astrbot_config["provider_settings"]["enable"]: + logger.debug( + "This pipeline does not enable AI capability, skip processing." + ) + return + + if not SessionServiceManager.should_process_llm_request(event): + logger.debug( + f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing." + ) + return + + async for resp in self.agent_sub_stage.process(event, self.prov_wake_prefix): + yield resp diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py similarity index 91% rename from astrbot/core/pipeline/process_stage/method/llm_request.py rename to astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index bd9e4ce3b..aaada9a19 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -21,27 +21,24 @@ from astrbot.core.provider.entities import ( LLMResponse, ProviderRequest, ) -from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager -from ....astr_agent_context import AgentContextWrapper -from ....astr_agent_hooks import MAIN_AGENT_HOOKS -from ....astr_agent_run_util import AgentRunner, run_agent -from ....astr_agent_tool_exec import FunctionToolExecutor -from ...context import PipelineContext, call_event_hook -from ..stage import Stage -from ..utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base +from .....astr_agent_context import AgentContextWrapper +from .....astr_agent_hooks import MAIN_AGENT_HOOKS +from .....astr_agent_run_util import AgentRunner, run_agent +from .....astr_agent_tool_exec import FunctionToolExecutor +from ....context import PipelineContext, call_event_hook +from ...stage import Stage +from ...utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base -class LLMRequestSubStage(Stage): +class InternalAgentSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx conf = ctx.astrbot_config settings = conf["provider_settings"] - self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list - self.provider_wake_prefix: str = settings["wake_prefix"] # str self.max_context_length = settings["max_context_length"] # int self.dequeue_context_length: int = min( max(1, settings["dequeue_context_length"]), @@ -59,13 +56,6 @@ class LLMRequestSubStage(Stage): self.show_reasoning = settings.get("display_reasoning_text", False) self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False) - for bwp in self.bot_wake_prefixs: - if self.provider_wake_prefix.startswith(bwp): - logger.info( - f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。", - ) - self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :] - self.conv_manager = ctx.plugin_manager.context.conversation_manager def _select_provider(self, event: AstrMessageEvent): @@ -304,21 +294,10 @@ class LLMRequestSubStage(Stage): return fixed_messages async def process( - self, - event: AstrMessageEvent, - _nested: bool = False, - ) -> None | AsyncGenerator[None, None]: + self, event: AstrMessageEvent, provider_wake_prefix: str + ) -> AsyncGenerator[None, None]: req: ProviderRequest | None = None - if not self.ctx.astrbot_config["provider_settings"]["enable"]: - logger.debug("未启用 LLM 能力,跳过处理。") - return - - # 检查会话级别的LLM启停状态 - if not SessionServiceManager.should_process_llm_request(event): - logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。") - return - provider = self._select_provider(event) if provider is None: return @@ -348,12 +327,12 @@ class LLMRequestSubStage(Stage): req.image_urls = [] if sel_model := event.get_extra("selected_model"): req.model = sel_model - if self.provider_wake_prefix and not event.message_str.startswith( - self.provider_wake_prefix + if provider_wake_prefix and not event.message_str.startswith( + provider_wake_prefix ): return - req.prompt = event.message_str[len(self.provider_wake_prefix) :] + req.prompt = event.message_str[len(provider_wake_prefix) :] # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。 # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() for comp in event.message_obj.message: diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py new file mode 100644 index 000000000..d4d0709e7 --- /dev/null +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -0,0 +1,126 @@ +import asyncio +from collections.abc import AsyncGenerator + +from astrbot.core import logger +from astrbot.core.agent.runners.coze_agent_runner import CozeAgentRunner +from astrbot.core.agent.runners.dashscope_agent_runner import DashscopeAgentRunner +from astrbot.core.agent.runners.dify_agent_runner import DifyAgentRunner +from astrbot.core.message.components import Image +from astrbot.core.message.message_event_result import ( + MessageEventResult, + ResultContentType, +) +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.provider.entities import ( + ProviderRequest, +) +from astrbot.core.star.star_handler import EventType +from astrbot.core.utils.metrics import Metric + +from .....astr_agent_context import AgentContextWrapper, AstrAgentContext +from .....astr_agent_hooks import MAIN_AGENT_HOOKS +from ....context import PipelineContext, call_event_hook +from ...stage import Stage + +AGENT_RUNNER_TYPE_KEY = { + "dify": "dify_agent_runner_provider_id", + "coze": "coze_agent_runner_provider_id", + "dashscope": "dashscope_agent_runner_provider_id", +} + + +class ThirdPartyAgentSubStage(Stage): + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + self.conf = ctx.astrbot_config + self.runner_type = self.conf["provider_settings"]["agent_runner_type"] + self.prov_id = self.conf["provider_settings"].get( + AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""), + "", + ) + self.prov_cfg: dict = next( + (p for p in self.conf["provider"] if p["id"] == self.prov_id), + {}, + ) + + async def process( + self, event: AstrMessageEvent, provider_wake_prefix: str + ) -> AsyncGenerator[None, None]: + req: ProviderRequest | None = None + + if provider_wake_prefix and not event.message_str.startswith( + provider_wake_prefix + ): + return + if not self.prov_id or not self.prov_cfg: + logger.error( + "Third Party Agent Runner provider ID is not configured properly." + ) + return + + # make provider request + req = ProviderRequest() + req.session_id = event.unified_msg_origin + req.prompt = event.message_str[len(provider_wake_prefix) :] + for comp in event.message_obj.message: + if isinstance(comp, Image): + image_path = await comp.convert_to_base64() + req.image_urls.append(image_path) + + if not req.prompt and not req.image_urls: + return + + # call event hook + if await call_event_hook(event, EventType.OnLLMRequestEvent, req): + return + + if self.runner_type == "dify": + runner = DifyAgentRunner[AstrAgentContext]() + elif self.runner_type == "coze": + runner = CozeAgentRunner[AstrAgentContext]() + elif self.runner_type == "dashscope": + runner = DashscopeAgentRunner[AstrAgentContext]() + else: + raise ValueError( + f"Unsupported third party agent runner type: {self.runner_type}", + ) + + astr_agent_ctx = AstrAgentContext( + context=self.ctx.plugin_manager.context, + event=event, + ) + + await runner.reset( + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=60, + ), + agent_hooks=MAIN_AGENT_HOOKS, + provider_config=self.prov_cfg, + ) + + async for _ in runner.step_until_done(): + pass + + final_resp = runner.get_final_llm_resp() + + if not final_resp or not final_resp.result_chain: + logger.warning("Agent Runner 未返回最终结果。") + return + + event.set_result( + MessageEventResult( + chain=final_resp.result_chain.chain or [], + result_content_type=ResultContentType.LLM_RESULT, + ), + ) + yield + + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=self.runner_type, + provider_type=self.runner_type, + ), + ) diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index ff8120b16..56d305de4 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -24,7 +24,7 @@ class StarRequestSubStage(Stage): async def process( self, event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: + ) -> AsyncGenerator[None, None]: activated_handlers: list[StarHandlerMetadata] = event.get_extra( "activated_handlers", ) diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 9f0b5f92a..2eeefcf11 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -7,7 +7,7 @@ from astrbot.core.star.star_handler import StarHandlerMetadata from ..context import PipelineContext from ..stage import Stage, register_stage -from .method.llm_request import LLMRequestSubStage +from .method.agent_request import AgentRequestSubStage from .method.star_request import StarRequestSubStage @@ -17,9 +17,12 @@ class ProcessStage(Stage): self.ctx = ctx self.config = ctx.astrbot_config self.plugin_manager = ctx.plugin_manager - self.llm_request_sub_stage = LLMRequestSubStage() - await self.llm_request_sub_stage.initialize(ctx) + # initialize agent sub stage + self.agent_sub_stage = AgentRequestSubStage() + await self.agent_sub_stage.initialize(ctx) + + # initialize star request sub stage self.star_request_sub_stage = StarRequestSubStage() await self.star_request_sub_stage.initialize(ctx) @@ -39,7 +42,7 @@ class ProcessStage(Stage): # Handler 的 LLM 请求 event.set_extra("provider_request", resp) _t = False - async for _ in self.llm_request_sub_stage.process(event): + async for _ in self.agent_sub_stage.process(event): _t = True yield if not _t: @@ -67,5 +70,5 @@ class ProcessStage(Stage): logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。") return - async for _ in self.llm_request_sub_stage.process(event): + async for _ in self.agent_sub_stage.process(event): yield diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index ec2550415..d665b38ab 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -227,6 +227,8 @@ class ProviderManager: async def load_provider(self, provider_config: dict): if not provider_config["enable"]: return + if provider_config["provider_type"] == "agent_runner": + return logger.info( f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...", @@ -247,14 +249,6 @@ class ProviderManager: from .sources.anthropic_source import ( ProviderAnthropic as ProviderAnthropic, ) - case "dify": - from .sources.dify_source import ProviderDify as ProviderDify - case "coze": - from .sources.coze_source import ProviderCoze as ProviderCoze - case "dashscope": - from .sources.dashscope_source import ( - ProviderDashscope as ProviderDashscope, - ) case "googlegenai_chat_completion": from .sources.gemini_source import ( ProviderGoogleGenAI as ProviderGoogleGenAI, diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py deleted file mode 100644 index 7850a982c..000000000 --- a/astrbot/core/provider/sources/dify_source.py +++ /dev/null @@ -1,285 +0,0 @@ -import os - -import astrbot.core.message.components as Comp -from astrbot.core import logger, sp -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.utils.dify_api_client import DifyAPIClient -from astrbot.core.utils.io import download_file, download_image_by_url - -from .. import Provider -from ..entities import LLMResponse -from ..register import register_provider_adapter - - -@register_provider_adapter("dify", "Dify APP 适配器。") -class ProviderDify(Provider): - def __init__( - self, - provider_config, - provider_settings, - ) -> None: - super().__init__( - provider_config, - provider_settings, - ) - self.api_key = provider_config.get("dify_api_key", "") - if not self.api_key: - raise Exception("Dify API Key 不能为空。") - api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1") - self.api_type = provider_config.get("dify_api_type", "") - if not self.api_type: - raise Exception("Dify API 类型不能为空。") - self.model_name = "dify" - self.workflow_output_key = provider_config.get( - "dify_workflow_output_key", - "astrbot_wf_output", - ) - self.dify_query_input_key = provider_config.get( - "dify_query_input_key", - "astrbot_text_query", - ) - if not self.dify_query_input_key: - self.dify_query_input_key = "astrbot_text_query" - if not self.workflow_output_key: - self.workflow_output_key = "astrbot_wf_output" - self.variables: dict = provider_config.get("variables", {}) - self.timeout = provider_config.get("timeout", 120) - if isinstance(self.timeout, str): - self.timeout = int(self.timeout) - self.conversation_ids = {} - """记录当前 session id 的对话 ID""" - - self.api_client = DifyAPIClient(self.api_key, api_base) - - async def text_chat( - self, - prompt: str, - session_id=None, - image_urls=None, - func_tool=None, - contexts=None, - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ) -> LLMResponse: - if image_urls is None: - image_urls = [] - result = "" - session_id = session_id or kwargs.get("user") or "unknown" # 1734 - conversation_id = self.conversation_ids.get(session_id, "") - - files_payload = [] - for image_url in image_urls: - image_path = ( - await download_image_by_url(image_url) - if image_url.startswith("http") - else image_url - ) - file_response = await self.api_client.file_upload( - image_path, - user=session_id, - ) - logger.debug(f"Dify 上传图片响应:{file_response}") - if "id" not in file_response: - logger.warning( - f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。", - ) - continue - files_payload.append( - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": file_response["id"], - }, - ) - - # 获得会话变量 - payload_vars = self.variables.copy() - # 动态变量 - session_var = await sp.session_get(session_id, "session_variables", default={}) - payload_vars.update(session_var) - payload_vars["system_prompt"] = system_prompt - - try: - match self.api_type: - case "chat" | "agent" | "chatflow": - if not prompt: - prompt = "请描述这张图片。" - - async for chunk in self.api_client.chat_messages( - inputs={ - **payload_vars, - }, - query=prompt, - user=session_id, - conversation_id=conversation_id, - files=files_payload, - timeout=self.timeout, - ): - logger.debug(f"dify resp chunk: {chunk}") - if ( - chunk["event"] == "message" - or chunk["event"] == "agent_message" - ): - result += chunk["answer"] - if not conversation_id: - self.conversation_ids[session_id] = chunk[ - "conversation_id" - ] - conversation_id = chunk["conversation_id"] - elif chunk["event"] == "message_end": - logger.debug("Dify message end") - break - elif chunk["event"] == "error": - logger.error(f"Dify 出现错误:{chunk}") - raise Exception( - f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}", - ) - - case "workflow": - async for chunk in self.api_client.workflow_run( - inputs={ - self.dify_query_input_key: prompt, - "astrbot_session_id": session_id, - **payload_vars, - }, - user=session_id, - files=files_payload, - timeout=self.timeout, - ): - match chunk["event"]: - case "workflow_started": - logger.info( - f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。", - ) - case "node_finished": - logger.debug( - f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。", - ) - case "workflow_finished": - logger.info( - f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束", - ) - logger.debug(f"Dify 工作流结果:{chunk}") - if chunk["data"]["error"]: - logger.error( - f"Dify 工作流出现错误:{chunk['data']['error']}", - ) - raise Exception( - f"Dify 工作流出现错误:{chunk['data']['error']}", - ) - if ( - self.workflow_output_key - not in chunk["data"]["outputs"] - ): - raise Exception( - f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}", - ) - result = chunk - case _: - raise Exception(f"未知的 Dify API 类型:{self.api_type}") - except Exception as e: - logger.error(f"Dify 请求失败:{e!s}") - return LLMResponse(role="err", completion_text=f"Dify 请求失败:{e!s}") - - if not result: - logger.warning("Dify 请求结果为空,请查看 Debug 日志。") - - chain = await self.parse_dify_result(result) - - return LLMResponse(role="assistant", result_chain=chain) - - async def text_chat_stream( - self, - prompt, - session_id=None, - image_urls=..., - func_tool=None, - contexts=..., - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response - - async def parse_dify_result(self, chunk: dict | str) -> MessageChain: - if isinstance(chunk, str): - # Chat - return MessageChain(chain=[Comp.Plain(chunk)]) - - async def parse_file(item: dict): - match item["type"]: - case "image": - return Comp.Image(file=item["url"], url=item["url"]) - case "audio": - # 仅支持 wav - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - path = os.path.join(temp_dir, f"{item['filename']}.wav") - await download_file(item["url"], path) - return Comp.Image(file=item["url"], url=item["url"]) - case "video": - return Comp.Video(file=item["url"]) - case _: - return Comp.File(name=item["filename"], file=item["url"]) - - output = chunk["data"]["outputs"][self.workflow_output_key] - chains = [] - if isinstance(output, str): - # 纯文本输出 - chains.append(Comp.Plain(output)) - elif isinstance(output, list): - # 主要适配 Dify 的 HTTP 请求结点的多模态输出 - for item in output: - # handle Array[File] - if ( - not isinstance(item, dict) - or item.get("dify_model_identity", "") != "__dify__file__" - ): - chains.append(Comp.Plain(str(output))) - break - else: - chains.append(Comp.Plain(str(output))) - - # scan file - files = chunk["data"].get("files", []) - for item in files: - comp = await parse_file(item) - chains.append(comp) - - return MessageChain(chain=chains) - - async def forget(self, session_id): - self.conversation_ids[session_id] = "" - return True - - async def get_current_key(self): - return self.api_key - - async def set_key(self, key): - raise Exception("Dify 适配器不支持设置 API Key。") - - async def get_models(self): - return [self.get_model()] - - async def get_human_readable_context(self, session_id, page, page_size): - raise Exception("暂不支持获得 Dify 的历史消息记录。") - - async def terminate(self): - await self.api_client.close() diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py index ea8ff9dff..20efdbbd0 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/utils/dify_api_client.py @@ -101,14 +101,16 @@ class DifyAPIClient: async def file_upload( self, - file_path: str, user: str, + file_path: str | None = None, + file_data: bytes | None = None, ) -> dict[str, Any]: url = f"{self.api_base}/files/upload" - with open(file_path, "rb") as f: + + if file_data is not None: payload = { "user": user, - "file": f, + "file": file_data, } async with self.session.post( url, @@ -116,6 +118,20 @@ class DifyAPIClient: headers=self.headers, ) as resp: return await resp.json() # {"id": "xxx", ...} + elif file_path is not None: + with open(file_path, "rb") as f: + payload = { + "user": user, + "file": f, + } + async with self.session.post( + url, + data=payload, + headers=self.headers, + ) as resp: + return await resp.json() # {"id": "xxx", ...} + else: + raise ValueError("file_path 和 file_data 不能同时为 None") async def close(self): await self.session.close() diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index c6b4c5ede..6b1f52a69 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -40,9 +40,6 @@ class SharedPreferences: else: ret = default return ret - raise ValueError( - "scope_id and key cannot be None when getting a specific preference.", - ) async def range_get_async( self, @@ -56,30 +53,6 @@ class SharedPreferences: ret = await self.db_helper.get_preferences(scope, scope_id, key) return ret - @overload - async def session_get( - self, - umo: None, - key: str, - default: Any = None, - ) -> list[Preference]: ... - - @overload - async def session_get( - self, - umo: str, - key: None, - default: Any = None, - ) -> list[Preference]: ... - - @overload - async def session_get( - self, - umo: None, - key: None, - default: Any = None, - ) -> list[Preference]: ... - async def session_get( self, umo: str | None, @@ -88,7 +61,7 @@ class SharedPreferences: ) -> _VT | list[Preference]: """获取会话范围的偏好设置 - Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。 + Note: 当 umo 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。 """ if umo is None or key is None: return await self.range_get_async("umo", umo, key) diff --git a/dashboard/src/components/shared/AstrBotConfigV4.vue b/dashboard/src/components/shared/AstrBotConfigV4.vue index 77fe39ebc..57c7b36ee 100644 --- a/dashboard/src/components/shared/AstrBotConfigV4.vue +++ b/dashboard/src/components/shared/AstrBotConfigV4.vue @@ -230,11 +230,8 @@ function hasVisibleItemsAfter(items, currentIndex) {
-
- -
-
- +
+
切换对话。" + cfg = self.context.get_config(umo=message.unified_msg_origin) + agent_runner_type = cfg["provider"]["agent_runner_type"] + if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: + message.set_result( + MessageEventResult().message( + f"{THIRD_PARTY_AGENT_RUNNER_STR} 对话列表功能暂不支持。", + ), ) - ret = "".join(parts) - message.set_result(MessageEventResult().message(ret)) return size_per_page = 6 @@ -241,15 +226,15 @@ class ConversationCommands: async def new_conv(self, message: AstrMessageEvent): """创建新对话""" - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type in ["dify", "coze"]: - assert isinstance(provider, (ProviderDify, ProviderCoze)), ( - "provider type is not dify or coze" - ) - await provider.forget(message.unified_msg_origin) - message.set_result( - MessageEventResult().message("成功,下次聊天将是新对话。"), + cfg = self.context.get_config(umo=message.unified_msg_origin) + agent_runner_type = cfg["provider"]["agent_runner_type"] + if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: + await sp.remove_async( + scope="umo", + scope_id=message.unified_msg_origin, + key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type], ) + message.set_result(MessageEventResult().message("已创建新对话。")) return cpersona = await self._get_current_persona_id(message.unified_msg_origin) @@ -272,19 +257,9 @@ class ConversationCommands: async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""): """创建新群聊对话""" - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type in ["dify", "coze"]: - assert isinstance(provider, (ProviderDify, ProviderCoze)), ( - "provider type is not dify or coze" - ) - await provider.forget(message.unified_msg_origin) - message.set_result( - MessageEventResult().message("成功,下次聊天将是新对话。"), - ) - return if sid: session = str( - MessageSesion( + MessageSession( platform_name=message.platform_meta.id, message_type=MessageType("GroupMessage"), session_id=sid, @@ -319,31 +294,6 @@ class ConversationCommands: ) return - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify), "provider type is not dify" - data = await provider.api_client.get_chat_convs(message.unified_msg_origin) - if not data["data"]: - message.set_result(MessageEventResult().message("未找到任何对话。")) - return - selected_conv = None - if index is not None: - try: - selected_conv = data["data"][index - 1] - except IndexError: - message.set_result( - MessageEventResult().message("对话序号错误,请使用 /ls 查看"), - ) - return - else: - selected_conv = data["data"][0] - ret = ( - f"Dify 切换到对话: {selected_conv['name']}({selected_conv['id'][:4]})。" - ) - provider.conversation_ids[message.unified_msg_origin] = selected_conv["id"] - message.set_result(MessageEventResult().message(ret)) - return - if index is None: message.set_result( MessageEventResult().message( @@ -376,19 +326,6 @@ class ConversationCommands: if not new_name: message.set_result(MessageEventResult().message("请输入新的对话名称。")) return - - provider = self.context.get_using_provider(message.unified_msg_origin) - - if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify) - cid = provider.conversation_ids.get(message.unified_msg_origin, None) - if not cid: - message.set_result(MessageEventResult().message("未找到当前对话。")) - return - await provider.api_client.rename(cid, new_name, message.unified_msg_origin) - message.set_result(MessageEventResult().message("重命名对话成功。")) - return - await self.context.conversation_manager.update_conversation_title( message.unified_msg_origin, new_name, @@ -408,20 +345,14 @@ class ConversationCommands: ) return - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify) - dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None) - if dify_cid: - await provider.api_client.delete_chat_conv( - message.unified_msg_origin, - dify_cid, - ) - message.set_result( - MessageEventResult().message( - "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。", - ), + agent_runner_type = cfg["provider"]["agent_runner_type"] + if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: + await sp.remove_async( + scope="umo", + scope_id=message.unified_msg_origin, + key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type], ) + message.set_result(MessageEventResult().message("重置对话成功。")) return session_curr_cid = ( diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 969343fe0..c1466ef3c 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -5,7 +5,6 @@ from astrbot.api.event import AstrMessageEvent, filter from astrbot.api.message_components import Image, Plain from astrbot.api.provider import LLMResponse, ProviderRequest from astrbot.core import logger -from astrbot.core.provider.sources.dify_source import ProviderDify from .commands import ( AdminCommands, @@ -279,33 +278,20 @@ class Main(star.Star): return try: conv = None - if provider.meta().type != "dify": - session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id( - event.unified_msg_origin, - ) + session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id( + event.unified_msg_origin, + ) - if not session_curr_cid: - logger.error( - "当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。", - ) - return + if not session_curr_cid: + logger.error( + "当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。", + ) + return - conv = await self.context.conversation_manager.get_conversation( - event.unified_msg_origin, - session_curr_cid, - ) - else: - # Dify 自己有维护对话,不需要 bot 端维护。 - assert isinstance(provider, ProviderDify) - cid = provider.conversation_ids.get( - event.unified_msg_origin, - None, - ) - if cid is None: - logger.error( - "[Dify] 当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。", - ) - return + conv = await self.context.conversation_manager.get_conversation( + event.unified_msg_origin, + session_curr_cid, + ) prompt = event.message_str From 82c9cf4db680bf210f51d6de1984ecaf5f919e1c Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Nov 2025 20:18:46 +0800 Subject: [PATCH 06/29] chore: remove legacy coze and dashscope provider --- astrbot/core/provider/sources/coze_source.py | 650 ------------------ .../core/provider/sources/dashscope_source.py | 207 ------ 2 files changed, 857 deletions(-) delete mode 100644 astrbot/core/provider/sources/coze_source.py delete mode 100644 astrbot/core/provider/sources/dashscope_source.py diff --git a/astrbot/core/provider/sources/coze_source.py b/astrbot/core/provider/sources/coze_source.py deleted file mode 100644 index 6f1355bf7..000000000 --- a/astrbot/core/provider/sources/coze_source.py +++ /dev/null @@ -1,650 +0,0 @@ -import base64 -import hashlib -import json -import os -from collections.abc import AsyncGenerator - -import astrbot.core.message.components as Comp -from astrbot import logger -from astrbot.api.provider import Provider -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.provider.entities import LLMResponse - -from ..register import register_provider_adapter -from .coze_api_client import CozeAPIClient - - -@register_provider_adapter("coze", "Coze (扣子) 智能体适配器") -class ProviderCoze(Provider): - def __init__( - self, - provider_config, - provider_settings, - ) -> None: - super().__init__( - provider_config, - provider_settings, - ) - self.api_key = provider_config.get("coze_api_key", "") - if not self.api_key: - raise Exception("Coze API Key 不能为空。") - self.bot_id = provider_config.get("bot_id", "") - if not self.bot_id: - raise Exception("Coze Bot ID 不能为空。") - self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn") - - if not isinstance(self.api_base, str) or not self.api_base.startswith( - ("http://", "https://"), - ): - raise Exception( - "Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。", - ) - - self.timeout = provider_config.get("timeout", 120) - if isinstance(self.timeout, str): - self.timeout = int(self.timeout) - self.auto_save_history = provider_config.get("auto_save_history", True) - self.conversation_ids: dict[str, str] = {} - self.file_id_cache: dict[str, dict[str, str]] = {} - - # 创建 API 客户端 - self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base) - - def _generate_cache_key(self, data: str, is_base64: bool = False) -> str: - """生成统一的缓存键 - - Args: - data: 图片数据或路径 - is_base64: 是否是 base64 数据 - - Returns: - str: 缓存键 - - """ - try: - if is_base64 and data.startswith("data:image/"): - try: - header, encoded = data.split(",", 1) - image_bytes = base64.b64decode(encoded) - cache_key = hashlib.md5(image_bytes).hexdigest() - return cache_key - except Exception: - cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest() - return cache_key - elif data.startswith(("http://", "https://")): - # URL图片,使用URL作为缓存键 - cache_key = hashlib.md5(data.encode("utf-8")).hexdigest() - return cache_key - else: - clean_path = ( - data.split("_")[0] - if "_" in data and len(data.split("_")) >= 3 - else data - ) - - if os.path.exists(clean_path): - with open(clean_path, "rb") as f: - file_content = f.read() - cache_key = hashlib.md5(file_content).hexdigest() - return cache_key - cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest() - return cache_key - - except Exception as e: - cache_key = hashlib.md5(data.encode("utf-8")).hexdigest() - logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}") - return cache_key - - async def _upload_file( - self, - file_data: bytes, - session_id: str | None = None, - cache_key: str | None = None, - ) -> str: - """上传文件到 Coze 并返回 file_id""" - # 使用 API 客户端上传文件 - file_id = await self.api_client.upload_file(file_data) - - # 缓存 file_id - if session_id and cache_key: - if session_id not in self.file_id_cache: - self.file_id_cache[session_id] = {} - self.file_id_cache[session_id][cache_key] = file_id - logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}") - - return file_id - - async def _download_and_upload_image( - self, - image_url: str, - session_id: str | None = None, - ) -> str: - """下载图片并上传到 Coze,返回 file_id""" - # 计算哈希实现缓存 - cache_key = self._generate_cache_key(image_url) if session_id else None - - if session_id and cache_key: - if session_id not in self.file_id_cache: - self.file_id_cache[session_id] = {} - - if cache_key in self.file_id_cache[session_id]: - file_id = self.file_id_cache[session_id][cache_key] - return file_id - - try: - image_data = await self.api_client.download_image(image_url) - - file_id = await self._upload_file(image_data, session_id, cache_key) - - if session_id and cache_key: - self.file_id_cache[session_id][cache_key] = file_id - - return file_id - - except Exception as e: - logger.error(f"处理图片失败 {image_url}: {e!s}") - raise Exception(f"处理图片失败: {e!s}") - - async def _process_context_images( - self, - content: str | list, - session_id: str, - ) -> str: - """处理上下文中的图片内容,将 base64 图片上传并替换为 file_id""" - try: - if isinstance(content, str): - return content - - processed_content = [] - if session_id not in self.file_id_cache: - self.file_id_cache[session_id] = {} - - for item in content: - if not isinstance(item, dict): - processed_content.append(item) - continue - if item.get("type") == "text": - processed_content.append(item) - elif item.get("type") == "image_url": - # 处理图片逻辑 - if "file_id" in item: - # 已经有 file_id - logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}") - processed_content.append(item) - else: - # 获取图片数据 - image_data = "" - if "image_url" in item and isinstance(item["image_url"], dict): - image_data = item["image_url"].get("url", "") - elif "data" in item: - image_data = item.get("data", "") - elif "url" in item: - image_data = item.get("url", "") - - if not image_data: - continue - # 计算哈希用于缓存 - cache_key = self._generate_cache_key( - image_data, - is_base64=image_data.startswith("data:image/"), - ) - - # 检查缓存 - if cache_key in self.file_id_cache[session_id]: - file_id = self.file_id_cache[session_id][cache_key] - processed_content.append( - {"type": "image", "file_id": file_id}, - ) - else: - # 上传图片并缓存 - if image_data.startswith("data:image/"): - # base64 处理 - _, encoded = image_data.split(",", 1) - image_bytes = base64.b64decode(encoded) - file_id = await self._upload_file( - image_bytes, - session_id, - cache_key, - ) - elif image_data.startswith(("http://", "https://")): - # URL 图片 - file_id = await self._download_and_upload_image( - image_data, - session_id, - ) - # 为URL图片也添加缓存 - self.file_id_cache[session_id][cache_key] = file_id - elif os.path.exists(image_data): - # 本地文件 - with open(image_data, "rb") as f: - image_bytes = f.read() - file_id = await self._upload_file( - image_bytes, - session_id, - cache_key, - ) - else: - logger.warning( - f"无法处理的图片格式: {image_data[:50]}...", - ) - continue - - processed_content.append( - {"type": "image", "file_id": file_id}, - ) - - result = json.dumps(processed_content, ensure_ascii=False) - return result - except Exception as e: - logger.error(f"处理上下文图片失败: {e!s}") - if isinstance(content, str): - return content - return json.dumps(content, ensure_ascii=False) - - async def text_chat( - self, - prompt: str, - session_id=None, - image_urls=None, - func_tool=None, - contexts=None, - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ) -> LLMResponse: - """文本对话, 内部使用流式接口实现非流式 - - Args: - prompt (str): 用户提示词 - session_id (str): 会话ID - image_urls (List[str]): 图片URL列表 - func_tool (FuncCall): 函数调用工具(不支持) - contexts (List): 上下文列表 - system_prompt (str): 系统提示语 - tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持) - model (str): 模型名称(不支持) - - Returns: - LLMResponse: LLM响应对象 - - """ - accumulated_content = "" - final_response = None - - async for llm_response in self.text_chat_stream( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - model=model, - **kwargs, - ): - if llm_response.is_chunk: - if llm_response.completion_text: - accumulated_content += llm_response.completion_text - else: - final_response = llm_response - - if final_response: - return final_response - - if accumulated_content: - chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) - return LLMResponse(role="assistant", result_chain=chain) - return LLMResponse(role="assistant", completion_text="") - - async def text_chat_stream( - self, - prompt: str, - session_id=None, - image_urls=None, - func_tool=None, - contexts=None, - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ) -> AsyncGenerator[LLMResponse, None]: - """流式对话接口""" - # 用户ID参数(参考文档, 可以自定义) - user_id = session_id or kwargs.get("user", "default_user") - - # 获取或创建会话ID - conversation_id = self.conversation_ids.get(user_id) - - # 构建消息 - additional_messages = [] - - if system_prompt: - if not self.auto_save_history or not conversation_id: - additional_messages.append( - { - "role": "system", - "content": system_prompt, - "content_type": "text", - }, - ) - - contexts = self._ensure_message_to_dicts(contexts) - if not self.auto_save_history and contexts: - # 如果关闭了自动保存历史,传入上下文 - for ctx in contexts: - if isinstance(ctx, dict) and "role" in ctx and "content" in ctx: - content = ctx["content"] - content_type = ctx.get("content_type", "text") - - # 处理可能包含图片的上下文 - if ( - content_type == "object_string" - or (isinstance(content, str) and content.startswith("[")) - or ( - isinstance(content, list) - and any( - isinstance(item, dict) - and item.get("type") == "image_url" - for item in content - ) - ) - ): - processed_content = await self._process_context_images( - content, - user_id, - ) - additional_messages.append( - { - "role": ctx["role"], - "content": processed_content, - "content_type": "object_string", - }, - ) - else: - # 纯文本 - additional_messages.append( - { - "role": ctx["role"], - "content": ( - content - if isinstance(content, str) - else json.dumps(content, ensure_ascii=False) - ), - "content_type": "text", - }, - ) - else: - logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}") - - if prompt or image_urls: - if image_urls: - # 多模态 - object_string_content = [] - if prompt: - object_string_content.append({"type": "text", "text": prompt}) - - for url in image_urls: - try: - if url.startswith(("http://", "https://")): - # 网络图片 - file_id = await self._download_and_upload_image( - url, - user_id, - ) - else: - # 本地文件或 base64 - if url.startswith("data:image/"): - # base64 - _, encoded = url.split(",", 1) - image_data = base64.b64decode(encoded) - cache_key = self._generate_cache_key( - url, - is_base64=True, - ) - file_id = await self._upload_file( - image_data, - user_id, - cache_key, - ) - # 本地文件 - elif os.path.exists(url): - with open(url, "rb") as f: - image_data = f.read() - # 用文件路径和修改时间来缓存 - file_stat = os.stat(url) - cache_key = self._generate_cache_key( - f"{url}_{file_stat.st_mtime}_{file_stat.st_size}", - is_base64=False, - ) - file_id = await self._upload_file( - image_data, - user_id, - cache_key, - ) - else: - logger.warning(f"图片文件不存在: {url}") - continue - - object_string_content.append( - { - "type": "image", - "file_id": file_id, - }, - ) - except Exception as e: - logger.error(f"处理图片失败 {url}: {e!s}") - continue - - if object_string_content: - content = json.dumps(object_string_content, ensure_ascii=False) - additional_messages.append( - { - "role": "user", - "content": content, - "content_type": "object_string", - }, - ) - # 纯文本 - elif prompt: - additional_messages.append( - { - "role": "user", - "content": prompt, - "content_type": "text", - }, - ) - - try: - accumulated_content = "" - message_started = False - - async for chunk in self.api_client.chat_messages( - bot_id=self.bot_id, - user_id=user_id, - additional_messages=additional_messages, - conversation_id=conversation_id, - auto_save_history=self.auto_save_history, - stream=True, - timeout=self.timeout, - ): - event_type = chunk.get("event") - data = chunk.get("data", {}) - - if event_type == "conversation.chat.created": - if isinstance(data, dict) and "conversation_id" in data: - self.conversation_ids[user_id] = data["conversation_id"] - - elif event_type == "conversation.message.delta": - if isinstance(data, dict): - content = data.get("content", "") - if not content and "delta" in data: - content = data["delta"].get("content", "") - if not content and "text" in data: - content = data.get("text", "") - - if content: - message_started = True - accumulated_content += content - yield LLMResponse( - role="assistant", - completion_text=content, - is_chunk=True, - ) - - elif event_type == "conversation.message.completed": - if isinstance(data, dict): - msg_type = data.get("type") - if msg_type == "answer" and data.get("role") == "assistant": - final_content = data.get("content", "") - if not accumulated_content and final_content: - chain = MessageChain(chain=[Comp.Plain(final_content)]) - yield LLMResponse( - role="assistant", - result_chain=chain, - is_chunk=False, - ) - - elif event_type == "conversation.chat.completed": - if accumulated_content: - chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) - yield LLMResponse( - role="assistant", - result_chain=chain, - is_chunk=False, - ) - break - - elif event_type == "done": - break - - elif event_type == "error": - error_msg = ( - data.get("message", "未知错误") - if isinstance(data, dict) - else str(data) - ) - logger.error(f"Coze 流式响应错误: {error_msg}") - yield LLMResponse( - role="err", - completion_text=f"Coze 错误: {error_msg}", - is_chunk=False, - ) - break - - if not message_started and not accumulated_content: - yield LLMResponse( - role="assistant", - completion_text="LLM 未响应任何内容。", - is_chunk=False, - ) - elif message_started and accumulated_content: - chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) - yield LLMResponse( - role="assistant", - result_chain=chain, - is_chunk=False, - ) - - except Exception as e: - logger.error(f"Coze 流式请求失败: {e!s}") - yield LLMResponse( - role="err", - completion_text=f"Coze 流式请求失败: {e!s}", - is_chunk=False, - ) - - async def forget(self, session_id: str): - """清空指定会话的上下文""" - user_id = session_id - conversation_id = self.conversation_ids.get(user_id) - - if user_id in self.file_id_cache: - self.file_id_cache.pop(user_id, None) - - if not conversation_id: - return True - - try: - response = await self.api_client.clear_context(conversation_id) - - if "code" in response and response["code"] == 0: - self.conversation_ids.pop(user_id, None) - return True - logger.warning(f"清空 Coze 会话上下文失败: {response}") - return False - - except Exception as e: - logger.error(f"清空 Coze 会话失败: {e!s}") - return False - - async def get_current_key(self): - """获取当前API Key""" - return self.api_key - - async def set_key(self, key: str): - """设置新的API Key""" - raise NotImplementedError("Coze 适配器不支持设置 API Key。") - - async def get_models(self): - """获取可用模型列表""" - return [f"bot_{self.bot_id}"] - - def get_model(self): - """获取当前模型""" - return f"bot_{self.bot_id}" - - def set_model(self, model: str): - """设置模型(在Coze中是Bot ID)""" - if model.startswith("bot_"): - self.bot_id = model[4:] - else: - self.bot_id = model - - async def get_human_readable_context( - self, - session_id: str, - page: int = 1, - page_size: int = 10, - ): - """获取人类可读的上下文历史""" - user_id = session_id - conversation_id = self.conversation_ids.get(user_id) - - if not conversation_id: - return [] - - try: - data = await self.api_client.get_message_list( - conversation_id=conversation_id, - order="desc", - limit=page_size, - offset=(page - 1) * page_size, - ) - - if data.get("code") != 0: - logger.warning(f"获取 Coze 消息历史失败: {data}") - return [] - - messages = data.get("data", {}).get("messages", []) - - readable_history = [] - for msg in messages: - role = msg.get("role", "unknown") - content = msg.get("content", "") - msg_type = msg.get("type", "") - - if role == "user": - readable_history.append(f"用户: {content}") - elif role == "assistant" and msg_type == "answer": - readable_history.append(f"助手: {content}") - - return readable_history - - except Exception as e: - logger.error(f"获取 Coze 消息历史失败: {e!s}") - return [] - - async def terminate(self): - """清理资源""" - await self.api_client.close() diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py deleted file mode 100644 index 7c690e048..000000000 --- a/astrbot/core/provider/sources/dashscope_source.py +++ /dev/null @@ -1,207 +0,0 @@ -import asyncio -import functools -import re - -from dashscope import Application -from dashscope.app.application_response import ApplicationResponse - -from astrbot.core import logger, sp -from astrbot.core.message.message_event_result import MessageChain - -from .. import Provider -from ..entities import LLMResponse -from ..register import register_provider_adapter -from .openai_source import ProviderOpenAIOfficial - - -@register_provider_adapter("dashscope", "Dashscope APP 适配器。") -class ProviderDashscope(ProviderOpenAIOfficial): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - ) -> None: - Provider.__init__( - self, - provider_config, - provider_settings, - ) - self.api_key = provider_config.get("dashscope_api_key", "") - if not self.api_key: - raise Exception("阿里云百炼 API Key 不能为空。") - self.app_id = provider_config.get("dashscope_app_id", "") - if not self.app_id: - raise Exception("阿里云百炼 APP ID 不能为空。") - self.dashscope_app_type = provider_config.get("dashscope_app_type", "") - if not self.dashscope_app_type: - raise Exception("阿里云百炼 APP 类型不能为空。") - self.model_name = "dashscope" - self.variables: dict = provider_config.get("variables", {}) - self.rag_options: dict = provider_config.get("rag_options", {}) - self.output_reference = self.rag_options.get("output_reference", False) - self.rag_options = self.rag_options.copy() - self.rag_options.pop("output_reference", None) - - self.timeout = provider_config.get("timeout", 120) - if isinstance(self.timeout, str): - self.timeout = int(self.timeout) - - def has_rag_options(self): - """判断是否有 RAG 选项 - - Returns: - bool: 是否有 RAG 选项 - - """ - if self.rag_options and ( - len(self.rag_options.get("pipeline_ids", [])) > 0 - or len(self.rag_options.get("file_ids", [])) > 0 - ): - return True - return False - - async def text_chat( - self, - prompt: str, - session_id=None, - image_urls=None, - func_tool=None, - contexts=None, - system_prompt=None, - model=None, - **kwargs, - ) -> LLMResponse: - if image_urls is None: - image_urls = [] - if contexts is None: - contexts = [] - # 获得会话变量 - payload_vars = self.variables.copy() - # 动态变量 - session_var = await sp.session_get(session_id, "session_variables", default={}) - payload_vars.update(session_var) - - if ( - self.dashscope_app_type in ["agent", "dialog-workflow"] - and not self.has_rag_options() - ): - # 支持多轮对话的 - new_record = {"role": "user", "content": prompt} - if image_urls: - logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。") - contexts_no_img = await self._remove_image_from_context(contexts) - context_query = [*contexts_no_img, new_record] - if system_prompt: - context_query.insert(0, {"role": "system", "content": system_prompt}) - for part in context_query: - if "_no_save" in part: - del part["_no_save"] - # 调用阿里云百炼 API - payload = { - "app_id": self.app_id, - "api_key": self.api_key, - "messages": context_query, - "biz_params": payload_vars or None, - } - partial = functools.partial( - Application.call, - **payload, - ) - response = await asyncio.get_event_loop().run_in_executor(None, partial) - else: - # 不支持多轮对话的 - # 调用阿里云百炼 API - payload = { - "app_id": self.app_id, - "prompt": prompt, - "api_key": self.api_key, - "biz_params": payload_vars or None, - } - if self.rag_options: - payload["rag_options"] = self.rag_options - partial = functools.partial( - Application.call, - **payload, - ) - response = await asyncio.get_event_loop().run_in_executor(None, partial) - - assert isinstance(response, ApplicationResponse) - - logger.debug(f"dashscope resp: {response}") - - if response.status_code != 200: - logger.error( - f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code", - ) - return LLMResponse( - role="err", - result_chain=MessageChain().message( - f"阿里云百炼请求失败: message={response.message} code={response.status_code}", - ), - ) - - output_text = response.output.get("text", "") or "" - # RAG 引用脚标格式化 - output_text = re.sub(r"\[(\d+)\]", r"[\1]", output_text) - if self.output_reference and response.output.get("doc_references", None): - ref_parts = [] - for ref in response.output.get("doc_references", []) or []: - ref_title = ( - ref.get("title", "") - if ref.get("title") - else ref.get("doc_name", "") - ) - ref_parts.append(f"{ref['index_id']}. {ref_title}\n") - ref_str = "".join(ref_parts) - output_text += f"\n\n回答来源:\n{ref_str}" - - llm_response = LLMResponse("assistant") - llm_response.result_chain = MessageChain().message(output_text) - - return llm_response - - async def text_chat_stream( - self, - prompt, - session_id=None, - image_urls=..., - func_tool=None, - contexts=..., - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response - - async def forget(self, session_id): - return True - - async def get_current_key(self): - return self.api_key - - async def set_key(self, key): - raise Exception("阿里云百炼 适配器不支持设置 API Key。") - - async def get_models(self): - return [self.get_model()] - - async def get_human_readable_context(self, session_id, page, page_size): - raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。") - - async def terminate(self): - pass From d560671d1fb366555fc39e7236bda2e911518f52 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Nov 2025 20:54:19 +0800 Subject: [PATCH 07/29] feat: agent runner config migration --- astrbot/core/core_lifecycle.py | 21 +++------- astrbot/core/utils/migra_helper.py | 61 ++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 16 deletions(-) create mode 100644 astrbot/core/utils/migra_helper.py diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 17fd52138..cd3f2f27b 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -16,13 +16,12 @@ import time import traceback from asyncio import Queue -from astrbot.core import LogBroker, logger, sp +from astrbot.api import logger, sp +from astrbot.core import LogBroker from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.db import BaseDatabase -from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 -from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.persona_mgr import PersonaManager from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler @@ -34,6 +33,7 @@ from astrbot.core.star.context import Context from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map from astrbot.core.umop_config_router import UmopConfigRouter from astrbot.core.updator import AstrBotUpdator +from astrbot.core.utils.migra_helper import migra from . import astrbot_config, html_renderer from .event_bus import EventBus @@ -97,19 +97,8 @@ class AstrBotCoreLifecycle: sp=sp, ) - # 4.5 to 4.6 migration for umop_config_router - try: - await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router) - except Exception as e: - logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}") - logger.error(traceback.format_exc()) - - # migration for webchat session - try: - await migrate_webchat_session(self.db) - except Exception as e: - logger.error(f"Migration for webchat session failed: {e!s}") - logger.error(traceback.format_exc()) + # apply migration + await migra(self.db, self.astrbot_config_mgr, self.umop_config_router) # 初始化事件队列 self.event_queue = Queue() diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py new file mode 100644 index 000000000..27358201a --- /dev/null +++ b/astrbot/core/utils/migra_helper.py @@ -0,0 +1,61 @@ +import traceback + +from astrbot.core import astrbot_config, logger +from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 +from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session + + +async def migra(db, astrbot_config_mgr, umop_config_router) -> None: + """ + Stores the migration logic here. + btw, i really don't like migration :( + """ + # 4.5 to 4.6 migration for umop_config_router + try: + await migrate_45_to_46(astrbot_config_mgr, umop_config_router) + except Exception as e: + logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}") + logger.error(traceback.format_exc()) + + # migration for webchat session + try: + await migrate_webchat_session(db) + except Exception as e: + logger.error(f"Migration for webchat session failed: {e!s}") + logger.error(traceback.format_exc()) + + # migra third party agent runner configs + try: + _c = False + providers = astrbot_config["provider"] + ids_map = {} + for prov in providers: + type_ = prov.get("type") + if type_ in ["dify", "coze", "dashscope"]: + prov["provider_type"] = "agent_runner" + ids_map[prov["id"]] = { + "type": type_, + "id": prov["id"], + } + _c = True + default_prov_id = astrbot_config["provider_settings"]["default_provider_id"] + if default_prov_id in ids_map: + astrbot_config["provider_settings"]["default_provider_id"] = "" + p = ids_map[default_prov_id] + if p["type"] == "dify": + astrbot_config["provider_settings"]["dify_agent_runner_provider_id"] = ( + p["id"] + ) + elif p["type"] == "coze": + astrbot_config["provider_settings"]["coze_agent_runner_provider_id"] = ( + p["id"] + ) + elif p["type"] == "dashscope": + astrbot_config["provider_settings"][ + "dashscope_agent_runner_provider_id" + ] = p["id"] + if _c: + astrbot_config.save_config() + except Exception as e: + logger.error(f"Migration for third party agent runner configs failed: {e!s}") + logger.error(traceback.format_exc()) From 9a5f507cbefda5af03eb99f67dff6f21c1bb722d Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Nov 2025 20:58:18 +0800 Subject: [PATCH 08/29] feat: enable agent runner providers in configuration --- astrbot/core/config/default.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 4e1101010..39c0aa3a7 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2217,6 +2217,7 @@ CONFIG_METADATA_3 = { "_special": "select_agent_runner_provider", "condition": { "provider_settings.agent_runner_type": "coze", + "provider_settings.enable": True, }, }, "provider_settings.dify_agent_runner_provider_id": { @@ -2225,6 +2226,7 @@ CONFIG_METADATA_3 = { "_special": "select_agent_runner_provider", "condition": { "provider_settings.agent_runner_type": "dify", + "provider_settings.enable": True, }, }, "provider_settings.dashscope_agent_runner_provider_id": { @@ -2233,6 +2235,7 @@ CONFIG_METADATA_3 = { "_special": "select_agent_runner_provider", "condition": { "provider_settings.agent_runner_type": "dashscope", + "provider_settings.enable": True, }, }, }, From 7ba98c1e916989f474f0a059f9c3660718a75765 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Nov 2025 21:06:16 +0800 Subject: [PATCH 09/29] feat: enhance provider display with grouped categorization and improved filtering --- dashboard/src/views/ProviderPage.vue | 142 +++++++++++++++++++++------ 1 file changed, 111 insertions(+), 31 deletions(-) diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index 2b7365851..321cb2d77 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -52,30 +52,62 @@ - - - mdi-api-off -

{{ getEmptyText() }}

-
-
- - - - - - - - - + +
@@ -368,6 +400,52 @@ export default { }; }, + groupedProviders() { + if (!this.config_data.provider) { + return []; + } + + const typeOrder = [ + 'chat_completion', + 'agent_runner', + 'speech_to_text', + 'text_to_speech', + 'embedding', + 'rerank', + ]; + + const assigned = new Set(); + const groups = typeOrder + .map((typeKey) => { + const items = this.config_data.provider.filter((provider) => { + const resolved = this.getProviderType(provider); + if (resolved === typeKey) { + assigned.add(provider.id); + return true; + } + return false; + }); + return { + typeKey, + label: this.messages.tabTypes[typeKey] || typeKey, + items, + }; + }) + .filter((group) => group.items.length > 0); + + const remaining = this.config_data.provider.filter( + (provider) => !assigned.has(provider.id), + ); + if (remaining.length > 0) { + groups.push({ + typeKey: 'others', + label: this.tm('providers.tabs.all'), + items: remaining, + }); + } + return groups; + }, + // 根据选择的标签过滤提供商列表 filteredProviders() { if (!this.config_data.provider || this.activeProviderTypeTab === 'all') { @@ -376,13 +454,7 @@ export default { return this.config_data.provider.filter(provider => { // 如果provider.provider_type已经存在,直接使用它 - if (provider.provider_type) { - return provider.provider_type === this.activeProviderTypeTab; - } - - // 否则使用映射关系 - const mappedType = this.oldVersionProviderTypeMapping[provider.type]; - return mappedType === this.activeProviderTypeTab; + return this.getProviderType(provider) === this.activeProviderTypeTab; }); } }, @@ -392,6 +464,14 @@ export default { }, methods: { + getProviderType(provider) { + if (!provider) return undefined; + if (provider.provider_type) { + return provider.provider_type; + } + return this.oldVersionProviderTypeMapping[provider.type]; + }, + getConfig() { axios.get('/api/config/get').then((res) => { this.config_data = res.data.data.config; From 1338cab61b2393aa82c67fb5256dd37b139e20a8 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 23 Nov 2025 21:53:56 +0800 Subject: [PATCH 10/29] feat: add configuration selector for session management and enhance session handling in chat components --- astrbot/core/umop_config_router.py | 19 ++ astrbot/dashboard/routes/chat.py | 14 +- dashboard/src/components/chat/Chat.vue | 2 + dashboard/src/components/chat/ChatInput.vue | 32 +- .../src/components/chat/ConfigSelector.vue | 311 ++++++++++++++++++ .../components/chat/ProviderModelSelector.vue | 1 + dashboard/src/composables/useSessions.ts | 8 +- dashboard/src/views/ProviderPage.vue | 11 +- 8 files changed, 388 insertions(+), 10 deletions(-) create mode 100644 dashboard/src/components/chat/ConfigSelector.vue diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index 07858da5f..27f6232aa 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -85,3 +85,22 @@ class UmopConfigRouter: self.umop_to_conf_id[umo] = conf_id await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) + + async def delete_route(self, umo: str): + """删除一条路由 + + Args: + umo (str): 需要删除的 UMO 字符串 + + Raises: + ValueError: 当 umo 格式不正确时抛出 + """ + + if not isinstance(umo, str) or len(umo.split(":")) != 3: + raise ValueError( + "umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", + ) + + if umo in self.umop_to_conf_id: + del self.umop_to_conf_id[umo] + await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 1ad789563..5381b5649 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -56,6 +56,7 @@ class ChatRoute(Route): self.conv_mgr = core_lifecycle.conversation_manager self.platform_history_mgr = core_lifecycle.platform_message_history_manager self.db = db + self.umop_config_router = core_lifecycle.umop_config_router self.running_convs: dict[str, bool] = {} @@ -266,7 +267,8 @@ class ChatRoute(Route): return Response().error("Permission denied").__dict__ # 删除该会话下的所有对话 - unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}" + message_type = "GroupMessage" if session.is_group else "FriendMessage" + unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}" await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) # 删除消息历史 @@ -276,6 +278,16 @@ class ChatRoute(Route): offset_sec=99999999, ) + # 删除与会话关联的配置路由 + try: + await self.umop_config_router.delete_route(unified_msg_origin) + except ValueError as exc: + logger.warning( + "Failed to delete UMO route %s during session cleanup: %s", + unified_msg_origin, + exc, + ) + # 清理队列(仅对 webchat) if session.platform_id == "webchat": webchat_queue_mgr.remove_queues(session_id) diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 09acd1b7e..ff0349409 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -87,6 +87,8 @@ :disabled="isStreaming || isConvRunning" :enableStreaming="enableStreaming" :isRecording="isRecording" + :session-id="currSessionId || null" + :current-session="getCurrentSession" @send="handleSendMessage" @toggleStreaming="toggleStreaming" @removeImage="removeImage" diff --git a/dashboard/src/components/chat/ChatInput.vue b/dashboard/src/components/chat/ChatInput.vue index 7ca0ec94a..cad1b7205 100644 --- a/dashboard/src/components/chat/ChatInput.vue +++ b/dashboard/src/components/chat/ChatInput.vue @@ -11,7 +11,13 @@ style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 8px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);">
- + + + + diff --git a/dashboard/src/components/chat/ProviderModelSelector.vue b/dashboard/src/components/chat/ProviderModelSelector.vue index 0b983e416..645ff5a55 100644 --- a/dashboard/src/components/chat/ProviderModelSelector.vue +++ b/dashboard/src/components/chat/ProviderModelSelector.vue @@ -3,6 +3,7 @@ + mdi-creation {{ selectedProviderId }} / {{ selectedModelName }} diff --git a/dashboard/src/composables/useSessions.ts b/dashboard/src/composables/useSessions.ts index f14e3aa11..b63caa42c 100644 --- a/dashboard/src/composables/useSessions.ts +++ b/dashboard/src/composables/useSessions.ts @@ -4,8 +4,12 @@ import { useRouter } from 'vue-router'; export interface Session { session_id: string; - display_name: string; - updated_at: string; + display_name?: string | null; + updated_at?: string; + platform_id?: string; + creator?: string; + is_group?: number; + created_at?: string; } export function useSessions(chatboxMode: boolean = false) { diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index 321cb2d77..32f5a60a8 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -63,12 +63,12 @@

{{ group.label }}

- + + :bglogo="getProviderIcon(provider.provider)" @delete="deleteProvider" @edit="configExistingProvider" + @copy="copyProvider" :show-copy-button="true">