diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index f7e0913b4..21e796433 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -2,13 +2,12 @@ import abc import typing as T from enum import Enum, auto -from astrbot.core.provider import Provider +from astrbot import logger from astrbot.core.provider.entities import LLMResponse from ..hooks import BaseAgentRunHooks from ..response import AgentResponse from ..run_context import ContextWrapper, TContext -from ..tool_executor import BaseFunctionToolExecutor class AgentState(Enum): @@ -24,9 +23,7 @@ class BaseAgentRunner(T.Generic[TContext]): @abc.abstractmethod async def reset( self, - provider: Provider, run_context: ContextWrapper[TContext], - tool_executor: BaseFunctionToolExecutor[TContext], agent_hooks: BaseAgentRunHooks[TContext], **kwargs: T.Any, ) -> None: @@ -60,3 +57,9 @@ class BaseAgentRunner(T.Generic[TContext]): This method should be called after the agent is done. """ ... + + def _transition_state(self, new_state: AgentState) -> None: + """Transition the agent state.""" + if self._state != new_state: + logger.debug(f"Agent state transition: {self._state} -> {new_state}") + self._state = new_state diff --git a/astrbot/core/agent/runners/coze/coze_agent_runner.py b/astrbot/core/agent/runners/coze/coze_agent_runner.py new file mode 100644 index 000000000..a8300bb71 --- /dev/null +++ b/astrbot/core/agent/runners/coze/coze_agent_runner.py @@ -0,0 +1,367 @@ +import base64 +import json +import sys +import typing as T + +import astrbot.core.message.components as Comp +from astrbot import logger +from astrbot.core import sp +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import ( + LLMResponse, + ProviderRequest, +) + +from ...hooks import BaseAgentRunHooks +from ...response import AgentResponseData +from ...run_context import ContextWrapper, TContext +from ..base import AgentResponse, AgentState, BaseAgentRunner +from .coze_api_client import CozeAPIClient + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +class CozeAgentRunner(BaseAgentRunner[TContext]): + """Coze Agent Runner""" + + @override + async def reset( + self, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.final_llm_resp = None + self._state = AgentState.IDLE + self.agent_hooks = agent_hooks + self.run_context = run_context + + self.api_key = provider_config.get("coze_api_key", "") + if not self.api_key: + raise Exception("Coze API Key 不能为空。") + self.bot_id = provider_config.get("bot_id", "") + if not self.bot_id: + raise Exception("Coze Bot ID 不能为空。") + self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn") + + if not isinstance(self.api_base, str) or not self.api_base.startswith( + ("http://", "https://"), + ): + raise Exception( + "Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。", + ) + + self.timeout = provider_config.get("timeout", 120) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + self.auto_save_history = provider_config.get("auto_save_history", True) + + # 创建 API 客户端 + self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base) + + # 会话相关缓存 + self.file_id_cache: dict[str, dict[str, str]] = {} + + @override + async def step(self): + """ + 执行 Coze Agent 的一个步骤 + """ + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + + if self._state == AgentState.IDLE: + try: + await self.agent_hooks.on_agent_begin(self.run_context) + except Exception as e: + logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True) + + # 开始处理,转换到运行状态 + self._transition_state(AgentState.RUNNING) + + try: + # 执行 Coze 请求并处理结果 + async for response in self._execute_coze_request(): + yield response + except Exception as e: + logger.error(f"Coze 请求失败:{str(e)}") + self._transition_state(AgentState.ERROR) + self.final_llm_resp = LLMResponse( + role="err", completion_text=f"Coze 请求失败:{str(e)}" + ) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=MessageChain().message(f"Coze 请求失败:{str(e)}") + ), + ) + finally: + await self.api_client.close() + + @override + async def step_until_done( + self, max_step: int = 30 + ) -> T.AsyncGenerator[AgentResponse, None]: + while not self.done(): + async for resp in self.step(): + yield resp + + async def _execute_coze_request(self): + """执行 Coze 请求的核心逻辑""" + prompt = self.req.prompt or "" + session_id = self.req.session_id or "unknown" + image_urls = self.req.image_urls or [] + contexts = self.req.contexts or [] + system_prompt = self.req.system_prompt + + # 用户ID参数 + user_id = session_id + + # 获取或创建会话ID + conversation_id = await sp.get_async( + scope="umo", + scope_id=user_id, + key="coze_conversation_id", + default="", + ) + + # 构建消息 + additional_messages = [] + + if system_prompt: + if not self.auto_save_history or not conversation_id: + additional_messages.append( + { + "role": "system", + "content": system_prompt, + "content_type": "text", + }, + ) + + # 处理历史上下文 + if not self.auto_save_history and contexts: + for ctx in contexts: + if isinstance(ctx, dict) and "role" in ctx and "content" in ctx: + # 处理上下文中的图片 + content = ctx["content"] + if isinstance(content, list): + # 多模态内容,需要处理图片 + processed_content = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + processed_content.append(item) + elif item.get("type") == "image_url": + # 处理图片上传 + try: + image_data = item.get("image_url", {}) + url = image_data.get("url", "") + if url: + file_id = ( + await self._download_and_upload_image( + url, session_id + ) + ) + processed_content.append( + { + "type": "file", + "file_id": file_id, + "file_url": url, + } + ) + except Exception as e: + logger.warning(f"处理上下文图片失败: {e}") + continue + + if processed_content: + additional_messages.append( + { + "role": ctx["role"], + "content": processed_content, + "content_type": "object_string", + } + ) + else: + # 纯文本内容 + additional_messages.append( + { + "role": ctx["role"], + "content": content, + "content_type": "text", + } + ) + + # 构建当前消息 + if prompt or image_urls: + if image_urls: + # 多模态 + object_string_content = [] + if prompt: + object_string_content.append({"type": "text", "text": prompt}) + + for url in image_urls: + # the url is a base64 string + try: + image_data = base64.b64decode(url) + file_id = await self.api_client.upload_file(image_data) + object_string_content.append( + { + "type": "image", + "file_id": file_id, + } + ) + except Exception as e: + logger.warning(f"处理图片失败 {url}: {e}") + continue + + if object_string_content: + content = json.dumps(object_string_content, ensure_ascii=False) + additional_messages.append( + { + "role": "user", + "content": content, + "content_type": "object_string", + } + ) + elif prompt: + # 纯文本 + additional_messages.append( + { + "role": "user", + "content": prompt, + "content_type": "text", + }, + ) + + # 执行 Coze API 请求 + accumulated_content = "" + message_started = False + + async for chunk in self.api_client.chat_messages( + bot_id=self.bot_id, + user_id=user_id, + additional_messages=additional_messages, + conversation_id=conversation_id, + auto_save_history=self.auto_save_history, + stream=True, + timeout=self.timeout, + ): + event_type = chunk.get("event") + data = chunk.get("data", {}) + + if event_type == "conversation.chat.created": + if isinstance(data, dict) and "conversation_id" in data: + await sp.put_async( + scope="umo", + scope_id=user_id, + key="coze_conversation_id", + value=data["conversation_id"], + ) + + if event_type == "conversation.message.delta": + # 增量消息 + content = data.get("content", "") + if not content and "delta" in data: + content = data["delta"].get("content", "") + if not content and "text" in data: + content = data.get("text", "") + + if content: + accumulated_content += content + message_started = True + + # 如果是流式响应,发送增量数据 + if self.streaming: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message(content) + ), + ) + + elif event_type == "conversation.message.completed": + # 消息完成 + logger.debug("Coze message completed") + message_started = True + + elif event_type == "conversation.chat.completed": + # 对话完成 + logger.debug("Coze chat completed") + break + + elif event_type == "error": + # 错误处理 + error_msg = data.get("msg", "未知错误") + error_code = data.get("code", "UNKNOWN") + logger.error(f"Coze 出现错误: {error_code} - {error_msg}") + raise Exception(f"Coze 出现错误: {error_code} - {error_msg}") + + if not message_started and not accumulated_content: + logger.warning("Coze 未返回任何内容") + accumulated_content = "" + + # 创建最终响应 + chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) + self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain) + self._transition_state(AgentState.DONE) + + try: + await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + + # 返回最终结果 + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=chain), + ) + + async def _download_and_upload_image( + self, + image_url: str, + session_id: str | None = None, + ) -> str: + """下载图片并上传到 Coze,返回 file_id""" + import hashlib + + # 计算哈希实现缓存 + cache_key = hashlib.md5(image_url.encode("utf-8")).hexdigest() + + if session_id: + if session_id not in self.file_id_cache: + self.file_id_cache[session_id] = {} + + if cache_key in self.file_id_cache[session_id]: + file_id = self.file_id_cache[session_id][cache_key] + logger.debug(f"[Coze] 使用缓存的 file_id: {file_id}") + return file_id + + try: + image_data = await self.api_client.download_image(image_url) + file_id = await self.api_client.upload_file(image_data) + + if session_id: + self.file_id_cache[session_id][cache_key] = file_id + logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}") + + return file_id + + except Exception as e: + logger.error(f"处理图片失败 {image_url}: {e!s}") + raise Exception(f"处理图片失败: {e!s}") + + @override + def done(self) -> bool: + """检查 Agent 是否已完成工作""" + return self._state in (AgentState.DONE, AgentState.ERROR) + + @override + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp diff --git a/astrbot/core/provider/sources/coze_api_client.py b/astrbot/core/agent/runners/coze/coze_api_client.py similarity index 100% rename from astrbot/core/provider/sources/coze_api_client.py rename to astrbot/core/agent/runners/coze/coze_api_client.py diff --git a/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py new file mode 100644 index 000000000..7a095a60b --- /dev/null +++ b/astrbot/core/agent/runners/dashscope/dashscope_agent_runner.py @@ -0,0 +1,403 @@ +import asyncio +import functools +import queue +import re +import sys +import threading +import typing as T + +from dashscope import Application +from dashscope.app.application_response import ApplicationResponse + +import astrbot.core.message.components as Comp +from astrbot.core import logger, sp +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import ( + LLMResponse, + ProviderRequest, +) + +from ...hooks import BaseAgentRunHooks +from ...response import AgentResponseData +from ...run_context import ContextWrapper, TContext +from ..base import AgentResponse, AgentState, BaseAgentRunner + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +class DashscopeAgentRunner(BaseAgentRunner[TContext]): + """Dashscope Agent Runner""" + + @override + async def reset( + self, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.final_llm_resp = None + self._state = AgentState.IDLE + self.agent_hooks = agent_hooks + self.run_context = run_context + + self.api_key = provider_config.get("dashscope_api_key", "") + if not self.api_key: + raise Exception("阿里云百炼 API Key 不能为空。") + self.app_id = provider_config.get("dashscope_app_id", "") + if not self.app_id: + raise Exception("阿里云百炼 APP ID 不能为空。") + self.dashscope_app_type = provider_config.get("dashscope_app_type", "") + if not self.dashscope_app_type: + raise Exception("阿里云百炼 APP 类型不能为空。") + + self.variables: dict = provider_config.get("variables", {}) or {} + self.rag_options: dict = provider_config.get("rag_options", {}) + self.output_reference = self.rag_options.get("output_reference", False) + self.rag_options = self.rag_options.copy() + self.rag_options.pop("output_reference", None) + + self.timeout = provider_config.get("timeout", 120) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + + def has_rag_options(self): + """判断是否有 RAG 选项 + + Returns: + bool: 是否有 RAG 选项 + + """ + if self.rag_options and ( + len(self.rag_options.get("pipeline_ids", [])) > 0 + or len(self.rag_options.get("file_ids", [])) > 0 + ): + return True + return False + + @override + async def step(self): + """ + 执行 Dashscope Agent 的一个步骤 + """ + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + + if self._state == AgentState.IDLE: + try: + await self.agent_hooks.on_agent_begin(self.run_context) + except Exception as e: + logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True) + + # 开始处理,转换到运行状态 + self._transition_state(AgentState.RUNNING) + + try: + # 执行 Dashscope 请求并处理结果 + async for response in self._execute_dashscope_request(): + yield response + except Exception as e: + logger.error(f"阿里云百炼请求失败:{str(e)}") + self._transition_state(AgentState.ERROR) + self.final_llm_resp = LLMResponse( + role="err", completion_text=f"阿里云百炼请求失败:{str(e)}" + ) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=MessageChain().message(f"阿里云百炼请求失败:{str(e)}") + ), + ) + + @override + async def step_until_done( + self, max_step: int = 30 + ) -> T.AsyncGenerator[AgentResponse, None]: + while not self.done(): + async for resp in self.step(): + yield resp + + def _consume_sync_generator( + self, response: T.Any, response_queue: queue.Queue + ) -> None: + """在线程中消费同步generator,将结果放入队列 + + Args: + response: 同步generator对象 + response_queue: 用于传递数据的队列 + + """ + try: + if self.streaming: + for chunk in response: + response_queue.put(("data", chunk)) + else: + response_queue.put(("data", response)) + except Exception as e: + response_queue.put(("error", e)) + finally: + response_queue.put(("done", None)) + + async def _process_stream_chunk( + self, chunk: ApplicationResponse, output_text: str + ) -> tuple[str, list | None, AgentResponse | None]: + """处理流式响应的单个chunk + + Args: + chunk: Dashscope响应chunk + output_text: 当前累积的输出文本 + + Returns: + (更新后的output_text, doc_references, AgentResponse或None) + + """ + logger.debug(f"dashscope stream chunk: {chunk}") + + if chunk.status_code != 200: + logger.error( + f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code", + ) + self._transition_state(AgentState.ERROR) + error_msg = ( + f"阿里云百炼请求失败: message={chunk.message} code={chunk.status_code}" + ) + self.final_llm_resp = LLMResponse( + role="err", + result_chain=MessageChain().message(error_msg), + ) + return ( + output_text, + None, + AgentResponse( + type="err", + data=AgentResponseData(chain=MessageChain().message(error_msg)), + ), + ) + + chunk_text = chunk.output.get("text", "") or "" + # RAG 引用脚标格式化 + chunk_text = re.sub(r"\[(\d+)\]", r"[\1]", chunk_text) + + response = None + if chunk_text: + output_text += chunk_text + response = AgentResponse( + type="streaming_delta", + data=AgentResponseData(chain=MessageChain().message(chunk_text)), + ) + + # 获取文档引用 + doc_references = chunk.output.get("doc_references", None) + + return output_text, doc_references, response + + def _format_doc_references(self, doc_references: list) -> str: + """格式化文档引用为文本 + + Args: + doc_references: 文档引用列表 + + Returns: + 格式化后的引用文本 + + """ + ref_parts = [] + for ref in doc_references: + ref_title = ( + ref.get("title", "") if ref.get("title") else ref.get("doc_name", "") + ) + ref_parts.append(f"{ref['index_id']}. {ref_title}\n") + ref_str = "".join(ref_parts) + return f"\n\n回答来源:\n{ref_str}" + + async def _build_request_payload( + self, prompt: str, session_id: str, contexts: list, system_prompt: str + ) -> dict: + """构建请求payload + + Args: + prompt: 用户输入 + session_id: 会话ID + contexts: 上下文列表 + system_prompt: 系统提示词 + + Returns: + 请求payload字典 + + """ + conversation_id = await sp.get_async( + scope="umo", + scope_id=session_id, + key="dashscope_conversation_id", + default="", + ) + # 获得会话变量 + payload_vars = self.variables.copy() + session_var = await sp.get_async( + scope="umo", + scope_id=session_id, + key="session_variables", + default={}, + ) + payload_vars.update(session_var) + + if ( + self.dashscope_app_type in ["agent", "dialog-workflow"] + and not self.has_rag_options() + ): + # 支持多轮对话的 + p = { + "app_id": self.app_id, + "api_key": self.api_key, + "prompt": prompt, + "biz_params": payload_vars or None, + "stream": self.streaming, + "incremental_output": True, + } + if conversation_id: + p["session_id"] = conversation_id + return p + else: + # 不支持多轮对话的 + payload = { + "app_id": self.app_id, + "prompt": prompt, + "api_key": self.api_key, + "biz_params": payload_vars or None, + "stream": self.streaming, + "incremental_output": True, + } + if self.rag_options: + payload["rag_options"] = self.rag_options + return payload + + async def _handle_streaming_response( + self, response: T.Any, session_id: str + ) -> T.AsyncGenerator[AgentResponse, None]: + """处理流式响应 + + Args: + response: Dashscope 流式响应 generator + + Yields: + AgentResponse 对象 + + """ + response_queue = queue.Queue() + consumer_thread = threading.Thread( + target=self._consume_sync_generator, + args=(response, response_queue), + daemon=True, + ) + consumer_thread.start() + + output_text = "" + doc_references = None + + while True: + try: + item_type, item_data = await asyncio.get_event_loop().run_in_executor( + None, response_queue.get, True, 1 + ) + except queue.Empty: + continue + + if item_type == "done": + break + elif item_type == "error": + raise item_data + elif item_type == "data": + chunk = item_data + assert isinstance(chunk, ApplicationResponse) + + ( + output_text, + chunk_doc_refs, + response, + ) = await self._process_stream_chunk(chunk, output_text) + + if response: + if response.type == "err": + yield response + return + yield response + + if chunk_doc_refs: + doc_references = chunk_doc_refs + + if chunk.output.session_id: + await sp.put_async( + scope="umo", + scope_id=session_id, + key="dashscope_conversation_id", + value=chunk.output.session_id, + ) + + # 添加 RAG 引用 + if self.output_reference and doc_references: + ref_text = self._format_doc_references(doc_references) + output_text += ref_text + + if self.streaming: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData(chain=MessageChain().message(ref_text)), + ) + + # 创建最终响应 + chain = MessageChain(chain=[Comp.Plain(output_text)]) + self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain) + self._transition_state(AgentState.DONE) + + try: + await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + + # 返回最终结果 + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=chain), + ) + + async def _execute_dashscope_request(self): + """执行 Dashscope 请求的核心逻辑""" + prompt = self.req.prompt or "" + session_id = self.req.session_id or "unknown" + image_urls = self.req.image_urls or [] + contexts = self.req.contexts or [] + system_prompt = self.req.system_prompt + + # 检查图片输入 + if image_urls: + logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。") + + # 构建请求payload + payload = await self._build_request_payload( + prompt, session_id, contexts, system_prompt + ) + + if not self.streaming: + payload["incremental_output"] = False + + # 发起请求 + partial = functools.partial(Application.call, **payload) + response = await asyncio.get_event_loop().run_in_executor(None, partial) + + async for resp in self._handle_streaming_response(response, session_id): + yield resp + + @override + def done(self) -> bool: + """检查 Agent 是否已完成工作""" + return self._state in (AgentState.DONE, AgentState.ERROR) + + @override + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp diff --git a/astrbot/core/agent/runners/dify/dify_agent_runner.py b/astrbot/core/agent/runners/dify/dify_agent_runner.py new file mode 100644 index 000000000..d9a8b7cd6 --- /dev/null +++ b/astrbot/core/agent/runners/dify/dify_agent_runner.py @@ -0,0 +1,336 @@ +import base64 +import os +import sys +import typing as T + +import astrbot.core.message.components as Comp +from astrbot.core import logger, sp +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import ( + LLMResponse, + ProviderRequest, +) +from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import download_file + +from ...hooks import BaseAgentRunHooks +from ...response import AgentResponseData +from ...run_context import ContextWrapper, TContext +from ..base import AgentResponse, AgentState, BaseAgentRunner +from .dify_api_client import DifyAPIClient + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +class DifyAgentRunner(BaseAgentRunner[TContext]): + """Dify Agent Runner""" + + @override + async def reset( + self, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.final_llm_resp = None + self._state = AgentState.IDLE + self.agent_hooks = agent_hooks + self.run_context = run_context + + self.api_key = provider_config.get("dify_api_key", "") + self.api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1") + self.api_type = provider_config.get("dify_api_type", "chat") + self.workflow_output_key = provider_config.get( + "dify_workflow_output_key", + "astrbot_wf_output", + ) + self.dify_query_input_key = provider_config.get( + "dify_query_input_key", + "astrbot_text_query", + ) + self.variables: dict = provider_config.get("variables", {}) or {} + self.timeout = provider_config.get("timeout", 60) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + + self.api_client = DifyAPIClient(self.api_key, self.api_base) + + @override + async def step(self): + """ + 执行 Dify Agent 的一个步骤 + """ + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + + if self._state == AgentState.IDLE: + try: + await self.agent_hooks.on_agent_begin(self.run_context) + except Exception as e: + logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True) + + # 开始处理,转换到运行状态 + self._transition_state(AgentState.RUNNING) + + try: + # 执行 Dify 请求并处理结果 + async for response in self._execute_dify_request(): + yield response + except Exception as e: + logger.error(f"Dify 请求失败:{str(e)}") + self._transition_state(AgentState.ERROR) + self.final_llm_resp = LLMResponse( + role="err", completion_text=f"Dify 请求失败:{str(e)}" + ) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=MessageChain().message(f"Dify 请求失败:{str(e)}") + ), + ) + finally: + await self.api_client.close() + + @override + async def step_until_done( + self, max_step: int = 30 + ) -> T.AsyncGenerator[AgentResponse, None]: + while not self.done(): + async for resp in self.step(): + yield resp + + async def _execute_dify_request(self): + """执行 Dify 请求的核心逻辑""" + prompt = self.req.prompt or "" + session_id = self.req.session_id or "unknown" + image_urls = self.req.image_urls or [] + system_prompt = self.req.system_prompt + + conversation_id = await sp.get_async( + scope="umo", + scope_id=session_id, + key="dify_conversation_id", + default="", + ) + result = "" + + # 处理图片上传 + files_payload = [] + for image_url in image_urls: + # image_url is a base64 string + try: + image_data = base64.b64decode(image_url) + file_response = await self.api_client.file_upload( + file_data=image_data, + user=session_id, + mime_type="image/png", + file_name="image.png", + ) + logger.debug(f"Dify 上传图片响应:{file_response}") + if "id" not in file_response: + logger.warning( + f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。" + ) + continue + files_payload.append( + { + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response["id"], + } + ) + except Exception as e: + logger.warning(f"上传图片失败:{e}") + continue + + # 获得会话变量 + payload_vars = self.variables.copy() + # 动态变量 + session_var = await sp.get_async( + scope="umo", + scope_id=session_id, + key="session_variables", + default={}, + ) + payload_vars.update(session_var) + payload_vars["system_prompt"] = system_prompt + + # 处理不同的 API 类型 + match self.api_type: + case "chat" | "agent" | "chatflow": + if not prompt: + prompt = "请描述这张图片。" + + async for chunk in self.api_client.chat_messages( + inputs={ + **payload_vars, + }, + query=prompt, + user=session_id, + conversation_id=conversation_id, + files=files_payload, + timeout=self.timeout, + ): + logger.debug(f"dify resp chunk: {chunk}") + if chunk["event"] == "message" or chunk["event"] == "agent_message": + result += chunk["answer"] + if not conversation_id: + await sp.put_async( + scope="umo", + scope_id=session_id, + key="dify_conversation_id", + value=chunk["conversation_id"], + ) + conversation_id = chunk["conversation_id"] + + # 如果是流式响应,发送增量数据 + if self.streaming and chunk["answer"]: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message(chunk["answer"]) + ), + ) + elif chunk["event"] == "message_end": + logger.debug("Dify message end") + break + elif chunk["event"] == "error": + logger.error(f"Dify 出现错误:{chunk}") + raise Exception( + f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}" + ) + + case "workflow": + async for chunk in self.api_client.workflow_run( + inputs={ + self.dify_query_input_key: prompt, + "astrbot_session_id": session_id, + **payload_vars, + }, + user=session_id, + files=files_payload, + timeout=self.timeout, + ): + logger.debug(f"dify workflow resp chunk: {chunk}") + match chunk["event"]: + case "workflow_started": + logger.info( + f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。" + ) + case "node_finished": + logger.debug( + f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。" + ) + case "text_chunk": + if self.streaming and chunk["data"]["text"]: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message( + chunk["data"]["text"] + ) + ), + ) + case "workflow_finished": + logger.info( + f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束" + ) + logger.debug(f"Dify 工作流结果:{chunk}") + if chunk["data"]["error"]: + logger.error( + f"Dify 工作流出现错误:{chunk['data']['error']}" + ) + raise Exception( + f"Dify 工作流出现错误:{chunk['data']['error']}" + ) + if self.workflow_output_key not in chunk["data"]["outputs"]: + raise Exception( + f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}" + ) + result = chunk + case _: + raise Exception(f"未知的 Dify API 类型:{self.api_type}") + + if not result: + logger.warning("Dify 请求结果为空,请查看 Debug 日志。") + + # 解析结果 + chain = await self.parse_dify_result(result) + + # 创建最终响应 + self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain) + self._transition_state(AgentState.DONE) + + try: + await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + + # 返回最终结果 + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=chain), + ) + + async def parse_dify_result(self, chunk: dict | str) -> MessageChain: + """解析 Dify 的响应结果""" + if isinstance(chunk, str): + # Chat + return MessageChain(chain=[Comp.Plain(chunk)]) + + async def parse_file(item: dict): + match item["type"]: + case "image": + return Comp.Image(file=item["url"], url=item["url"]) + case "audio": + # 仅支持 wav + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, f"{item['filename']}.wav") + await download_file(item["url"], path) + return Comp.Image(file=item["url"], url=item["url"]) + case "video": + return Comp.Video(file=item["url"]) + case _: + return Comp.File(name=item["filename"], file=item["url"]) + + output = chunk["data"]["outputs"][self.workflow_output_key] + chains = [] + if isinstance(output, str): + # 纯文本输出 + chains.append(Comp.Plain(output)) + elif isinstance(output, list): + # 主要适配 Dify 的 HTTP 请求结点的多模态输出 + for item in output: + # handle Array[File] + if ( + not isinstance(item, dict) + or item.get("dify_model_identity", "") != "__dify__file__" + ): + chains.append(Comp.Plain(str(output))) + break + else: + chains.append(Comp.Plain(str(output))) + + # scan file + files = chunk["data"].get("files", []) + for item in files: + comp = await parse_file(item) + chains.append(comp) + + return MessageChain(chain=chains) + + @override + def done(self) -> bool: + """检查 Agent 是否已完成工作""" + return self._state in (AgentState.DONE, AgentState.ERROR) + + @override + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/agent/runners/dify/dify_api_client.py similarity index 71% rename from astrbot/core/utils/dify_api_client.py rename to astrbot/core/agent/runners/dify/dify_api_client.py index ea8ff9dff..d9c6556cf 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/agent/runners/dify/dify_api_client.py @@ -3,7 +3,7 @@ import json from collections.abc import AsyncGenerator from typing import Any -from aiohttp import ClientResponse, ClientSession +from aiohttp import ClientResponse, ClientSession, FormData from astrbot.core import logger @@ -101,21 +101,59 @@ class DifyAPIClient: async def file_upload( self, - file_path: str, user: str, + file_path: str | None = None, + file_data: bytes | None = None, + file_name: str | None = None, + mime_type: str | None = None, ) -> dict[str, Any]: + """Upload a file to Dify. Must provide either file_path or file_data. + + Args: + user: The user ID. + file_path: The path to the file to upload. + file_data: The file data in bytes. + file_name: Optional file name when using file_data. + Returns: + A dictionary containing the uploaded file information. + """ url = f"{self.api_base}/files/upload" - with open(file_path, "rb") as f: - payload = { - "user": user, - "file": f, - } - async with self.session.post( - url, - data=payload, - headers=self.headers, - ) as resp: - return await resp.json() # {"id": "xxx", ...} + + form = FormData() + form.add_field("user", user) + + if file_data is not None: + # 使用 bytes 数据 + form.add_field( + "file", + file_data, + filename=file_name or "uploaded_file", + content_type=mime_type or "application/octet-stream", + ) + elif file_path is not None: + # 使用文件路径 + import os + + with open(file_path, "rb") as f: + file_content = f.read() + form.add_field( + "file", + file_content, + filename=os.path.basename(file_path), + content_type=mime_type or "application/octet-stream", + ) + else: + raise ValueError("file_path 和 file_data 不能同时为 None") + + async with self.session.post( + url, + data=form, + headers=self.headers, # 不包含 Content-Type,让 aiohttp 自动设置 + ) as resp: + if resp.status != 200 and resp.status != 201: + text = await resp.text() + raise Exception(f"Dify 文件上传失败:{resp.status}. {text}") + return await resp.json() # {"id": "xxx", ...} async def close(self): await self.session.close() diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index d74a45982..6f3c813eb 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -69,12 +69,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): ) self.run_context.messages = messages - def _transition_state(self, new_state: AgentState) -> None: - """转换 Agent 状态""" - if self._state != new_state: - logger.debug(f"Agent state transition: {self._state} -> {new_state}") - self._state = new_state - async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: """Yields chunks *and* a final LLMResponse.""" if self.streaming: diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index bd26aca13..5a23a8b6a 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -68,6 +68,10 @@ DEFAULT_CONFIG = { "dequeue_context_length": 1, "streaming_response": False, "show_tool_use_status": False, + "agent_runner_type": "local", + "dify_agent_runner_provider_id": "", + "coze_agent_runner_provider_id": "", + "dashscope_agent_runner_provider_id": "", "unsupported_streaming_strategy": "realtime_segmenting", "max_agent_step": 30, "tool_call_timeout": 60, @@ -1011,7 +1015,7 @@ CONFIG_METADATA_2 = { "id": "dify_app_default", "provider": "dify", "type": "dify", - "provider_type": "chat_completion", + "provider_type": "agent_runner", "enable": True, "dify_api_type": "chat", "dify_api_key": "", @@ -1025,20 +1029,20 @@ CONFIG_METADATA_2 = { "Coze": { "id": "coze", "provider": "coze", - "provider_type": "chat_completion", + "provider_type": "agent_runner", "type": "coze", "enable": True, "coze_api_key": "", "bot_id": "", "coze_api_base": "https://api.coze.cn", "timeout": 60, - "auto_save_history": True, + # "auto_save_history": True, }, "阿里云百炼应用": { "id": "dashscope", "provider": "dashscope", "type": "dashscope", - "provider_type": "chat_completion", + "provider_type": "agent_runner", "enable": True, "dashscope_app_type": "agent", "dashscope_api_key": "", @@ -1907,7 +1911,6 @@ CONFIG_METADATA_2 = { "enable": { "description": "启用", "type": "bool", - "hint": "是否启用。", }, "key": { "description": "API Key", @@ -2037,12 +2040,22 @@ CONFIG_METADATA_2 = { "unsupported_streaming_strategy": { "type": "string", }, + "agent_runner_type": { + "type": "string", + }, + "dify_agent_runner_provider_id": { + "type": "string", + }, + "coze_agent_runner_provider_id": { + "type": "string", + }, + "dashscope_agent_runner_provider_id": { + "type": "string", + }, "max_agent_step": { - "description": "工具调用轮数上限", "type": "int", }, "tool_call_timeout": { - "description": "工具调用超时时间(秒)", "type": "int", }, }, @@ -2180,30 +2193,75 @@ CONFIG_METADATA_3 = { "ai_group": { "name": "AI 配置", "metadata": { - "ai": { - "description": "模型", + "agent_runner": { + "description": "Agent 执行方式", + "hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。", "type": "object", "items": { "provider_settings.enable": { - "description": "启用大语言模型聊天", + "description": "启用", "type": "bool", + "hint": "AI 对话总开关", }, + "provider_settings.agent_runner_type": { + "description": "执行器", + "type": "string", + "options": ["local", "dify", "coze", "dashscope"], + "labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"], + "condition": { + "provider_settings.enable": True, + }, + }, + "provider_settings.coze_agent_runner_provider_id": { + "description": "Coze Agent 执行器提供商 ID", + "type": "string", + "_special": "select_agent_runner_provider:coze", + "condition": { + "provider_settings.agent_runner_type": "coze", + "provider_settings.enable": True, + }, + }, + "provider_settings.dify_agent_runner_provider_id": { + "description": "Dify Agent 执行器提供商 ID", + "type": "string", + "_special": "select_agent_runner_provider:dify", + "condition": { + "provider_settings.agent_runner_type": "dify", + "provider_settings.enable": True, + }, + }, + "provider_settings.dashscope_agent_runner_provider_id": { + "description": "阿里云百炼应用 Agent 执行器提供商 ID", + "type": "string", + "_special": "select_agent_runner_provider:dashscope", + "condition": { + "provider_settings.agent_runner_type": "dashscope", + "provider_settings.enable": True, + }, + }, + }, + }, + "ai": { + "description": "模型", + "hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。", + "type": "object", + "items": { "provider_settings.default_provider_id": { "description": "默认聊天模型", "type": "string", "_special": "select_provider", - "hint": "留空时使用第一个模型。", + "hint": "留空时使用第一个模型", }, "provider_settings.default_image_caption_provider_id": { "description": "默认图片转述模型", "type": "string", "_special": "select_provider", - "hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。", + "hint": "留空代表不使用,可用于非多模态模型", }, "provider_stt_settings.enable": { "description": "启用语音转文本", "type": "bool", - "hint": "STT 总开关。", + "hint": "STT 总开关", }, "provider_stt_settings.provider_id": { "description": "默认语音转文本模型", @@ -2217,12 +2275,11 @@ CONFIG_METADATA_3 = { "provider_tts_settings.enable": { "description": "启用文本转语音", "type": "bool", - "hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。", + "hint": "TTS 总开关", }, "provider_tts_settings.provider_id": { "description": "默认文本转语音模型", "type": "string", - "hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。", "_special": "select_provider_tts", "condition": { "provider_tts_settings.enable": True, @@ -2233,6 +2290,9 @@ CONFIG_METADATA_3 = { "type": "text", }, }, + "condition": { + "provider_settings.enable": True, + }, }, "persona": { "description": "人格", @@ -2244,6 +2304,10 @@ CONFIG_METADATA_3 = { "_special": "select_persona", }, }, + "condition": { + "provider_settings.agent_runner_type": "local", + "provider_settings.enable": True, + }, }, "knowledgebase": { "description": "知识库", @@ -2272,6 +2336,10 @@ CONFIG_METADATA_3 = { "hint": "启用后,知识库检索将作为 LLM Tool,由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。", }, }, + "condition": { + "provider_settings.agent_runner_type": "local", + "provider_settings.enable": True, + }, }, "websearch": { "description": "网页搜索", @@ -2308,6 +2376,10 @@ CONFIG_METADATA_3 = { "type": "bool", }, }, + "condition": { + "provider_settings.agent_runner_type": "local", + "provider_settings.enable": True, + }, }, "others": { "description": "其他配置", @@ -2316,34 +2388,51 @@ CONFIG_METADATA_3 = { "provider_settings.display_reasoning_text": { "description": "显示思考内容", "type": "bool", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.identifier": { "description": "用户识别", "type": "bool", + "hint": "启用后,会在提示词前包含用户 ID 信息。", }, "provider_settings.group_name_display": { "description": "显示群名称", "type": "bool", - "hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。", + "hint": "启用后,在支持的平台(OneBot v11)上会在提示词前包含群名称信息。", }, "provider_settings.datetime_system_prompt": { "description": "现实世界时间感知", "type": "bool", + "hint": "启用后,会在系统提示词中附带当前时间信息。", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.show_tool_use_status": { "description": "输出函数调用状态", "type": "bool", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.max_agent_step": { "description": "工具调用轮数上限", "type": "int", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.tool_call_timeout": { "description": "工具调用超时时间(秒)", "type": "int", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.streaming_response": { - "description": "流式回复", + "description": "流式输出", "type": "bool", }, "provider_settings.unsupported_streaming_strategy": { @@ -2359,17 +2448,23 @@ CONFIG_METADATA_3 = { "provider_settings.max_context_length": { "description": "最多携带对话轮数", "type": "int", - "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条。-1 为不限制。", + "hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.dequeue_context_length": { "description": "丢弃对话轮数", "type": "int", - "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数。", + "hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数", + "condition": { + "provider_settings.agent_runner_type": "local", + }, }, "provider_settings.wake_prefix": { "description": "LLM 聊天额外唤醒前缀 ", "type": "string", - "hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。", + "hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求", }, "provider_settings.prompt_prefix": { "description": "用户提示词", @@ -2381,6 +2476,9 @@ CONFIG_METADATA_3 = { "type": "bool", }, }, + "condition": { + "provider_settings.enable": True, + }, }, }, }, diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 17fd52138..e8241f85a 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -16,13 +16,12 @@ import time import traceback from asyncio import Queue -from astrbot.core import LogBroker, logger, sp +from astrbot.api import logger, sp +from astrbot.core import LogBroker from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.db import BaseDatabase -from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 -from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.persona_mgr import PersonaManager from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler @@ -34,6 +33,7 @@ from astrbot.core.star.context import Context from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map from astrbot.core.umop_config_router import UmopConfigRouter from astrbot.core.updator import AstrBotUpdator +from astrbot.core.utils.migra_helper import migra from . import astrbot_config, html_renderer from .event_bus import EventBus @@ -97,18 +97,16 @@ class AstrBotCoreLifecycle: sp=sp, ) - # 4.5 to 4.6 migration for umop_config_router + # apply migration try: - await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router) + await migra( + self.db, + self.astrbot_config_mgr, + self.umop_config_router, + self.astrbot_config_mgr, + ) except Exception as e: - logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}") - logger.error(traceback.format_exc()) - - # migration for webchat session - try: - await migrate_webchat_session(self.db) - except Exception as e: - logger.error(f"Migration for webchat session failed: {e!s}") + logger.error(f"AstrBot migration failed: {e!s}") logger.error(traceback.format_exc()) # 初始化事件队列 diff --git a/astrbot/core/pipeline/process_stage/method/agent_request.py b/astrbot/core/pipeline/process_stage/method/agent_request.py new file mode 100644 index 000000000..f6f81631e --- /dev/null +++ b/astrbot/core/pipeline/process_stage/method/agent_request.py @@ -0,0 +1,48 @@ +from collections.abc import AsyncGenerator + +from astrbot.core import logger +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.star.session_llm_manager import SessionServiceManager + +from ...context import PipelineContext +from ..stage import Stage +from .agent_sub_stages.internal import InternalAgentSubStage +from .agent_sub_stages.third_party import ThirdPartyAgentSubStage + + +class AgentRequestSubStage(Stage): + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + self.config = ctx.astrbot_config + + self.bot_wake_prefixs: list[str] = self.config["wake_prefix"] + self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"] + for bwp in self.bot_wake_prefixs: + if self.prov_wake_prefix.startswith(bwp): + logger.info( + f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。", + ) + self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :] + + agent_runner_type = self.config["provider_settings"]["agent_runner_type"] + if agent_runner_type == "local": + self.agent_sub_stage = InternalAgentSubStage() + else: + self.agent_sub_stage = ThirdPartyAgentSubStage() + await self.agent_sub_stage.initialize(ctx) + + async def process(self, event: AstrMessageEvent) -> AsyncGenerator[None, None]: + if not self.ctx.astrbot_config["provider_settings"]["enable"]: + logger.debug( + "This pipeline does not enable AI capability, skip processing." + ) + return + + if not SessionServiceManager.should_process_llm_request(event): + logger.debug( + f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing." + ) + return + + async for resp in self.agent_sub_stage.process(event, self.prov_wake_prefix): + yield resp diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py similarity index 91% rename from astrbot/core/pipeline/process_stage/method/llm_request.py rename to astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index bd9e4ce3b..aaada9a19 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -21,27 +21,24 @@ from astrbot.core.provider.entities import ( LLMResponse, ProviderRequest, ) -from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager -from ....astr_agent_context import AgentContextWrapper -from ....astr_agent_hooks import MAIN_AGENT_HOOKS -from ....astr_agent_run_util import AgentRunner, run_agent -from ....astr_agent_tool_exec import FunctionToolExecutor -from ...context import PipelineContext, call_event_hook -from ..stage import Stage -from ..utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base +from .....astr_agent_context import AgentContextWrapper +from .....astr_agent_hooks import MAIN_AGENT_HOOKS +from .....astr_agent_run_util import AgentRunner, run_agent +from .....astr_agent_tool_exec import FunctionToolExecutor +from ....context import PipelineContext, call_event_hook +from ...stage import Stage +from ...utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base -class LLMRequestSubStage(Stage): +class InternalAgentSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx conf = ctx.astrbot_config settings = conf["provider_settings"] - self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list - self.provider_wake_prefix: str = settings["wake_prefix"] # str self.max_context_length = settings["max_context_length"] # int self.dequeue_context_length: int = min( max(1, settings["dequeue_context_length"]), @@ -59,13 +56,6 @@ class LLMRequestSubStage(Stage): self.show_reasoning = settings.get("display_reasoning_text", False) self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False) - for bwp in self.bot_wake_prefixs: - if self.provider_wake_prefix.startswith(bwp): - logger.info( - f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。", - ) - self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :] - self.conv_manager = ctx.plugin_manager.context.conversation_manager def _select_provider(self, event: AstrMessageEvent): @@ -304,21 +294,10 @@ class LLMRequestSubStage(Stage): return fixed_messages async def process( - self, - event: AstrMessageEvent, - _nested: bool = False, - ) -> None | AsyncGenerator[None, None]: + self, event: AstrMessageEvent, provider_wake_prefix: str + ) -> AsyncGenerator[None, None]: req: ProviderRequest | None = None - if not self.ctx.astrbot_config["provider_settings"]["enable"]: - logger.debug("未启用 LLM 能力,跳过处理。") - return - - # 检查会话级别的LLM启停状态 - if not SessionServiceManager.should_process_llm_request(event): - logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。") - return - provider = self._select_provider(event) if provider is None: return @@ -348,12 +327,12 @@ class LLMRequestSubStage(Stage): req.image_urls = [] if sel_model := event.get_extra("selected_model"): req.model = sel_model - if self.provider_wake_prefix and not event.message_str.startswith( - self.provider_wake_prefix + if provider_wake_prefix and not event.message_str.startswith( + provider_wake_prefix ): return - req.prompt = event.message_str[len(self.provider_wake_prefix) :] + req.prompt = event.message_str[len(provider_wake_prefix) :] # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。 # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() for comp in event.message_obj.message: diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py new file mode 100644 index 000000000..a54aa52d3 --- /dev/null +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -0,0 +1,202 @@ +import asyncio +from collections.abc import AsyncGenerator +from typing import TYPE_CHECKING + +from astrbot.core import logger +from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner +from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import ( + DashscopeAgentRunner, +) +from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner +from astrbot.core.message.components import Image +from astrbot.core.message.message_event_result import ( + MessageChain, + MessageEventResult, + ResultContentType, +) + +if TYPE_CHECKING: + from astrbot.core.agent.runners.base import BaseAgentRunner +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.provider.entities import ( + ProviderRequest, +) +from astrbot.core.star.star_handler import EventType +from astrbot.core.utils.metrics import Metric + +from .....astr_agent_context import AgentContextWrapper, AstrAgentContext +from .....astr_agent_hooks import MAIN_AGENT_HOOKS +from ....context import PipelineContext, call_event_hook +from ...stage import Stage + +AGENT_RUNNER_TYPE_KEY = { + "dify": "dify_agent_runner_provider_id", + "coze": "coze_agent_runner_provider_id", + "dashscope": "dashscope_agent_runner_provider_id", +} + + +async def run_third_party_agent( + runner: "BaseAgentRunner", + stream_to_general: bool = False, +) -> AsyncGenerator[MessageChain | None, None]: + """ + 运行第三方 agent runner 并转换响应格式 + 类似于 run_agent 函数,但专门处理第三方 agent runner + """ + try: + async for resp in runner.step_until_done(max_step=30): # type: ignore[misc] + if resp.type == "streaming_delta": + if stream_to_general: + continue + yield resp.data["chain"] + elif resp.type == "llm_result": + if stream_to_general: + yield resp.data["chain"] + except Exception as e: + logger.error(f"Third party agent runner error: {e}") + err_msg = ( + f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n" + f"错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n" + ) + yield MessageChain().message(err_msg) + + +class ThirdPartyAgentSubStage(Stage): + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + self.conf = ctx.astrbot_config + self.runner_type = self.conf["provider_settings"]["agent_runner_type"] + self.prov_id = self.conf["provider_settings"].get( + AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""), + "", + ) + settings = ctx.astrbot_config["provider_settings"] + self.streaming_response: bool = settings["streaming_response"] + self.unsupported_streaming_strategy: str = settings[ + "unsupported_streaming_strategy" + ] + + async def process( + self, event: AstrMessageEvent, provider_wake_prefix: str + ) -> AsyncGenerator[None, None]: + req: ProviderRequest | None = None + + if provider_wake_prefix and not event.message_str.startswith( + provider_wake_prefix + ): + return + + self.prov_cfg: dict = next( + (p for p in self.conf["provider"] if p["id"] == self.prov_id), + {}, + ) + if not self.prov_id or not self.prov_cfg: + logger.error( + "Third Party Agent Runner provider ID is not configured properly." + ) + return + + # make provider request + req = ProviderRequest() + req.session_id = event.unified_msg_origin + req.prompt = event.message_str[len(provider_wake_prefix) :] + for comp in event.message_obj.message: + if isinstance(comp, Image): + image_path = await comp.convert_to_base64() + req.image_urls.append(image_path) + + if not req.prompt and not req.image_urls: + return + + # call event hook + if await call_event_hook(event, EventType.OnLLMRequestEvent, req): + return + + if self.runner_type == "dify": + runner = DifyAgentRunner[AstrAgentContext]() + elif self.runner_type == "coze": + runner = CozeAgentRunner[AstrAgentContext]() + elif self.runner_type == "dashscope": + runner = DashscopeAgentRunner[AstrAgentContext]() + else: + raise ValueError( + f"Unsupported third party agent runner type: {self.runner_type}", + ) + + astr_agent_ctx = AstrAgentContext( + context=self.ctx.plugin_manager.context, + event=event, + ) + + streaming_response = self.streaming_response + if (enable_streaming := event.get_extra("enable_streaming")) is not None: + streaming_response = bool(enable_streaming) + + stream_to_general = ( + self.unsupported_streaming_strategy == "turn_off" + and not event.platform_meta.support_streaming_message + ) + + await runner.reset( + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=60, + ), + agent_hooks=MAIN_AGENT_HOOKS, + provider_config=self.prov_cfg, + streaming=streaming_response, + ) + + if streaming_response and not stream_to_general: + # 流式响应 + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream( + run_third_party_agent( + runner, + stream_to_general=False, + ), + ), + ) + yield + if runner.done(): + final_resp = runner.get_final_llm_resp() + if final_resp and final_resp.result_chain: + event.set_result( + MessageEventResult( + chain=final_resp.result_chain.chain or [], + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) + else: + # 非流式响应或转换为普通响应 + async for _ in run_third_party_agent( + runner, + stream_to_general=stream_to_general, + ): + yield + + final_resp = runner.get_final_llm_resp() + + if not final_resp or not final_resp.result_chain: + logger.warning("Agent Runner 未返回最终结果。") + return + + event.set_result( + MessageEventResult( + chain=final_resp.result_chain.chain or [], + result_content_type=ResultContentType.LLM_RESULT, + ), + ) + yield + + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=self.runner_type, + provider_type=self.runner_type, + ), + ) diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index ff8120b16..56d305de4 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -24,7 +24,7 @@ class StarRequestSubStage(Stage): async def process( self, event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: + ) -> AsyncGenerator[None, None]: activated_handlers: list[StarHandlerMetadata] = event.get_extra( "activated_handlers", ) diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 9f0b5f92a..2eeefcf11 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -7,7 +7,7 @@ from astrbot.core.star.star_handler import StarHandlerMetadata from ..context import PipelineContext from ..stage import Stage, register_stage -from .method.llm_request import LLMRequestSubStage +from .method.agent_request import AgentRequestSubStage from .method.star_request import StarRequestSubStage @@ -17,9 +17,12 @@ class ProcessStage(Stage): self.ctx = ctx self.config = ctx.astrbot_config self.plugin_manager = ctx.plugin_manager - self.llm_request_sub_stage = LLMRequestSubStage() - await self.llm_request_sub_stage.initialize(ctx) + # initialize agent sub stage + self.agent_sub_stage = AgentRequestSubStage() + await self.agent_sub_stage.initialize(ctx) + + # initialize star request sub stage self.star_request_sub_stage = StarRequestSubStage() await self.star_request_sub_stage.initialize(ctx) @@ -39,7 +42,7 @@ class ProcessStage(Stage): # Handler 的 LLM 请求 event.set_extra("provider_request", resp) _t = False - async for _ in self.llm_request_sub_stage.process(event): + async for _ in self.agent_sub_stage.process(event): _t = True yield if not _t: @@ -67,5 +70,5 @@ class ProcessStage(Stage): logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。") return - async for _ in self.llm_request_sub_stage.process(event): + async for _ in self.agent_sub_stage.process(event): yield diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index ec2550415..115a48463 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -227,6 +227,8 @@ class ProviderManager: async def load_provider(self, provider_config: dict): if not provider_config["enable"]: return + if provider_config.get("provider_type", "") == "agent_runner": + return logger.info( f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...", @@ -247,14 +249,6 @@ class ProviderManager: from .sources.anthropic_source import ( ProviderAnthropic as ProviderAnthropic, ) - case "dify": - from .sources.dify_source import ProviderDify as ProviderDify - case "coze": - from .sources.coze_source import ProviderCoze as ProviderCoze - case "dashscope": - from .sources.dashscope_source import ( - ProviderDashscope as ProviderDashscope, - ) case "googlegenai_chat_completion": from .sources.gemini_source import ( ProviderGoogleGenAI as ProviderGoogleGenAI, diff --git a/astrbot/core/provider/sources/coze_source.py b/astrbot/core/provider/sources/coze_source.py deleted file mode 100644 index 6f1355bf7..000000000 --- a/astrbot/core/provider/sources/coze_source.py +++ /dev/null @@ -1,650 +0,0 @@ -import base64 -import hashlib -import json -import os -from collections.abc import AsyncGenerator - -import astrbot.core.message.components as Comp -from astrbot import logger -from astrbot.api.provider import Provider -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.provider.entities import LLMResponse - -from ..register import register_provider_adapter -from .coze_api_client import CozeAPIClient - - -@register_provider_adapter("coze", "Coze (扣子) 智能体适配器") -class ProviderCoze(Provider): - def __init__( - self, - provider_config, - provider_settings, - ) -> None: - super().__init__( - provider_config, - provider_settings, - ) - self.api_key = provider_config.get("coze_api_key", "") - if not self.api_key: - raise Exception("Coze API Key 不能为空。") - self.bot_id = provider_config.get("bot_id", "") - if not self.bot_id: - raise Exception("Coze Bot ID 不能为空。") - self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn") - - if not isinstance(self.api_base, str) or not self.api_base.startswith( - ("http://", "https://"), - ): - raise Exception( - "Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。", - ) - - self.timeout = provider_config.get("timeout", 120) - if isinstance(self.timeout, str): - self.timeout = int(self.timeout) - self.auto_save_history = provider_config.get("auto_save_history", True) - self.conversation_ids: dict[str, str] = {} - self.file_id_cache: dict[str, dict[str, str]] = {} - - # 创建 API 客户端 - self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base) - - def _generate_cache_key(self, data: str, is_base64: bool = False) -> str: - """生成统一的缓存键 - - Args: - data: 图片数据或路径 - is_base64: 是否是 base64 数据 - - Returns: - str: 缓存键 - - """ - try: - if is_base64 and data.startswith("data:image/"): - try: - header, encoded = data.split(",", 1) - image_bytes = base64.b64decode(encoded) - cache_key = hashlib.md5(image_bytes).hexdigest() - return cache_key - except Exception: - cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest() - return cache_key - elif data.startswith(("http://", "https://")): - # URL图片,使用URL作为缓存键 - cache_key = hashlib.md5(data.encode("utf-8")).hexdigest() - return cache_key - else: - clean_path = ( - data.split("_")[0] - if "_" in data and len(data.split("_")) >= 3 - else data - ) - - if os.path.exists(clean_path): - with open(clean_path, "rb") as f: - file_content = f.read() - cache_key = hashlib.md5(file_content).hexdigest() - return cache_key - cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest() - return cache_key - - except Exception as e: - cache_key = hashlib.md5(data.encode("utf-8")).hexdigest() - logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}") - return cache_key - - async def _upload_file( - self, - file_data: bytes, - session_id: str | None = None, - cache_key: str | None = None, - ) -> str: - """上传文件到 Coze 并返回 file_id""" - # 使用 API 客户端上传文件 - file_id = await self.api_client.upload_file(file_data) - - # 缓存 file_id - if session_id and cache_key: - if session_id not in self.file_id_cache: - self.file_id_cache[session_id] = {} - self.file_id_cache[session_id][cache_key] = file_id - logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}") - - return file_id - - async def _download_and_upload_image( - self, - image_url: str, - session_id: str | None = None, - ) -> str: - """下载图片并上传到 Coze,返回 file_id""" - # 计算哈希实现缓存 - cache_key = self._generate_cache_key(image_url) if session_id else None - - if session_id and cache_key: - if session_id not in self.file_id_cache: - self.file_id_cache[session_id] = {} - - if cache_key in self.file_id_cache[session_id]: - file_id = self.file_id_cache[session_id][cache_key] - return file_id - - try: - image_data = await self.api_client.download_image(image_url) - - file_id = await self._upload_file(image_data, session_id, cache_key) - - if session_id and cache_key: - self.file_id_cache[session_id][cache_key] = file_id - - return file_id - - except Exception as e: - logger.error(f"处理图片失败 {image_url}: {e!s}") - raise Exception(f"处理图片失败: {e!s}") - - async def _process_context_images( - self, - content: str | list, - session_id: str, - ) -> str: - """处理上下文中的图片内容,将 base64 图片上传并替换为 file_id""" - try: - if isinstance(content, str): - return content - - processed_content = [] - if session_id not in self.file_id_cache: - self.file_id_cache[session_id] = {} - - for item in content: - if not isinstance(item, dict): - processed_content.append(item) - continue - if item.get("type") == "text": - processed_content.append(item) - elif item.get("type") == "image_url": - # 处理图片逻辑 - if "file_id" in item: - # 已经有 file_id - logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}") - processed_content.append(item) - else: - # 获取图片数据 - image_data = "" - if "image_url" in item and isinstance(item["image_url"], dict): - image_data = item["image_url"].get("url", "") - elif "data" in item: - image_data = item.get("data", "") - elif "url" in item: - image_data = item.get("url", "") - - if not image_data: - continue - # 计算哈希用于缓存 - cache_key = self._generate_cache_key( - image_data, - is_base64=image_data.startswith("data:image/"), - ) - - # 检查缓存 - if cache_key in self.file_id_cache[session_id]: - file_id = self.file_id_cache[session_id][cache_key] - processed_content.append( - {"type": "image", "file_id": file_id}, - ) - else: - # 上传图片并缓存 - if image_data.startswith("data:image/"): - # base64 处理 - _, encoded = image_data.split(",", 1) - image_bytes = base64.b64decode(encoded) - file_id = await self._upload_file( - image_bytes, - session_id, - cache_key, - ) - elif image_data.startswith(("http://", "https://")): - # URL 图片 - file_id = await self._download_and_upload_image( - image_data, - session_id, - ) - # 为URL图片也添加缓存 - self.file_id_cache[session_id][cache_key] = file_id - elif os.path.exists(image_data): - # 本地文件 - with open(image_data, "rb") as f: - image_bytes = f.read() - file_id = await self._upload_file( - image_bytes, - session_id, - cache_key, - ) - else: - logger.warning( - f"无法处理的图片格式: {image_data[:50]}...", - ) - continue - - processed_content.append( - {"type": "image", "file_id": file_id}, - ) - - result = json.dumps(processed_content, ensure_ascii=False) - return result - except Exception as e: - logger.error(f"处理上下文图片失败: {e!s}") - if isinstance(content, str): - return content - return json.dumps(content, ensure_ascii=False) - - async def text_chat( - self, - prompt: str, - session_id=None, - image_urls=None, - func_tool=None, - contexts=None, - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ) -> LLMResponse: - """文本对话, 内部使用流式接口实现非流式 - - Args: - prompt (str): 用户提示词 - session_id (str): 会话ID - image_urls (List[str]): 图片URL列表 - func_tool (FuncCall): 函数调用工具(不支持) - contexts (List): 上下文列表 - system_prompt (str): 系统提示语 - tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持) - model (str): 模型名称(不支持) - - Returns: - LLMResponse: LLM响应对象 - - """ - accumulated_content = "" - final_response = None - - async for llm_response in self.text_chat_stream( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - model=model, - **kwargs, - ): - if llm_response.is_chunk: - if llm_response.completion_text: - accumulated_content += llm_response.completion_text - else: - final_response = llm_response - - if final_response: - return final_response - - if accumulated_content: - chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) - return LLMResponse(role="assistant", result_chain=chain) - return LLMResponse(role="assistant", completion_text="") - - async def text_chat_stream( - self, - prompt: str, - session_id=None, - image_urls=None, - func_tool=None, - contexts=None, - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ) -> AsyncGenerator[LLMResponse, None]: - """流式对话接口""" - # 用户ID参数(参考文档, 可以自定义) - user_id = session_id or kwargs.get("user", "default_user") - - # 获取或创建会话ID - conversation_id = self.conversation_ids.get(user_id) - - # 构建消息 - additional_messages = [] - - if system_prompt: - if not self.auto_save_history or not conversation_id: - additional_messages.append( - { - "role": "system", - "content": system_prompt, - "content_type": "text", - }, - ) - - contexts = self._ensure_message_to_dicts(contexts) - if not self.auto_save_history and contexts: - # 如果关闭了自动保存历史,传入上下文 - for ctx in contexts: - if isinstance(ctx, dict) and "role" in ctx and "content" in ctx: - content = ctx["content"] - content_type = ctx.get("content_type", "text") - - # 处理可能包含图片的上下文 - if ( - content_type == "object_string" - or (isinstance(content, str) and content.startswith("[")) - or ( - isinstance(content, list) - and any( - isinstance(item, dict) - and item.get("type") == "image_url" - for item in content - ) - ) - ): - processed_content = await self._process_context_images( - content, - user_id, - ) - additional_messages.append( - { - "role": ctx["role"], - "content": processed_content, - "content_type": "object_string", - }, - ) - else: - # 纯文本 - additional_messages.append( - { - "role": ctx["role"], - "content": ( - content - if isinstance(content, str) - else json.dumps(content, ensure_ascii=False) - ), - "content_type": "text", - }, - ) - else: - logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}") - - if prompt or image_urls: - if image_urls: - # 多模态 - object_string_content = [] - if prompt: - object_string_content.append({"type": "text", "text": prompt}) - - for url in image_urls: - try: - if url.startswith(("http://", "https://")): - # 网络图片 - file_id = await self._download_and_upload_image( - url, - user_id, - ) - else: - # 本地文件或 base64 - if url.startswith("data:image/"): - # base64 - _, encoded = url.split(",", 1) - image_data = base64.b64decode(encoded) - cache_key = self._generate_cache_key( - url, - is_base64=True, - ) - file_id = await self._upload_file( - image_data, - user_id, - cache_key, - ) - # 本地文件 - elif os.path.exists(url): - with open(url, "rb") as f: - image_data = f.read() - # 用文件路径和修改时间来缓存 - file_stat = os.stat(url) - cache_key = self._generate_cache_key( - f"{url}_{file_stat.st_mtime}_{file_stat.st_size}", - is_base64=False, - ) - file_id = await self._upload_file( - image_data, - user_id, - cache_key, - ) - else: - logger.warning(f"图片文件不存在: {url}") - continue - - object_string_content.append( - { - "type": "image", - "file_id": file_id, - }, - ) - except Exception as e: - logger.error(f"处理图片失败 {url}: {e!s}") - continue - - if object_string_content: - content = json.dumps(object_string_content, ensure_ascii=False) - additional_messages.append( - { - "role": "user", - "content": content, - "content_type": "object_string", - }, - ) - # 纯文本 - elif prompt: - additional_messages.append( - { - "role": "user", - "content": prompt, - "content_type": "text", - }, - ) - - try: - accumulated_content = "" - message_started = False - - async for chunk in self.api_client.chat_messages( - bot_id=self.bot_id, - user_id=user_id, - additional_messages=additional_messages, - conversation_id=conversation_id, - auto_save_history=self.auto_save_history, - stream=True, - timeout=self.timeout, - ): - event_type = chunk.get("event") - data = chunk.get("data", {}) - - if event_type == "conversation.chat.created": - if isinstance(data, dict) and "conversation_id" in data: - self.conversation_ids[user_id] = data["conversation_id"] - - elif event_type == "conversation.message.delta": - if isinstance(data, dict): - content = data.get("content", "") - if not content and "delta" in data: - content = data["delta"].get("content", "") - if not content and "text" in data: - content = data.get("text", "") - - if content: - message_started = True - accumulated_content += content - yield LLMResponse( - role="assistant", - completion_text=content, - is_chunk=True, - ) - - elif event_type == "conversation.message.completed": - if isinstance(data, dict): - msg_type = data.get("type") - if msg_type == "answer" and data.get("role") == "assistant": - final_content = data.get("content", "") - if not accumulated_content and final_content: - chain = MessageChain(chain=[Comp.Plain(final_content)]) - yield LLMResponse( - role="assistant", - result_chain=chain, - is_chunk=False, - ) - - elif event_type == "conversation.chat.completed": - if accumulated_content: - chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) - yield LLMResponse( - role="assistant", - result_chain=chain, - is_chunk=False, - ) - break - - elif event_type == "done": - break - - elif event_type == "error": - error_msg = ( - data.get("message", "未知错误") - if isinstance(data, dict) - else str(data) - ) - logger.error(f"Coze 流式响应错误: {error_msg}") - yield LLMResponse( - role="err", - completion_text=f"Coze 错误: {error_msg}", - is_chunk=False, - ) - break - - if not message_started and not accumulated_content: - yield LLMResponse( - role="assistant", - completion_text="LLM 未响应任何内容。", - is_chunk=False, - ) - elif message_started and accumulated_content: - chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) - yield LLMResponse( - role="assistant", - result_chain=chain, - is_chunk=False, - ) - - except Exception as e: - logger.error(f"Coze 流式请求失败: {e!s}") - yield LLMResponse( - role="err", - completion_text=f"Coze 流式请求失败: {e!s}", - is_chunk=False, - ) - - async def forget(self, session_id: str): - """清空指定会话的上下文""" - user_id = session_id - conversation_id = self.conversation_ids.get(user_id) - - if user_id in self.file_id_cache: - self.file_id_cache.pop(user_id, None) - - if not conversation_id: - return True - - try: - response = await self.api_client.clear_context(conversation_id) - - if "code" in response and response["code"] == 0: - self.conversation_ids.pop(user_id, None) - return True - logger.warning(f"清空 Coze 会话上下文失败: {response}") - return False - - except Exception as e: - logger.error(f"清空 Coze 会话失败: {e!s}") - return False - - async def get_current_key(self): - """获取当前API Key""" - return self.api_key - - async def set_key(self, key: str): - """设置新的API Key""" - raise NotImplementedError("Coze 适配器不支持设置 API Key。") - - async def get_models(self): - """获取可用模型列表""" - return [f"bot_{self.bot_id}"] - - def get_model(self): - """获取当前模型""" - return f"bot_{self.bot_id}" - - def set_model(self, model: str): - """设置模型(在Coze中是Bot ID)""" - if model.startswith("bot_"): - self.bot_id = model[4:] - else: - self.bot_id = model - - async def get_human_readable_context( - self, - session_id: str, - page: int = 1, - page_size: int = 10, - ): - """获取人类可读的上下文历史""" - user_id = session_id - conversation_id = self.conversation_ids.get(user_id) - - if not conversation_id: - return [] - - try: - data = await self.api_client.get_message_list( - conversation_id=conversation_id, - order="desc", - limit=page_size, - offset=(page - 1) * page_size, - ) - - if data.get("code") != 0: - logger.warning(f"获取 Coze 消息历史失败: {data}") - return [] - - messages = data.get("data", {}).get("messages", []) - - readable_history = [] - for msg in messages: - role = msg.get("role", "unknown") - content = msg.get("content", "") - msg_type = msg.get("type", "") - - if role == "user": - readable_history.append(f"用户: {content}") - elif role == "assistant" and msg_type == "answer": - readable_history.append(f"助手: {content}") - - return readable_history - - except Exception as e: - logger.error(f"获取 Coze 消息历史失败: {e!s}") - return [] - - async def terminate(self): - """清理资源""" - await self.api_client.close() diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py deleted file mode 100644 index 7c690e048..000000000 --- a/astrbot/core/provider/sources/dashscope_source.py +++ /dev/null @@ -1,207 +0,0 @@ -import asyncio -import functools -import re - -from dashscope import Application -from dashscope.app.application_response import ApplicationResponse - -from astrbot.core import logger, sp -from astrbot.core.message.message_event_result import MessageChain - -from .. import Provider -from ..entities import LLMResponse -from ..register import register_provider_adapter -from .openai_source import ProviderOpenAIOfficial - - -@register_provider_adapter("dashscope", "Dashscope APP 适配器。") -class ProviderDashscope(ProviderOpenAIOfficial): - def __init__( - self, - provider_config: dict, - provider_settings: dict, - ) -> None: - Provider.__init__( - self, - provider_config, - provider_settings, - ) - self.api_key = provider_config.get("dashscope_api_key", "") - if not self.api_key: - raise Exception("阿里云百炼 API Key 不能为空。") - self.app_id = provider_config.get("dashscope_app_id", "") - if not self.app_id: - raise Exception("阿里云百炼 APP ID 不能为空。") - self.dashscope_app_type = provider_config.get("dashscope_app_type", "") - if not self.dashscope_app_type: - raise Exception("阿里云百炼 APP 类型不能为空。") - self.model_name = "dashscope" - self.variables: dict = provider_config.get("variables", {}) - self.rag_options: dict = provider_config.get("rag_options", {}) - self.output_reference = self.rag_options.get("output_reference", False) - self.rag_options = self.rag_options.copy() - self.rag_options.pop("output_reference", None) - - self.timeout = provider_config.get("timeout", 120) - if isinstance(self.timeout, str): - self.timeout = int(self.timeout) - - def has_rag_options(self): - """判断是否有 RAG 选项 - - Returns: - bool: 是否有 RAG 选项 - - """ - if self.rag_options and ( - len(self.rag_options.get("pipeline_ids", [])) > 0 - or len(self.rag_options.get("file_ids", [])) > 0 - ): - return True - return False - - async def text_chat( - self, - prompt: str, - session_id=None, - image_urls=None, - func_tool=None, - contexts=None, - system_prompt=None, - model=None, - **kwargs, - ) -> LLMResponse: - if image_urls is None: - image_urls = [] - if contexts is None: - contexts = [] - # 获得会话变量 - payload_vars = self.variables.copy() - # 动态变量 - session_var = await sp.session_get(session_id, "session_variables", default={}) - payload_vars.update(session_var) - - if ( - self.dashscope_app_type in ["agent", "dialog-workflow"] - and not self.has_rag_options() - ): - # 支持多轮对话的 - new_record = {"role": "user", "content": prompt} - if image_urls: - logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。") - contexts_no_img = await self._remove_image_from_context(contexts) - context_query = [*contexts_no_img, new_record] - if system_prompt: - context_query.insert(0, {"role": "system", "content": system_prompt}) - for part in context_query: - if "_no_save" in part: - del part["_no_save"] - # 调用阿里云百炼 API - payload = { - "app_id": self.app_id, - "api_key": self.api_key, - "messages": context_query, - "biz_params": payload_vars or None, - } - partial = functools.partial( - Application.call, - **payload, - ) - response = await asyncio.get_event_loop().run_in_executor(None, partial) - else: - # 不支持多轮对话的 - # 调用阿里云百炼 API - payload = { - "app_id": self.app_id, - "prompt": prompt, - "api_key": self.api_key, - "biz_params": payload_vars or None, - } - if self.rag_options: - payload["rag_options"] = self.rag_options - partial = functools.partial( - Application.call, - **payload, - ) - response = await asyncio.get_event_loop().run_in_executor(None, partial) - - assert isinstance(response, ApplicationResponse) - - logger.debug(f"dashscope resp: {response}") - - if response.status_code != 200: - logger.error( - f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code", - ) - return LLMResponse( - role="err", - result_chain=MessageChain().message( - f"阿里云百炼请求失败: message={response.message} code={response.status_code}", - ), - ) - - output_text = response.output.get("text", "") or "" - # RAG 引用脚标格式化 - output_text = re.sub(r"\[(\d+)\]", r"[\1]", output_text) - if self.output_reference and response.output.get("doc_references", None): - ref_parts = [] - for ref in response.output.get("doc_references", []) or []: - ref_title = ( - ref.get("title", "") - if ref.get("title") - else ref.get("doc_name", "") - ) - ref_parts.append(f"{ref['index_id']}. {ref_title}\n") - ref_str = "".join(ref_parts) - output_text += f"\n\n回答来源:\n{ref_str}" - - llm_response = LLMResponse("assistant") - llm_response.result_chain = MessageChain().message(output_text) - - return llm_response - - async def text_chat_stream( - self, - prompt, - session_id=None, - image_urls=..., - func_tool=None, - contexts=..., - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response - - async def forget(self, session_id): - return True - - async def get_current_key(self): - return self.api_key - - async def set_key(self, key): - raise Exception("阿里云百炼 适配器不支持设置 API Key。") - - async def get_models(self): - return [self.get_model()] - - async def get_human_readable_context(self, session_id, page, page_size): - raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。") - - async def terminate(self): - pass diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py deleted file mode 100644 index 7850a982c..000000000 --- a/astrbot/core/provider/sources/dify_source.py +++ /dev/null @@ -1,285 +0,0 @@ -import os - -import astrbot.core.message.components as Comp -from astrbot.core import logger, sp -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.utils.dify_api_client import DifyAPIClient -from astrbot.core.utils.io import download_file, download_image_by_url - -from .. import Provider -from ..entities import LLMResponse -from ..register import register_provider_adapter - - -@register_provider_adapter("dify", "Dify APP 适配器。") -class ProviderDify(Provider): - def __init__( - self, - provider_config, - provider_settings, - ) -> None: - super().__init__( - provider_config, - provider_settings, - ) - self.api_key = provider_config.get("dify_api_key", "") - if not self.api_key: - raise Exception("Dify API Key 不能为空。") - api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1") - self.api_type = provider_config.get("dify_api_type", "") - if not self.api_type: - raise Exception("Dify API 类型不能为空。") - self.model_name = "dify" - self.workflow_output_key = provider_config.get( - "dify_workflow_output_key", - "astrbot_wf_output", - ) - self.dify_query_input_key = provider_config.get( - "dify_query_input_key", - "astrbot_text_query", - ) - if not self.dify_query_input_key: - self.dify_query_input_key = "astrbot_text_query" - if not self.workflow_output_key: - self.workflow_output_key = "astrbot_wf_output" - self.variables: dict = provider_config.get("variables", {}) - self.timeout = provider_config.get("timeout", 120) - if isinstance(self.timeout, str): - self.timeout = int(self.timeout) - self.conversation_ids = {} - """记录当前 session id 的对话 ID""" - - self.api_client = DifyAPIClient(self.api_key, api_base) - - async def text_chat( - self, - prompt: str, - session_id=None, - image_urls=None, - func_tool=None, - contexts=None, - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ) -> LLMResponse: - if image_urls is None: - image_urls = [] - result = "" - session_id = session_id or kwargs.get("user") or "unknown" # 1734 - conversation_id = self.conversation_ids.get(session_id, "") - - files_payload = [] - for image_url in image_urls: - image_path = ( - await download_image_by_url(image_url) - if image_url.startswith("http") - else image_url - ) - file_response = await self.api_client.file_upload( - image_path, - user=session_id, - ) - logger.debug(f"Dify 上传图片响应:{file_response}") - if "id" not in file_response: - logger.warning( - f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。", - ) - continue - files_payload.append( - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": file_response["id"], - }, - ) - - # 获得会话变量 - payload_vars = self.variables.copy() - # 动态变量 - session_var = await sp.session_get(session_id, "session_variables", default={}) - payload_vars.update(session_var) - payload_vars["system_prompt"] = system_prompt - - try: - match self.api_type: - case "chat" | "agent" | "chatflow": - if not prompt: - prompt = "请描述这张图片。" - - async for chunk in self.api_client.chat_messages( - inputs={ - **payload_vars, - }, - query=prompt, - user=session_id, - conversation_id=conversation_id, - files=files_payload, - timeout=self.timeout, - ): - logger.debug(f"dify resp chunk: {chunk}") - if ( - chunk["event"] == "message" - or chunk["event"] == "agent_message" - ): - result += chunk["answer"] - if not conversation_id: - self.conversation_ids[session_id] = chunk[ - "conversation_id" - ] - conversation_id = chunk["conversation_id"] - elif chunk["event"] == "message_end": - logger.debug("Dify message end") - break - elif chunk["event"] == "error": - logger.error(f"Dify 出现错误:{chunk}") - raise Exception( - f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}", - ) - - case "workflow": - async for chunk in self.api_client.workflow_run( - inputs={ - self.dify_query_input_key: prompt, - "astrbot_session_id": session_id, - **payload_vars, - }, - user=session_id, - files=files_payload, - timeout=self.timeout, - ): - match chunk["event"]: - case "workflow_started": - logger.info( - f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。", - ) - case "node_finished": - logger.debug( - f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。", - ) - case "workflow_finished": - logger.info( - f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束", - ) - logger.debug(f"Dify 工作流结果:{chunk}") - if chunk["data"]["error"]: - logger.error( - f"Dify 工作流出现错误:{chunk['data']['error']}", - ) - raise Exception( - f"Dify 工作流出现错误:{chunk['data']['error']}", - ) - if ( - self.workflow_output_key - not in chunk["data"]["outputs"] - ): - raise Exception( - f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}", - ) - result = chunk - case _: - raise Exception(f"未知的 Dify API 类型:{self.api_type}") - except Exception as e: - logger.error(f"Dify 请求失败:{e!s}") - return LLMResponse(role="err", completion_text=f"Dify 请求失败:{e!s}") - - if not result: - logger.warning("Dify 请求结果为空,请查看 Debug 日志。") - - chain = await self.parse_dify_result(result) - - return LLMResponse(role="assistant", result_chain=chain) - - async def text_chat_stream( - self, - prompt, - session_id=None, - image_urls=..., - func_tool=None, - contexts=..., - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response - - async def parse_dify_result(self, chunk: dict | str) -> MessageChain: - if isinstance(chunk, str): - # Chat - return MessageChain(chain=[Comp.Plain(chunk)]) - - async def parse_file(item: dict): - match item["type"]: - case "image": - return Comp.Image(file=item["url"], url=item["url"]) - case "audio": - # 仅支持 wav - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - path = os.path.join(temp_dir, f"{item['filename']}.wav") - await download_file(item["url"], path) - return Comp.Image(file=item["url"], url=item["url"]) - case "video": - return Comp.Video(file=item["url"]) - case _: - return Comp.File(name=item["filename"], file=item["url"]) - - output = chunk["data"]["outputs"][self.workflow_output_key] - chains = [] - if isinstance(output, str): - # 纯文本输出 - chains.append(Comp.Plain(output)) - elif isinstance(output, list): - # 主要适配 Dify 的 HTTP 请求结点的多模态输出 - for item in output: - # handle Array[File] - if ( - not isinstance(item, dict) - or item.get("dify_model_identity", "") != "__dify__file__" - ): - chains.append(Comp.Plain(str(output))) - break - else: - chains.append(Comp.Plain(str(output))) - - # scan file - files = chunk["data"].get("files", []) - for item in files: - comp = await parse_file(item) - chains.append(comp) - - return MessageChain(chain=chains) - - async def forget(self, session_id): - self.conversation_ids[session_id] = "" - return True - - async def get_current_key(self): - return self.api_key - - async def set_key(self, key): - raise Exception("Dify 适配器不支持设置 API Key。") - - async def get_models(self): - return [self.get_model()] - - async def get_human_readable_context(self, session_id, page, page_size): - raise Exception("暂不支持获得 Dify 的历史消息记录。") - - async def terminate(self): - await self.api_client.close() diff --git a/astrbot/core/umop_config_router.py b/astrbot/core/umop_config_router.py index 07858da5f..27f6232aa 100644 --- a/astrbot/core/umop_config_router.py +++ b/astrbot/core/umop_config_router.py @@ -85,3 +85,22 @@ class UmopConfigRouter: self.umop_to_conf_id[umo] = conf_id await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) + + async def delete_route(self, umo: str): + """删除一条路由 + + Args: + umo (str): 需要删除的 UMO 字符串 + + Raises: + ValueError: 当 umo 格式不正确时抛出 + """ + + if not isinstance(umo, str) or len(umo.split(":")) != 3: + raise ValueError( + "umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all", + ) + + if umo in self.umop_to_conf_id: + del self.umop_to_conf_id[umo] + await self.sp.global_put("umop_config_routing", self.umop_to_conf_id) diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py new file mode 100644 index 000000000..5642d606e --- /dev/null +++ b/astrbot/core/utils/migra_helper.py @@ -0,0 +1,73 @@ +import traceback + +from astrbot.core import astrbot_config, logger +from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager +from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 +from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session + + +def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None: + """ + Migra agent runner configs from provider configs. + """ + try: + default_prov_id = conf["provider_settings"]["default_provider_id"] + if default_prov_id in ids_map: + conf["provider_settings"]["default_provider_id"] = "" + p = ids_map[default_prov_id] + if p["type"] == "dify": + conf["provider_settings"]["dify_agent_runner_provider_id"] = p["id"] + conf["provider_settings"]["agent_runner_type"] = "dify" + elif p["type"] == "coze": + conf["provider_settings"]["coze_agent_runner_provider_id"] = p["id"] + conf["provider_settings"]["agent_runner_type"] = "coze" + elif p["type"] == "dashscope": + conf["provider_settings"]["dashscope_agent_runner_provider_id"] = p[ + "id" + ] + conf["provider_settings"]["agent_runner_type"] = "dashscope" + conf.save_config() + except Exception as e: + logger.error(f"Migration for third party agent runner configs failed: {e!s}") + logger.error(traceback.format_exc()) + + +async def migra( + db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager +) -> None: + """ + Stores the migration logic here. + btw, i really don't like migration :( + """ + # 4.5 to 4.6 migration for umop_config_router + try: + await migrate_45_to_46(astrbot_config_mgr, umop_config_router) + except Exception as e: + logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}") + logger.error(traceback.format_exc()) + + # migration for webchat session + try: + await migrate_webchat_session(db) + except Exception as e: + logger.error(f"Migration for webchat session failed: {e!s}") + logger.error(traceback.format_exc()) + + # migra third party agent runner configs + _c = False + providers = astrbot_config["provider"] + ids_map = {} + for prov in providers: + type_ = prov.get("type") + if type_ in ["dify", "coze", "dashscope"]: + prov["provider_type"] = "agent_runner" + ids_map[prov["id"]] = { + "type": type_, + "id": prov["id"], + } + _c = True + if _c: + astrbot_config.save_config() + + for conf in acm.confs.values(): + _migra_agent_runner_configs(conf, ids_map) diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index c6b4c5ede..6b1f52a69 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -40,9 +40,6 @@ class SharedPreferences: else: ret = default return ret - raise ValueError( - "scope_id and key cannot be None when getting a specific preference.", - ) async def range_get_async( self, @@ -56,30 +53,6 @@ class SharedPreferences: ret = await self.db_helper.get_preferences(scope, scope_id, key) return ret - @overload - async def session_get( - self, - umo: None, - key: str, - default: Any = None, - ) -> list[Preference]: ... - - @overload - async def session_get( - self, - umo: str, - key: None, - default: Any = None, - ) -> list[Preference]: ... - - @overload - async def session_get( - self, - umo: None, - key: None, - default: Any = None, - ) -> list[Preference]: ... - async def session_get( self, umo: str | None, @@ -88,7 +61,7 @@ class SharedPreferences: ) -> _VT | list[Preference]: """获取会话范围的偏好设置 - Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。 + Note: 当 umo 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。 """ if umo is None or key is None: return await self.range_get_async("umo", umo, key) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 1ad789563..5381b5649 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -56,6 +56,7 @@ class ChatRoute(Route): self.conv_mgr = core_lifecycle.conversation_manager self.platform_history_mgr = core_lifecycle.platform_message_history_manager self.db = db + self.umop_config_router = core_lifecycle.umop_config_router self.running_convs: dict[str, bool] = {} @@ -266,7 +267,8 @@ class ChatRoute(Route): return Response().error("Permission denied").__dict__ # 删除该会话下的所有对话 - unified_msg_origin = f"{session.platform_id}:FriendMessage:{session.platform_id}!{username}!{session_id}" + message_type = "GroupMessage" if session.is_group else "FriendMessage" + unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}" await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) # 删除消息历史 @@ -276,6 +278,16 @@ class ChatRoute(Route): offset_sec=99999999, ) + # 删除与会话关联的配置路由 + try: + await self.umop_config_router.delete_route(unified_msg_origin) + except ValueError as exc: + logger.warning( + "Failed to delete UMO route %s during session cleanup: %s", + unified_msg_origin, + exc, + ) + # 清理队列(仅对 webchat) if session.platform_id == "webchat": webchat_queue_mgr.remove_queues(session_id) diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 09acd1b7e..ff0349409 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -87,6 +87,8 @@ :disabled="isStreaming || isConvRunning" :enableStreaming="enableStreaming" :isRecording="isRecording" + :session-id="currSessionId || null" + :current-session="getCurrentSession" @send="handleSendMessage" @toggleStreaming="toggleStreaming" @removeImage="removeImage" diff --git a/dashboard/src/components/chat/ChatInput.vue b/dashboard/src/components/chat/ChatInput.vue index 7ca0ec94a..cad1b7205 100644 --- a/dashboard/src/components/chat/ChatInput.vue +++ b/dashboard/src/components/chat/ChatInput.vue @@ -11,7 +11,13 @@ style="width: 100%; resize: none; outline: none; border: 1px solid var(--v-theme-border); border-radius: 12px; padding: 8px 16px; min-height: 40px; font-family: inherit; font-size: 16px; background-color: var(--v-theme-surface);">
- + + + + diff --git a/dashboard/src/components/chat/ProviderModelSelector.vue b/dashboard/src/components/chat/ProviderModelSelector.vue index 0b983e416..645ff5a55 100644 --- a/dashboard/src/components/chat/ProviderModelSelector.vue +++ b/dashboard/src/components/chat/ProviderModelSelector.vue @@ -3,6 +3,7 @@ + mdi-creation {{ selectedProviderId }} / {{ selectedModelName }} diff --git a/dashboard/src/components/provider/AddNewProvider.vue b/dashboard/src/components/provider/AddNewProvider.vue index b4cd1eb92..f876779fb 100644 --- a/dashboard/src/components/provider/AddNewProvider.vue +++ b/dashboard/src/components/provider/AddNewProvider.vue @@ -7,6 +7,10 @@ mdi-message-text {{ tm('dialogs.addProvider.tabs.basic') }} + + mdi-cogs + {{ tm('dialogs.addProvider.tabs.agentRunner') }} + mdi-microphone-message {{ tm('dialogs.addProvider.tabs.speechToText') }} @@ -27,7 +31,7 @@
- 接入 {{ name }} + {{ name }} {{ getProviderDescription(template, name) }} @@ -54,7 +58,7 @@ - {{ tm('dialogs.addProvider.noTemplates', { type: getTabTypeName(tabType) }) }} + {{ t('dialogs.addProvider.noTemplates') }} @@ -104,19 +108,6 @@ export default { this.$emit('update:show', value); } }, - - // 翻译消息的计算属性 - messages() { - return { - tabTypes: { - 'chat_completion': this.tm('providers.tabs.chatCompletion'), - 'speech_to_text': this.tm('providers.tabs.speechToText'), - 'text_to_speech': this.tm('providers.tabs.textToSpeech'), - 'embedding': this.tm('providers.tabs.embedding'), - 'rerank': this.tm('providers.tabs.rerank') - } - }; - } }, methods: { closeDialog() { @@ -140,11 +131,6 @@ export default { // 从工具函数导入 getProviderIcon, - // 获取Tab类型的中文名称 - getTabTypeName(tabType) { - return this.messages.tabTypes[tabType] || tabType; - }, - // 获取提供商简介 getProviderDescription(template, name) { return getProviderDescription(template, name, this.tm); diff --git a/dashboard/src/components/shared/AstrBotConfigV4.vue b/dashboard/src/components/shared/AstrBotConfigV4.vue index bfe99667b..e55f6f998 100644 --- a/dashboard/src/components/shared/AstrBotConfigV4.vue +++ b/dashboard/src/components/shared/AstrBotConfigV4.vue @@ -101,6 +101,21 @@ function shouldShowItem(itemMeta, itemKey) { return true } +// 检查最外层的 object 是否应该显示 +function shouldShowSection() { + const sectionMeta = props.metadata[props.metadataKey] + if (!sectionMeta?.condition) { + return true + } + for (const [conditionKey, expectedValue] of Object.entries(sectionMeta.condition)) { + const actualValue = getValueBySelector(props.iterable, conditionKey) + if (actualValue !== expectedValue) { + return false + } + } + return true +} + function hasVisibleItemsAfter(items, currentIndex) { const itemEntries = Object.entries(items) @@ -114,12 +129,33 @@ function hasVisibleItemsAfter(items, currentIndex) { return false } + +function parseSpecialValue(value) { + if (!value || typeof value !== 'string') { + return { name: '', subtype: '' } + } + const [name, ...rest] = value.split(':') + return { + name, + subtype: rest.join(':') || '' + } +} + +function getSpecialName(value) { + return parseSpecialValue(value).name +} + +function getSpecialSubtype(value) { + return parseSpecialValue(value).subtype +} + - +
diff --git a/dashboard/src/components/shared/ProviderSelector.vue b/dashboard/src/components/shared/ProviderSelector.vue index a1945f4fa..5a7994ca5 100644 --- a/dashboard/src/components/shared/ProviderSelector.vue +++ b/dashboard/src/components/shared/ProviderSelector.vue @@ -94,6 +94,10 @@ const props = defineProps({ type: String, default: 'chat_completion' }, + providerSubtype: { + type: String, + default: '' + }, buttonText: { type: String, default: '选择提供商...' @@ -127,7 +131,10 @@ async function loadProviders() { } }) if (response.data.status === 'ok') { - providerList.value = response.data.data || [] + const providers = response.data.data || [] + providerList.value = props.providerSubtype + ? providers.filter((provider) => matchesProviderSubtype(provider, props.providerSubtype)) + : providers } } catch (error) { console.error('加载提供商列表失败:', error) @@ -137,6 +144,17 @@ async function loadProviders() { } } +function matchesProviderSubtype(provider, subtype) { + if (!subtype) { + return true + } + const normalized = String(subtype).toLowerCase() + const candidates = [provider.type, provider.provider, provider.id] + .filter(Boolean) + .map((value) => String(value).toLowerCase()) + return candidates.includes(normalized) +} + function selectProvider(provider) { selectedProvider.value = provider.id } diff --git a/dashboard/src/composables/useSessions.ts b/dashboard/src/composables/useSessions.ts index f14e3aa11..c6000c784 100644 --- a/dashboard/src/composables/useSessions.ts +++ b/dashboard/src/composables/useSessions.ts @@ -4,8 +4,12 @@ import { useRouter } from 'vue-router'; export interface Session { session_id: string; - display_name: string; + display_name: string | null; updated_at: string; + platform_id: string; + creator: string; + is_group: number; + created_at: string; } export function useSessions(chatboxMode: boolean = false) { diff --git a/dashboard/src/i18n/locales/en-US/core/common.json b/dashboard/src/i18n/locales/en-US/core/common.json index 37b384199..931b62cfe 100644 --- a/dashboard/src/i18n/locales/en-US/core/common.json +++ b/dashboard/src/i18n/locales/en-US/core/common.json @@ -74,6 +74,7 @@ "delete": "Delete", "copy": "Copy", "edit": "Edit", + "copy": "Copy", "noData": "No data available" } } diff --git a/dashboard/src/i18n/locales/en-US/features/provider.json b/dashboard/src/i18n/locales/en-US/features/provider.json index 7888c9e0b..e08177d3d 100644 --- a/dashboard/src/i18n/locales/en-US/features/provider.json +++ b/dashboard/src/i18n/locales/en-US/features/provider.json @@ -9,6 +9,7 @@ "tabs": { "all": "All", "chatCompletion": "Chat Completion", + "agentRunner": "Agent Runner", "speechToText": "Speech to Text", "textToSpeech": "Text to Speech", "embedding": "Embedding", @@ -44,12 +45,13 @@ "title": "Service Provider", "tabs": { "basic": "Basic", + "agentRunner": "Agent Runner", "speechToText": "Speech to Text", "textToSpeech": "Text to Speech", "embedding": "Embedding", "rerank": "Rerank" }, - "noTemplates": "No {type} type provider templates available" + "noTemplates": "No this type provider templates available" }, "config": { "addTitle": "Add", diff --git a/dashboard/src/i18n/locales/zh-CN/features/provider.json b/dashboard/src/i18n/locales/zh-CN/features/provider.json index 913d74c30..234018829 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/provider.json +++ b/dashboard/src/i18n/locales/zh-CN/features/provider.json @@ -8,7 +8,8 @@ "providerType": "提供商类型", "tabs": { "all": "全部", - "chatCompletion": "基本对话", + "chatCompletion": "对话", + "agentRunner": "Agent 执行器", "speechToText": "语音转文字", "textToSpeech": "文字转语音", "embedding": "嵌入(Embedding)", @@ -44,13 +45,14 @@ "addProvider": { "title": "模型提供商", "tabs": { - "basic": "基本", + "basic": "对话", + "agentRunner": "Agent 执行器", "speechToText": "语音转文字", "textToSpeech": "文字转语音", "embedding": "嵌入(Embedding)", "rerank": "重排序(Rerank)" }, - "noTemplates": "暂无{type}类型的提供商模板" + "noTemplates": "暂无该类型的提供商模板" }, "config": { "addTitle": "新增", diff --git a/dashboard/src/views/ProviderPage.vue b/dashboard/src/views/ProviderPage.vue index e973ef624..32f5a60a8 100644 --- a/dashboard/src/views/ProviderPage.vue +++ b/dashboard/src/views/ProviderPage.vue @@ -30,6 +30,10 @@ mdi-message-text {{ tm('providers.tabs.chatCompletion') }}
+ + mdi-message-text + {{ tm('providers.tabs.agentRunner') }} + mdi-microphone-message {{ tm('providers.tabs.speechToText') }} @@ -48,30 +52,62 @@ - - - mdi-api-off -

{{ getEmptyText() }}

-
-
- - - - - - - - - + +
@@ -289,8 +325,8 @@ export default { "anthropic_chat_completion": "chat_completion", "googlegenai_chat_completion": "chat_completion", "zhipu_chat_completion": "chat_completion", - "dify": "chat_completion", - "coze": "chat_completion", + "dify": "agent_runner", + "coze": "agent_runner", "dashscope": "chat_completion", "openai_whisper_api": "speech_to_text", "openai_whisper_selfhost": "speech_to_text", @@ -334,6 +370,7 @@ export default { }, tabTypes: { 'chat_completion': this.tm('providers.tabs.chatCompletion'), + 'agent_runner': this.tm('providers.tabs.agentRunner'), 'speech_to_text': this.tm('providers.tabs.speechToText'), 'text_to_speech': this.tm('providers.tabs.textToSpeech'), 'embedding': this.tm('providers.tabs.embedding'), @@ -363,6 +400,52 @@ export default { }; }, + groupedProviders() { + if (!this.config_data.provider) { + return []; + } + + const typeOrder = [ + 'chat_completion', + 'agent_runner', + 'speech_to_text', + 'text_to_speech', + 'embedding', + 'rerank', + ]; + + const assigned = new Set(); + const groups = typeOrder + .map((typeKey) => { + const items = this.config_data.provider.filter((provider) => { + const resolved = this.getProviderType(provider); + if (resolved === typeKey) { + assigned.add(provider.id); + return true; + } + return false; + }); + return { + typeKey, + label: this.messages.tabTypes[typeKey] || typeKey, + items, + }; + }) + .filter((group) => group.items.length > 0); + + const remaining = this.config_data.provider.filter( + (provider) => !assigned.has(provider.id), + ); + if (remaining.length > 0) { + groups.push({ + typeKey: 'others', + label: this.tm('providers.tabs.all'), + items: remaining, + }); + } + return groups; + }, + // 根据选择的标签过滤提供商列表 filteredProviders() { if (!this.config_data.provider || this.activeProviderTypeTab === 'all') { @@ -371,13 +454,7 @@ export default { return this.config_data.provider.filter(provider => { // 如果provider.provider_type已经存在,直接使用它 - if (provider.provider_type) { - return provider.provider_type === this.activeProviderTypeTab; - } - - // 否则使用映射关系 - const mappedType = this.oldVersionProviderTypeMapping[provider.type]; - return mappedType === this.activeProviderTypeTab; + return this.getProviderType(provider) === this.activeProviderTypeTab; }); } }, @@ -387,6 +464,14 @@ export default { }, methods: { + getProviderType(provider) { + if (!provider) return undefined; + if (provider.provider_type) { + return provider.provider_type; + } + return this.oldVersionProviderTypeMapping[provider.type]; + }, + getConfig() { axios.get('/api/config/get').then((res) => { this.config_data = res.data.data.config; @@ -690,6 +775,9 @@ export default { if (!provider.enable) { throw new Error('该提供商未被用户启用'); } + if (provider.provider_type === 'agent_runner') { + throw new Error('暂时无法测试 Agent Runner 类型的提供商'); + } const res = await axios.get(`/api/config/provider/check_one?id=${provider.id}`); if (res.data && res.data.status === 'ok') { diff --git a/packages/astrbot/commands/conversation.py b/packages/astrbot/commands/conversation.py index 67402c660..cdffd3597 100644 --- a/packages/astrbot/commands/conversation.py +++ b/packages/astrbot/commands/conversation.py @@ -2,14 +2,18 @@ import datetime from astrbot.api import logger, sp, star from astrbot.api.event import AstrMessageEvent, MessageEventResult -from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.platform.astr_message_event import MessageSession from astrbot.core.platform.message_type import MessageType -from astrbot.core.provider.sources.coze_source import ProviderCoze -from astrbot.core.provider.sources.dify_source import ProviderDify from ..long_term_memory import LongTermMemory from .utils.rst_scene import RstScene +THIRD_PARTY_AGENT_RUNNER_KEY = { + "dify": "dify_conversation_id", + "coze": "coze_conversation_id", +} +THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys()) + class ConversationCommands: def __init__(self, context: star.Context, ltm: LongTermMemory | None = None): @@ -38,6 +42,7 @@ class ConversationCommands: async def reset(self, message: AstrMessageEvent): """重置 LLM 会话""" + umo = message.unified_msg_origin cfg = self.context.get_config(umo=message.unified_msg_origin) is_unique_session = cfg["platform_settings"]["unique_session"] is_group = bool(message.get_group_id()) @@ -62,28 +67,23 @@ class ConversationCommands: ) return - if not self.context.get_using_provider(message.unified_msg_origin): + agent_runner_type = cfg["provider_settings"]["agent_runner_type"] + if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: + await sp.remove_async( + scope="umo", + scope_id=umo, + key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type], + ) + message.set_result(MessageEventResult().message("重置对话成功。")) + return + + if not self.context.get_using_provider(umo): message.set_result( MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"), ) return - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type in ["dify", "coze"]: - assert isinstance(provider, (ProviderDify, ProviderCoze)), ( - "provider type is not dify or coze" - ) - await provider.forget(message.unified_msg_origin) - message.set_result( - MessageEventResult().message( - "已重置当前 Dify / Coze 会话,新聊天将更换到新的会话。", - ), - ) - return - - cid = await self.context.conversation_manager.get_curr_conversation_id( - message.unified_msg_origin, - ) + cid = await self.context.conversation_manager.get_curr_conversation_id(umo) if not cid: message.set_result( @@ -94,7 +94,7 @@ class ConversationCommands: return await self.context.conversation_manager.update_conversation( - message.unified_msg_origin, + umo, cid, [], ) @@ -151,29 +151,14 @@ class ConversationCommands: async def convs(self, message: AstrMessageEvent, page: int = 1): """查看对话列表""" - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type == "dify": - """原有的Dify处理逻辑保持不变""" - parts = ["Dify 对话列表:\n"] - assert isinstance(provider, ProviderDify) - data = await provider.api_client.get_chat_convs(message.unified_msg_origin) - idx = 1 - for conv in data["data"]: - ts_h = datetime.datetime.fromtimestamp(conv["updated_at"]).strftime( - "%m-%d %H:%M", - ) - parts.append( - f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n" - ) - idx += 1 - if idx == 1: - parts.append("没有找到任何对话。") - dify_cid = provider.conversation_ids.get(message.unified_msg_origin, None) - parts.append( - f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。" + cfg = self.context.get_config(umo=message.unified_msg_origin) + agent_runner_type = cfg["provider_settings"]["agent_runner_type"] + if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: + message.set_result( + MessageEventResult().message( + f"{THIRD_PARTY_AGENT_RUNNER_STR} 对话列表功能暂不支持。", + ), ) - ret = "".join(parts) - message.set_result(MessageEventResult().message(ret)) return size_per_page = 6 @@ -241,15 +226,15 @@ class ConversationCommands: async def new_conv(self, message: AstrMessageEvent): """创建新对话""" - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type in ["dify", "coze"]: - assert isinstance(provider, (ProviderDify, ProviderCoze)), ( - "provider type is not dify or coze" - ) - await provider.forget(message.unified_msg_origin) - message.set_result( - MessageEventResult().message("成功,下次聊天将是新对话。"), + cfg = self.context.get_config(umo=message.unified_msg_origin) + agent_runner_type = cfg["provider_settings"]["agent_runner_type"] + if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: + await sp.remove_async( + scope="umo", + scope_id=message.unified_msg_origin, + key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type], ) + message.set_result(MessageEventResult().message("已创建新对话。")) return cpersona = await self._get_current_persona_id(message.unified_msg_origin) @@ -272,19 +257,9 @@ class ConversationCommands: async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""): """创建新群聊对话""" - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type in ["dify", "coze"]: - assert isinstance(provider, (ProviderDify, ProviderCoze)), ( - "provider type is not dify or coze" - ) - await provider.forget(message.unified_msg_origin) - message.set_result( - MessageEventResult().message("成功,下次聊天将是新对话。"), - ) - return if sid: session = str( - MessageSesion( + MessageSession( platform_name=message.platform_meta.id, message_type=MessageType("GroupMessage"), session_id=sid, @@ -319,31 +294,6 @@ class ConversationCommands: ) return - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify), "provider type is not dify" - data = await provider.api_client.get_chat_convs(message.unified_msg_origin) - if not data["data"]: - message.set_result(MessageEventResult().message("未找到任何对话。")) - return - selected_conv = None - if index is not None: - try: - selected_conv = data["data"][index - 1] - except IndexError: - message.set_result( - MessageEventResult().message("对话序号错误,请使用 /ls 查看"), - ) - return - else: - selected_conv = data["data"][0] - ret = ( - f"Dify 切换到对话: {selected_conv['name']}({selected_conv['id'][:4]})。" - ) - provider.conversation_ids[message.unified_msg_origin] = selected_conv["id"] - message.set_result(MessageEventResult().message(ret)) - return - if index is None: message.set_result( MessageEventResult().message( @@ -376,19 +326,6 @@ class ConversationCommands: if not new_name: message.set_result(MessageEventResult().message("请输入新的对话名称。")) return - - provider = self.context.get_using_provider(message.unified_msg_origin) - - if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify) - cid = provider.conversation_ids.get(message.unified_msg_origin, None) - if not cid: - message.set_result(MessageEventResult().message("未找到当前对话。")) - return - await provider.api_client.rename(cid, new_name, message.unified_msg_origin) - message.set_result(MessageEventResult().message("重命名对话成功。")) - return - await self.context.conversation_manager.update_conversation_title( message.unified_msg_origin, new_name, @@ -408,20 +345,14 @@ class ConversationCommands: ) return - provider = self.context.get_using_provider(message.unified_msg_origin) - if provider and provider.meta().type == "dify": - assert isinstance(provider, ProviderDify) - dify_cid = provider.conversation_ids.pop(message.unified_msg_origin, None) - if dify_cid: - await provider.api_client.delete_chat_conv( - message.unified_msg_origin, - dify_cid, - ) - message.set_result( - MessageEventResult().message( - "删除当前对话成功。不再处于对话状态,使用 /switch 序号 切换到其他对话或 /new 创建。", - ), + agent_runner_type = cfg["provider_settings"]["agent_runner_type"] + if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY: + await sp.remove_async( + scope="umo", + scope_id=message.unified_msg_origin, + key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type], ) + message.set_result(MessageEventResult().message("重置对话成功。")) return session_curr_cid = ( diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 969343fe0..c1466ef3c 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -5,7 +5,6 @@ from astrbot.api.event import AstrMessageEvent, filter from astrbot.api.message_components import Image, Plain from astrbot.api.provider import LLMResponse, ProviderRequest from astrbot.core import logger -from astrbot.core.provider.sources.dify_source import ProviderDify from .commands import ( AdminCommands, @@ -279,33 +278,20 @@ class Main(star.Star): return try: conv = None - if provider.meta().type != "dify": - session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id( - event.unified_msg_origin, - ) + session_curr_cid = await self.context.conversation_manager.get_curr_conversation_id( + event.unified_msg_origin, + ) - if not session_curr_cid: - logger.error( - "当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。", - ) - return + if not session_curr_cid: + logger.error( + "当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。", + ) + return - conv = await self.context.conversation_manager.get_conversation( - event.unified_msg_origin, - session_curr_cid, - ) - else: - # Dify 自己有维护对话,不需要 bot 端维护。 - assert isinstance(provider, ProviderDify) - cid = provider.conversation_ids.get( - event.unified_msg_origin, - None, - ) - if cid is None: - logger.error( - "[Dify] 当前未处于对话状态,无法主动回复,请确保 平台设置->会话隔离(unique_session) 未开启,并使用 /switch 序号 切换或者 /new 创建一个会话。", - ) - return + conv = await self.context.conversation_manager.get_conversation( + event.unified_msg_origin, + session_curr_cid, + ) prompt = event.message_str