diff --git a/astrbot/core/agent/runners/base.py b/astrbot/core/agent/runners/base.py index f7e0913b4..c53bdc0dc 100644 --- a/astrbot/core/agent/runners/base.py +++ b/astrbot/core/agent/runners/base.py @@ -2,13 +2,12 @@ import abc import typing as T from enum import Enum, auto -from astrbot.core.provider import Provider +from astrbot import logger from astrbot.core.provider.entities import LLMResponse from ..hooks import BaseAgentRunHooks from ..response import AgentResponse from ..run_context import ContextWrapper, TContext -from ..tool_executor import BaseFunctionToolExecutor class AgentState(Enum): @@ -24,9 +23,7 @@ class BaseAgentRunner(T.Generic[TContext]): @abc.abstractmethod async def reset( self, - provider: Provider, run_context: ContextWrapper[TContext], - tool_executor: BaseFunctionToolExecutor[TContext], agent_hooks: BaseAgentRunHooks[TContext], **kwargs: T.Any, ) -> None: @@ -60,3 +57,9 @@ class BaseAgentRunner(T.Generic[TContext]): This method should be called after the agent is done. """ ... + + def _transition_state(self, new_state: AgentState) -> None: + """Transition the agent state.""" + if self._state != new_state: + logger.debug(f"Dify Agent state transition: {self._state} -> {new_state}") + self._state = new_state diff --git a/astrbot/core/agent/runners/coze_agent_runner.py b/astrbot/core/agent/runners/coze_agent_runner.py new file mode 100644 index 000000000..885eef699 --- /dev/null +++ b/astrbot/core/agent/runners/coze_agent_runner.py @@ -0,0 +1,367 @@ +import base64 +import json +import sys +import typing as T + +import astrbot.core.message.components as Comp +from astrbot import logger +from astrbot.core import sp +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import ( + LLMResponse, + ProviderRequest, +) +from astrbot.core.provider.sources.coze_api_client import CozeAPIClient + +from ..hooks import BaseAgentRunHooks +from ..response import AgentResponseData +from ..run_context import ContextWrapper, TContext +from .base import AgentResponse, AgentState, BaseAgentRunner + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +class CozeAgentRunner(BaseAgentRunner[TContext]): + """Coze Agent Runner""" + + @override + async def reset( + self, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.final_llm_resp = None + self._state = AgentState.IDLE + self.agent_hooks = agent_hooks + self.run_context = run_context + + self.api_key = provider_config.get("coze_api_key", "") + if not self.api_key: + raise Exception("Coze API Key 不能为空。") + self.bot_id = provider_config.get("bot_id", "") + if not self.bot_id: + raise Exception("Coze Bot ID 不能为空。") + self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn") + + if not isinstance(self.api_base, str) or not self.api_base.startswith( + ("http://", "https://"), + ): + raise Exception( + "Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。", + ) + + self.timeout = provider_config.get("timeout", 120) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + self.auto_save_history = provider_config.get("auto_save_history", True) + + # 创建 API 客户端 + self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base) + + # 会话相关缓存 + self.file_id_cache: dict[str, dict[str, str]] = {} + + @override + async def step(self): + """ + 执行 Coze Agent 的一个步骤 + """ + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + + if self._state == AgentState.IDLE: + try: + await self.agent_hooks.on_agent_begin(self.run_context) + except Exception as e: + logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True) + + # 开始处理,转换到运行状态 + self._transition_state(AgentState.RUNNING) + + try: + # 执行 Coze 请求并处理结果 + async for response in self._execute_coze_request(): + yield response + except Exception as e: + logger.error(f"Coze 请求失败:{str(e)}") + self._transition_state(AgentState.ERROR) + self.final_llm_resp = LLMResponse( + role="err", completion_text=f"Coze 请求失败:{str(e)}" + ) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=MessageChain().message(f"Coze 请求失败:{str(e)}") + ), + ) + finally: + await self.api_client.close() + + @override + async def step_until_done( + self, max_step: int = 30 + ) -> T.AsyncGenerator[AgentResponse, None]: + while not self.done(): + async for resp in self.step(): + yield resp + + async def _execute_coze_request(self): + """执行 Coze 请求的核心逻辑""" + prompt = self.req.prompt or "" + session_id = self.req.session_id or "unknown" + image_urls = self.req.image_urls or [] + contexts = self.req.contexts or [] + system_prompt = self.req.system_prompt + + # 用户ID参数 + user_id = session_id + + # 获取或创建会话ID + conversation_id = await sp.get_async( + scope="umo", + scope_id=user_id, + key="coze_conversation_id", + default="", + ) + + # 构建消息 + additional_messages = [] + + if system_prompt: + if not self.auto_save_history or not conversation_id: + additional_messages.append( + { + "role": "system", + "content": system_prompt, + "content_type": "text", + }, + ) + + # 处理历史上下文 + if not self.auto_save_history and contexts: + for ctx in contexts: + if isinstance(ctx, dict) and "role" in ctx and "content" in ctx: + # 处理上下文中的图片 + content = ctx["content"] + if isinstance(content, list): + # 多模态内容,需要处理图片 + processed_content = [] + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + processed_content.append(item) + elif item.get("type") == "image_url": + # 处理图片上传 + try: + image_data = item.get("image_url", {}) + url = image_data.get("url", "") + if url: + file_id = ( + await self._download_and_upload_image( + url, session_id + ) + ) + processed_content.append( + { + "type": "file", + "file_id": file_id, + "file_url": url, + } + ) + except Exception as e: + logger.warning(f"处理上下文图片失败: {e}") + continue + + if processed_content: + additional_messages.append( + { + "role": ctx["role"], + "content": processed_content, + "content_type": "object_string", + } + ) + else: + # 纯文本内容 + additional_messages.append( + { + "role": ctx["role"], + "content": content, + "content_type": "text", + } + ) + + # 构建当前消息 + if prompt or image_urls: + if image_urls: + # 多模态 + object_string_content = [] + if prompt: + object_string_content.append({"type": "text", "text": prompt}) + + for url in image_urls: + # the url is a base64 string + try: + image_data = base64.b64decode(url) + file_id = await self.api_client.upload_file(image_data) + object_string_content.append( + { + "type": "image", + "file_id": file_id, + } + ) + except Exception as e: + logger.warning(f"处理图片失败 {url}: {e}") + continue + + if object_string_content: + content = json.dumps(object_string_content, ensure_ascii=False) + additional_messages.append( + { + "role": "user", + "content": content, + "content_type": "object_string", + } + ) + elif prompt: + # 纯文本 + additional_messages.append( + { + "role": "user", + "content": prompt, + "content_type": "text", + }, + ) + + # 执行 Coze API 请求 + accumulated_content = "" + message_started = False + + async for chunk in self.api_client.chat_messages( + bot_id=self.bot_id, + user_id=user_id, + additional_messages=additional_messages, + conversation_id=conversation_id, + auto_save_history=self.auto_save_history, + stream=True, + timeout=self.timeout, + ): + event_type = chunk.get("event") + data = chunk.get("data", {}) + + if event_type == "conversation.chat.created": + if isinstance(data, dict) and "conversation_id" in data: + await sp.put_async( + scope="umo", + scope_id=user_id, + key="coze_conversation_id", + value=data["conversation_id"], + ) + + if event_type == "conversation.message.delta": + # 增量消息 + content = data.get("content", "") + if not content and "delta" in data: + content = data["delta"].get("content", "") + if not content and "text" in data: + content = data.get("text", "") + + if content: + accumulated_content += content + message_started = True + + # 如果是流式响应,发送增量数据 + if self.streaming: + yield AgentResponse( + type="streaming_delta", + data=AgentResponseData( + chain=MessageChain().message(content) + ), + ) + + elif event_type == "conversation.message.completed": + # 消息完成 + logger.debug("Coze message completed") + message_started = True + + elif event_type == "conversation.chat.completed": + # 对话完成 + logger.debug("Coze chat completed") + break + + elif event_type == "error": + # 错误处理 + error_msg = data.get("msg", "未知错误") + error_code = data.get("code", "UNKNOWN") + logger.error(f"Coze 出现错误: {error_code} - {error_msg}") + raise Exception(f"Coze 出现错误: {error_code} - {error_msg}") + + if not message_started and not accumulated_content: + logger.warning("Coze 未返回任何内容") + accumulated_content = "" + + # 创建最终响应 + chain = MessageChain(chain=[Comp.Plain(accumulated_content)]) + self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain) + self._transition_state(AgentState.DONE) + + try: + await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + + # 返回最终结果 + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=chain), + ) + + async def _download_and_upload_image( + self, + image_url: str, + session_id: str | None = None, + ) -> str: + """下载图片并上传到 Coze,返回 file_id""" + import hashlib + + # 计算哈希实现缓存 + cache_key = hashlib.md5(image_url.encode("utf-8")).hexdigest() + + if session_id: + if session_id not in self.file_id_cache: + self.file_id_cache[session_id] = {} + + if cache_key in self.file_id_cache[session_id]: + file_id = self.file_id_cache[session_id][cache_key] + logger.debug(f"[Coze] 使用缓存的 file_id: {file_id}") + return file_id + + try: + image_data = await self.api_client.download_image(image_url) + file_id = await self.api_client.upload_file(image_data) + + if session_id: + self.file_id_cache[session_id][cache_key] = file_id + logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}") + + return file_id + + except Exception as e: + logger.error(f"处理图片失败 {image_url}: {e!s}") + raise Exception(f"处理图片失败: {e!s}") + + @override + def done(self) -> bool: + """检查 Agent 是否已完成工作""" + return self._state in (AgentState.DONE, AgentState.ERROR) + + @override + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp diff --git a/astrbot/core/agent/runners/dashscope_agent_runner.py b/astrbot/core/agent/runners/dashscope_agent_runner.py new file mode 100644 index 000000000..6a7cada5f --- /dev/null +++ b/astrbot/core/agent/runners/dashscope_agent_runner.py @@ -0,0 +1,273 @@ +import asyncio +import functools +import re +import sys +import typing as T + +from dashscope import Application +from dashscope.app.application_response import ApplicationResponse + +import astrbot.core.message.components as Comp +from astrbot.core import logger, sp +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import ( + LLMResponse, + ProviderRequest, +) + +from ..hooks import BaseAgentRunHooks +from ..response import AgentResponseData +from ..run_context import ContextWrapper, TContext +from .base import AgentResponse, AgentState, BaseAgentRunner + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + + +class DashscopeAgentRunner(BaseAgentRunner[TContext]): + """Dashscope Agent Runner""" + + @override + async def reset( + self, + request: ProviderRequest, + run_context: ContextWrapper[TContext], + agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, + **kwargs: T.Any, + ) -> None: + self.req = request + self.streaming = kwargs.get("streaming", False) + self.final_llm_resp = None + self._state = AgentState.IDLE + self.agent_hooks = agent_hooks + self.run_context = run_context + + self.api_key = provider_config.get("dashscope_api_key", "") + if not self.api_key: + raise Exception("阿里云百炼 API Key 不能为空。") + self.app_id = provider_config.get("dashscope_app_id", "") + if not self.app_id: + raise Exception("阿里云百炼 APP ID 不能为空。") + self.dashscope_app_type = provider_config.get("dashscope_app_type", "") + if not self.dashscope_app_type: + raise Exception("阿里云百炼 APP 类型不能为空。") + + self.variables: dict = provider_config.get("variables", {}) or {} + self.rag_options: dict = provider_config.get("rag_options", {}) + self.output_reference = self.rag_options.get("output_reference", False) + self.rag_options = self.rag_options.copy() + self.rag_options.pop("output_reference", None) + + self.timeout = provider_config.get("timeout", 120) + if isinstance(self.timeout, str): + self.timeout = int(self.timeout) + + def has_rag_options(self): + """判断是否有 RAG 选项 + + Returns: + bool: 是否有 RAG 选项 + + """ + if self.rag_options and ( + len(self.rag_options.get("pipeline_ids", [])) > 0 + or len(self.rag_options.get("file_ids", [])) > 0 + ): + return True + return False + + @override + async def step(self): + """ + 执行 Dashscope Agent 的一个步骤 + """ + if not self.req: + raise ValueError("Request is not set. Please call reset() first.") + + if self._state == AgentState.IDLE: + try: + await self.agent_hooks.on_agent_begin(self.run_context) + except Exception as e: + logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True) + + # 开始处理,转换到运行状态 + self._transition_state(AgentState.RUNNING) + + try: + # 执行 Dashscope 请求并处理结果 + async for response in self._execute_dashscope_request(): + yield response + except Exception as e: + logger.error(f"阿里云百炼请求失败:{str(e)}") + self._transition_state(AgentState.ERROR) + self.final_llm_resp = LLMResponse( + role="err", completion_text=f"阿里云百炼请求失败:{str(e)}" + ) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=MessageChain().message(f"阿里云百炼请求失败:{str(e)}") + ), + ) + + @override + async def step_until_done( + self, max_step: int = 30 + ) -> T.AsyncGenerator[AgentResponse, None]: + while not self.done(): + async for resp in self.step(): + yield resp + + async def _remove_image_from_context(self, contexts: list) -> list: + """移除上下文中的图片内容""" + result = [] + for ctx in contexts: + if isinstance(ctx, dict): + content = ctx.get("content", "") + if isinstance(content, list): + # 只保留文本内容 + text_parts = [ + item.get("text", "") + for item in content + if isinstance(item, dict) and item.get("type") == "text" + ] + if text_parts: + new_ctx = ctx.copy() + new_ctx["content"] = " ".join(text_parts) + result.append(new_ctx) + else: + result.append(ctx) + else: + result.append(ctx) + return result + + async def _execute_dashscope_request(self): + """执行 Dashscope 请求的核心逻辑""" + prompt = self.req.prompt or "" + session_id = self.req.session_id or "unknown" + image_urls = self.req.image_urls or [] + contexts = self.req.contexts or [] + system_prompt = self.req.system_prompt + + # 获得会话变量 + payload_vars = self.variables.copy() + # 动态变量 + session_var = await sp.get_async( + scope="umo", + scope_id=session_id, + key="session_variables", + default={}, + ) + payload_vars.update(session_var) + + if ( + self.dashscope_app_type in ["agent", "dialog-workflow"] + and not self.has_rag_options() + ): + # 支持多轮对话的 + new_record = {"role": "user", "content": prompt} + if image_urls: + logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。") + contexts_no_img = await self._remove_image_from_context(contexts) + context_query = [*contexts_no_img, new_record] + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) + for part in context_query: + if "_no_save" in part: + del part["_no_save"] + # 调用阿里云百炼 API + payload = { + "app_id": self.app_id, + "api_key": self.api_key, + "messages": context_query, + "biz_params": payload_vars or None, + } + partial = functools.partial( + Application.call, + **payload, + ) + response = await asyncio.get_event_loop().run_in_executor(None, partial) + else: + # 不支持多轮对话的 + # 调用阿里云百炼 API + payload = { + "app_id": self.app_id, + "prompt": prompt, + "api_key": self.api_key, + "biz_params": payload_vars or None, + } + if self.rag_options: + payload["rag_options"] = self.rag_options + partial = functools.partial( + Application.call, + **payload, + ) + response = await asyncio.get_event_loop().run_in_executor(None, partial) + + assert isinstance(response, ApplicationResponse) + + logger.debug(f"dashscope resp: {response}") + + if response.status_code != 200: + logger.error( + f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code", + ) + self._transition_state(AgentState.ERROR) + self.final_llm_resp = LLMResponse( + role="err", + result_chain=MessageChain().message( + f"阿里云百炼请求失败: message={response.message} code={response.status_code}", + ), + ) + yield AgentResponse( + type="err", + data=AgentResponseData( + chain=MessageChain().message( + f"阿里云百炼请求失败: message={response.message} code={response.status_code}" + ) + ), + ) + return + + output_text = response.output.get("text", "") or "" + # RAG 引用脚标格式化 + output_text = re.sub(r"\[(\d+)\]", r"[\1]", output_text) + if self.output_reference and response.output.get("doc_references", None): + ref_parts = [] + for ref in response.output.get("doc_references", []) or []: + ref_title = ( + ref.get("title", "") + if ref.get("title") + else ref.get("doc_name", "") + ) + ref_parts.append(f"{ref['index_id']}. {ref_title}\n") + ref_str = "".join(ref_parts) + output_text += f"\n\n回答来源:\n{ref_str}" + + # 创建最终响应 + chain = MessageChain(chain=[Comp.Plain(output_text)]) + self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain) + self._transition_state(AgentState.DONE) + + try: + await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp) + except Exception as e: + logger.error(f"Error in on_agent_done hook: {e}", exc_info=True) + + # 返回最终结果 + yield AgentResponse( + type="llm_result", + data=AgentResponseData(chain=chain), + ) + + @override + def done(self) -> bool: + """检查 Agent 是否已完成工作""" + return self._state in (AgentState.DONE, AgentState.ERROR) + + @override + def get_final_llm_resp(self) -> LLMResponse | None: + return self.final_llm_resp diff --git a/astrbot/core/agent/runners/dify_agent_runner.py b/astrbot/core/agent/runners/dify_agent_runner.py index 145e4a627..1433cfdf4 100644 --- a/astrbot/core/agent/runners/dify_agent_runner.py +++ b/astrbot/core/agent/runners/dify_agent_runner.py @@ -1,22 +1,23 @@ -import sys +import base64 import os +import sys import typing as T -from .base import BaseAgentRunner, AgentResponse, AgentState -from ..hooks import BaseAgentRunHooks -from ..tool_executor import BaseFunctionToolExecutor -from ..run_context import ContextWrapper, TContext -from ..response import AgentResponseData -from astrbot.core.provider.provider import Provider + +import astrbot.core.message.components as Comp +from astrbot.core import logger, sp from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.entities import ( - ProviderRequest, LLMResponse, + ProviderRequest, ) -from astrbot.core.utils.dify_api_client import DifyAPIClient -from astrbot.core.utils.io import download_image_by_url, download_file from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core import logger, sp -import astrbot.core.message.components as Comp +from astrbot.core.utils.dify_api_client import DifyAPIClient +from astrbot.core.utils.io import download_file + +from ..hooks import BaseAgentRunHooks +from ..response import AgentResponseData +from ..run_context import ContextWrapper, TContext +from .base import AgentResponse, AgentState, BaseAgentRunner if sys.version_info >= (3, 12): from typing import override @@ -30,53 +31,36 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): @override async def reset( self, - provider: Provider, request: ProviderRequest, run_context: ContextWrapper[TContext], - tool_executor: BaseFunctionToolExecutor[TContext], agent_hooks: BaseAgentRunHooks[TContext], + provider_config: dict, **kwargs: T.Any, ) -> None: self.req = request self.streaming = kwargs.get("streaming", False) - self.provider = provider self.final_llm_resp = None self._state = AgentState.IDLE - self.tool_executor = tool_executor self.agent_hooks = agent_hooks self.run_context = run_context - # Dify 特定配置 - 从 provider 或 kwargs 中获取 - self.api_key = kwargs.get("dify_api_key", "") - api_base = kwargs.get("dify_api_base", "https://api.dify.ai/v1") - self.api_type = kwargs.get("dify_api_type", "") - - self.workflow_output_key = kwargs.get( - "dify_workflow_output_key", "astrbot_wf_output" + self.api_key = provider_config.get("dify_api_key", "") + self.api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1") + self.api_type = provider_config.get("dify_api_type", "chat") + self.workflow_output_key = provider_config.get( + "dify_workflow_output_key", + "astrbot_wf_output", ) - self.dify_query_input_key = kwargs.get( - "dify_query_input_key", "astrbot_text_query" + self.dify_query_input_key = provider_config.get( + "dify_query_input_key", + "astrbot_text_query", ) - if not self.dify_query_input_key: - self.dify_query_input_key = "astrbot_text_query" - if not self.workflow_output_key: - self.workflow_output_key = "astrbot_wf_output" - - self.variables: dict = kwargs.get("variables", {}) - self.timeout = kwargs.get("timeout", 120) + self.variables: dict = provider_config.get("variables", {}) or {} + self.timeout = provider_config.get("timeout", 60) if isinstance(self.timeout, str): self.timeout = int(self.timeout) - self.conversation_ids = {} - """记录当前 session id 的对话 ID""" - - self.api_client = DifyAPIClient(self.api_key, api_base) - - def _transition_state(self, new_state: AgentState) -> None: - """转换 Agent 状态""" - if self._state != new_state: - logger.debug(f"Dify Agent state transition: {self._state} -> {new_state}") - self._state = new_state + self.api_client = DifyAPIClient(self.api_key, self.api_base) @override async def step(self): @@ -111,6 +95,16 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): chain=MessageChain().message(f"Dify 请求失败:{str(e)}") ), ) + finally: + await self.api_client.close() + + @override + async def step_until_done( + self, max_step: int = 30 + ) -> T.AsyncGenerator[AgentResponse, None]: + while not self.done(): + async for resp in self.step(): + yield resp async def _execute_dify_request(self): """执行 Dify 请求的核心逻辑""" @@ -119,20 +113,22 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): image_urls = self.req.image_urls or [] system_prompt = self.req.system_prompt - conversation_id = self.conversation_ids.get(session_id, "") + conversation_id = await sp.get_async( + scope="umo", + scope_id=session_id, + key="dify_conversation_id", + default="", + ) result = "" # 处理图片上传 files_payload = [] for image_url in image_urls: + # image_url is a base64 string try: - image_path = ( - await download_image_by_url(image_url) - if image_url.startswith("http") - else image_url - ) + image_data = base64.b64decode(image_url) file_response = await self.api_client.file_upload( - image_path, user=session_id + file_data=image_data, user=session_id ) logger.debug(f"Dify 上传图片响应:{file_response}") if "id" not in file_response: @@ -154,7 +150,12 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): # 获得会话变量 payload_vars = self.variables.copy() # 动态变量 - session_var = await sp.session_get(session_id, "session_variables", default={}) + session_var = await sp.get_async( + scope="umo", + scope_id=session_id, + key="session_variables", + default={}, + ) payload_vars.update(session_var) payload_vars["system_prompt"] = system_prompt @@ -178,7 +179,12 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): if chunk["event"] == "message" or chunk["event"] == "agent_message": result += chunk["answer"] if not conversation_id: - self.conversation_ids[session_id] = chunk["conversation_id"] + await sp.put_async( + scope="umo", + scope_id=session_id, + key="dify_conversation_id", + value=chunk["conversation_id"], + ) conversation_id = chunk["conversation_id"] # 如果是流式响应,发送增量数据 @@ -314,13 +320,3 @@ class DifyAgentRunner(BaseAgentRunner[TContext]): @override def get_final_llm_resp(self) -> LLMResponse | None: return self.final_llm_resp - - async def forget(self, session_id): - """忘记会话上下文""" - self.conversation_ids[session_id] = "" - return True - - async def terminate(self): - """终止并清理资源""" - if hasattr(self, "api_client"): - await self.api_client.close() diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index d74a45982..6f3c813eb 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -69,12 +69,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): ) self.run_context.messages = messages - def _transition_state(self, new_state: AgentState) -> None: - """转换 Agent 状态""" - if self._state != new_state: - logger.debug(f"Agent state transition: {self._state} -> {new_state}") - self._state = new_state - async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: """Yields chunks *and* a final LLMResponse.""" if self.streaming: diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index cf3658670..4e1101010 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -69,8 +69,9 @@ DEFAULT_CONFIG = { "streaming_response": False, "show_tool_use_status": False, "agent_runner_type": "local", - "dify_runner_provider_id": "", - "coze_runner_provider_id": "", + "dify_agent_runner_provider_id": "", + "coze_agent_runner_provider_id": "", + "dashscope_agent_runner_provider_id": "", "unsupported_streaming_strategy": "realtime_segmenting", "max_agent_step": 30, "tool_call_timeout": 60, @@ -1041,7 +1042,7 @@ CONFIG_METADATA_2 = { "id": "dashscope", "provider": "dashscope", "type": "dashscope", - "provider_type": "chat_completion", + "provider_type": "agent_runner", "enable": True, "dashscope_app_type": "agent", "dashscope_api_key": "", @@ -2042,10 +2043,13 @@ CONFIG_METADATA_2 = { "agent_runner_type": { "type": "string", }, - "dify_runner_provider_id": { + "dify_agent_runner_provider_id": { "type": "string", }, - "coze_runner_provider_id": { + "coze_agent_runner_provider_id": { + "type": "string", + }, + "dashscope_agent_runner_provider_id": { "type": "string", }, "max_agent_step": { @@ -2201,42 +2205,36 @@ CONFIG_METADATA_3 = { "provider_settings.agent_runner_type": { "description": "执行器", "type": "string", - "options": ["local", "dify", "coze"], - "labels": ["内置 Agent", "Dify", "Coze"], + "options": ["local", "dify", "coze", "dashscope"], + "labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"], "condition": { "provider_settings.enable": True, }, }, - }, - }, - "dify_runner": { - "description": "Dify", - "type": "object", - "items": { - "provider_settings.dify_runner_provider_id": { - "description": "Dify 执行器提供商 ID", + "provider_settings.coze_agent_runner_provider_id": { + "description": "Coze Agent 执行器提供商 ID", "type": "string", - "_special": "select_provider_dify_runner", + "_special": "select_agent_runner_provider", + "condition": { + "provider_settings.agent_runner_type": "coze", + }, }, - }, - "condition": { - "provider_settings.agent_runner_type": "dify", - "provider_settings.enable": True, - }, - }, - "coze_runner": { - "description": "Coze", - "type": "object", - "items": { - "provider_settings.coze_runner_provider_id": { - "description": "Coze 执行器提供商 ID", + "provider_settings.dify_agent_runner_provider_id": { + "description": "Dify Agent 执行器提供商 ID", "type": "string", - "_special": "select_provider_coze_runner", + "_special": "select_agent_runner_provider", + "condition": { + "provider_settings.agent_runner_type": "dify", + }, + }, + "provider_settings.dashscope_agent_runner_provider_id": { + "description": "阿里云百炼应用 Agent 执行器提供商 ID", + "type": "string", + "_special": "select_agent_runner_provider", + "condition": { + "provider_settings.agent_runner_type": "dashscope", + }, }, - }, - "condition": { - "provider_settings.agent_runner_type": "coze", - "provider_settings.enable": True, }, }, "ai": { @@ -2248,9 +2246,6 @@ CONFIG_METADATA_3 = { "type": "string", "_special": "select_provider", "hint": "留空时使用第一个模型", - "condition": { - "provider_settings.agent_runner_type": "local", - }, }, "provider_settings.default_image_caption_provider_id": { "description": "默认图片转述模型", diff --git a/astrbot/core/pipeline/process_stage/method/agent_request.py b/astrbot/core/pipeline/process_stage/method/agent_request.py new file mode 100644 index 000000000..f6f81631e --- /dev/null +++ b/astrbot/core/pipeline/process_stage/method/agent_request.py @@ -0,0 +1,48 @@ +from collections.abc import AsyncGenerator + +from astrbot.core import logger +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.star.session_llm_manager import SessionServiceManager + +from ...context import PipelineContext +from ..stage import Stage +from .agent_sub_stages.internal import InternalAgentSubStage +from .agent_sub_stages.third_party import ThirdPartyAgentSubStage + + +class AgentRequestSubStage(Stage): + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + self.config = ctx.astrbot_config + + self.bot_wake_prefixs: list[str] = self.config["wake_prefix"] + self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"] + for bwp in self.bot_wake_prefixs: + if self.prov_wake_prefix.startswith(bwp): + logger.info( + f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。", + ) + self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :] + + agent_runner_type = self.config["provider_settings"]["agent_runner_type"] + if agent_runner_type == "local": + self.agent_sub_stage = InternalAgentSubStage() + else: + self.agent_sub_stage = ThirdPartyAgentSubStage() + await self.agent_sub_stage.initialize(ctx) + + async def process(self, event: AstrMessageEvent) -> AsyncGenerator[None, None]: + if not self.ctx.astrbot_config["provider_settings"]["enable"]: + logger.debug( + "This pipeline does not enable AI capability, skip processing." + ) + return + + if not SessionServiceManager.should_process_llm_request(event): + logger.debug( + f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing." + ) + return + + async for resp in self.agent_sub_stage.process(event, self.prov_wake_prefix): + yield resp diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py similarity index 91% rename from astrbot/core/pipeline/process_stage/method/llm_request.py rename to astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index bd9e4ce3b..aaada9a19 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -21,27 +21,24 @@ from astrbot.core.provider.entities import ( LLMResponse, ProviderRequest, ) -from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.utils.metrics import Metric from astrbot.core.utils.session_lock import session_lock_manager -from ....astr_agent_context import AgentContextWrapper -from ....astr_agent_hooks import MAIN_AGENT_HOOKS -from ....astr_agent_run_util import AgentRunner, run_agent -from ....astr_agent_tool_exec import FunctionToolExecutor -from ...context import PipelineContext, call_event_hook -from ..stage import Stage -from ..utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base +from .....astr_agent_context import AgentContextWrapper +from .....astr_agent_hooks import MAIN_AGENT_HOOKS +from .....astr_agent_run_util import AgentRunner, run_agent +from .....astr_agent_tool_exec import FunctionToolExecutor +from ....context import PipelineContext, call_event_hook +from ...stage import Stage +from ...utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base -class LLMRequestSubStage(Stage): +class InternalAgentSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.ctx = ctx conf = ctx.astrbot_config settings = conf["provider_settings"] - self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list - self.provider_wake_prefix: str = settings["wake_prefix"] # str self.max_context_length = settings["max_context_length"] # int self.dequeue_context_length: int = min( max(1, settings["dequeue_context_length"]), @@ -59,13 +56,6 @@ class LLMRequestSubStage(Stage): self.show_reasoning = settings.get("display_reasoning_text", False) self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False) - for bwp in self.bot_wake_prefixs: - if self.provider_wake_prefix.startswith(bwp): - logger.info( - f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。", - ) - self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :] - self.conv_manager = ctx.plugin_manager.context.conversation_manager def _select_provider(self, event: AstrMessageEvent): @@ -304,21 +294,10 @@ class LLMRequestSubStage(Stage): return fixed_messages async def process( - self, - event: AstrMessageEvent, - _nested: bool = False, - ) -> None | AsyncGenerator[None, None]: + self, event: AstrMessageEvent, provider_wake_prefix: str + ) -> AsyncGenerator[None, None]: req: ProviderRequest | None = None - if not self.ctx.astrbot_config["provider_settings"]["enable"]: - logger.debug("未启用 LLM 能力,跳过处理。") - return - - # 检查会话级别的LLM启停状态 - if not SessionServiceManager.should_process_llm_request(event): - logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。") - return - provider = self._select_provider(event) if provider is None: return @@ -348,12 +327,12 @@ class LLMRequestSubStage(Stage): req.image_urls = [] if sel_model := event.get_extra("selected_model"): req.model = sel_model - if self.provider_wake_prefix and not event.message_str.startswith( - self.provider_wake_prefix + if provider_wake_prefix and not event.message_str.startswith( + provider_wake_prefix ): return - req.prompt = event.message_str[len(self.provider_wake_prefix) :] + req.prompt = event.message_str[len(provider_wake_prefix) :] # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。 # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() for comp in event.message_obj.message: diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py new file mode 100644 index 000000000..d4d0709e7 --- /dev/null +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -0,0 +1,126 @@ +import asyncio +from collections.abc import AsyncGenerator + +from astrbot.core import logger +from astrbot.core.agent.runners.coze_agent_runner import CozeAgentRunner +from astrbot.core.agent.runners.dashscope_agent_runner import DashscopeAgentRunner +from astrbot.core.agent.runners.dify_agent_runner import DifyAgentRunner +from astrbot.core.message.components import Image +from astrbot.core.message.message_event_result import ( + MessageEventResult, + ResultContentType, +) +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.provider.entities import ( + ProviderRequest, +) +from astrbot.core.star.star_handler import EventType +from astrbot.core.utils.metrics import Metric + +from .....astr_agent_context import AgentContextWrapper, AstrAgentContext +from .....astr_agent_hooks import MAIN_AGENT_HOOKS +from ....context import PipelineContext, call_event_hook +from ...stage import Stage + +AGENT_RUNNER_TYPE_KEY = { + "dify": "dify_agent_runner_provider_id", + "coze": "coze_agent_runner_provider_id", + "dashscope": "dashscope_agent_runner_provider_id", +} + + +class ThirdPartyAgentSubStage(Stage): + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + self.conf = ctx.astrbot_config + self.runner_type = self.conf["provider_settings"]["agent_runner_type"] + self.prov_id = self.conf["provider_settings"].get( + AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""), + "", + ) + self.prov_cfg: dict = next( + (p for p in self.conf["provider"] if p["id"] == self.prov_id), + {}, + ) + + async def process( + self, event: AstrMessageEvent, provider_wake_prefix: str + ) -> AsyncGenerator[None, None]: + req: ProviderRequest | None = None + + if provider_wake_prefix and not event.message_str.startswith( + provider_wake_prefix + ): + return + if not self.prov_id or not self.prov_cfg: + logger.error( + "Third Party Agent Runner provider ID is not configured properly." + ) + return + + # make provider request + req = ProviderRequest() + req.session_id = event.unified_msg_origin + req.prompt = event.message_str[len(provider_wake_prefix) :] + for comp in event.message_obj.message: + if isinstance(comp, Image): + image_path = await comp.convert_to_base64() + req.image_urls.append(image_path) + + if not req.prompt and not req.image_urls: + return + + # call event hook + if await call_event_hook(event, EventType.OnLLMRequestEvent, req): + return + + if self.runner_type == "dify": + runner = DifyAgentRunner[AstrAgentContext]() + elif self.runner_type == "coze": + runner = CozeAgentRunner[AstrAgentContext]() + elif self.runner_type == "dashscope": + runner = DashscopeAgentRunner[AstrAgentContext]() + else: + raise ValueError( + f"Unsupported third party agent runner type: {self.runner_type}", + ) + + astr_agent_ctx = AstrAgentContext( + context=self.ctx.plugin_manager.context, + event=event, + ) + + await runner.reset( + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=60, + ), + agent_hooks=MAIN_AGENT_HOOKS, + provider_config=self.prov_cfg, + ) + + async for _ in runner.step_until_done(): + pass + + final_resp = runner.get_final_llm_resp() + + if not final_resp or not final_resp.result_chain: + logger.warning("Agent Runner 未返回最终结果。") + return + + event.set_result( + MessageEventResult( + chain=final_resp.result_chain.chain or [], + result_content_type=ResultContentType.LLM_RESULT, + ), + ) + yield + + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=self.runner_type, + provider_type=self.runner_type, + ), + ) diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index ff8120b16..56d305de4 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -24,7 +24,7 @@ class StarRequestSubStage(Stage): async def process( self, event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: + ) -> AsyncGenerator[None, None]: activated_handlers: list[StarHandlerMetadata] = event.get_extra( "activated_handlers", ) diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 9f0b5f92a..2eeefcf11 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -7,7 +7,7 @@ from astrbot.core.star.star_handler import StarHandlerMetadata from ..context import PipelineContext from ..stage import Stage, register_stage -from .method.llm_request import LLMRequestSubStage +from .method.agent_request import AgentRequestSubStage from .method.star_request import StarRequestSubStage @@ -17,9 +17,12 @@ class ProcessStage(Stage): self.ctx = ctx self.config = ctx.astrbot_config self.plugin_manager = ctx.plugin_manager - self.llm_request_sub_stage = LLMRequestSubStage() - await self.llm_request_sub_stage.initialize(ctx) + # initialize agent sub stage + self.agent_sub_stage = AgentRequestSubStage() + await self.agent_sub_stage.initialize(ctx) + + # initialize star request sub stage self.star_request_sub_stage = StarRequestSubStage() await self.star_request_sub_stage.initialize(ctx) @@ -39,7 +42,7 @@ class ProcessStage(Stage): # Handler 的 LLM 请求 event.set_extra("provider_request", resp) _t = False - async for _ in self.llm_request_sub_stage.process(event): + async for _ in self.agent_sub_stage.process(event): _t = True yield if not _t: @@ -67,5 +70,5 @@ class ProcessStage(Stage): logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。") return - async for _ in self.llm_request_sub_stage.process(event): + async for _ in self.agent_sub_stage.process(event): yield diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index ec2550415..d665b38ab 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -227,6 +227,8 @@ class ProviderManager: async def load_provider(self, provider_config: dict): if not provider_config["enable"]: return + if provider_config["provider_type"] == "agent_runner": + return logger.info( f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...", @@ -247,14 +249,6 @@ class ProviderManager: from .sources.anthropic_source import ( ProviderAnthropic as ProviderAnthropic, ) - case "dify": - from .sources.dify_source import ProviderDify as ProviderDify - case "coze": - from .sources.coze_source import ProviderCoze as ProviderCoze - case "dashscope": - from .sources.dashscope_source import ( - ProviderDashscope as ProviderDashscope, - ) case "googlegenai_chat_completion": from .sources.gemini_source import ( ProviderGoogleGenAI as ProviderGoogleGenAI, diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py deleted file mode 100644 index 7850a982c..000000000 --- a/astrbot/core/provider/sources/dify_source.py +++ /dev/null @@ -1,285 +0,0 @@ -import os - -import astrbot.core.message.components as Comp -from astrbot.core import logger, sp -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.utils.dify_api_client import DifyAPIClient -from astrbot.core.utils.io import download_file, download_image_by_url - -from .. import Provider -from ..entities import LLMResponse -from ..register import register_provider_adapter - - -@register_provider_adapter("dify", "Dify APP 适配器。") -class ProviderDify(Provider): - def __init__( - self, - provider_config, - provider_settings, - ) -> None: - super().__init__( - provider_config, - provider_settings, - ) - self.api_key = provider_config.get("dify_api_key", "") - if not self.api_key: - raise Exception("Dify API Key 不能为空。") - api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1") - self.api_type = provider_config.get("dify_api_type", "") - if not self.api_type: - raise Exception("Dify API 类型不能为空。") - self.model_name = "dify" - self.workflow_output_key = provider_config.get( - "dify_workflow_output_key", - "astrbot_wf_output", - ) - self.dify_query_input_key = provider_config.get( - "dify_query_input_key", - "astrbot_text_query", - ) - if not self.dify_query_input_key: - self.dify_query_input_key = "astrbot_text_query" - if not self.workflow_output_key: - self.workflow_output_key = "astrbot_wf_output" - self.variables: dict = provider_config.get("variables", {}) - self.timeout = provider_config.get("timeout", 120) - if isinstance(self.timeout, str): - self.timeout = int(self.timeout) - self.conversation_ids = {} - """记录当前 session id 的对话 ID""" - - self.api_client = DifyAPIClient(self.api_key, api_base) - - async def text_chat( - self, - prompt: str, - session_id=None, - image_urls=None, - func_tool=None, - contexts=None, - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ) -> LLMResponse: - if image_urls is None: - image_urls = [] - result = "" - session_id = session_id or kwargs.get("user") or "unknown" # 1734 - conversation_id = self.conversation_ids.get(session_id, "") - - files_payload = [] - for image_url in image_urls: - image_path = ( - await download_image_by_url(image_url) - if image_url.startswith("http") - else image_url - ) - file_response = await self.api_client.file_upload( - image_path, - user=session_id, - ) - logger.debug(f"Dify 上传图片响应:{file_response}") - if "id" not in file_response: - logger.warning( - f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。", - ) - continue - files_payload.append( - { - "type": "image", - "transfer_method": "local_file", - "upload_file_id": file_response["id"], - }, - ) - - # 获得会话变量 - payload_vars = self.variables.copy() - # 动态变量 - session_var = await sp.session_get(session_id, "session_variables", default={}) - payload_vars.update(session_var) - payload_vars["system_prompt"] = system_prompt - - try: - match self.api_type: - case "chat" | "agent" | "chatflow": - if not prompt: - prompt = "请描述这张图片。" - - async for chunk in self.api_client.chat_messages( - inputs={ - **payload_vars, - }, - query=prompt, - user=session_id, - conversation_id=conversation_id, - files=files_payload, - timeout=self.timeout, - ): - logger.debug(f"dify resp chunk: {chunk}") - if ( - chunk["event"] == "message" - or chunk["event"] == "agent_message" - ): - result += chunk["answer"] - if not conversation_id: - self.conversation_ids[session_id] = chunk[ - "conversation_id" - ] - conversation_id = chunk["conversation_id"] - elif chunk["event"] == "message_end": - logger.debug("Dify message end") - break - elif chunk["event"] == "error": - logger.error(f"Dify 出现错误:{chunk}") - raise Exception( - f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}", - ) - - case "workflow": - async for chunk in self.api_client.workflow_run( - inputs={ - self.dify_query_input_key: prompt, - "astrbot_session_id": session_id, - **payload_vars, - }, - user=session_id, - files=files_payload, - timeout=self.timeout, - ): - match chunk["event"]: - case "workflow_started": - logger.info( - f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。", - ) - case "node_finished": - logger.debug( - f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。", - ) - case "workflow_finished": - logger.info( - f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束", - ) - logger.debug(f"Dify 工作流结果:{chunk}") - if chunk["data"]["error"]: - logger.error( - f"Dify 工作流出现错误:{chunk['data']['error']}", - ) - raise Exception( - f"Dify 工作流出现错误:{chunk['data']['error']}", - ) - if ( - self.workflow_output_key - not in chunk["data"]["outputs"] - ): - raise Exception( - f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}", - ) - result = chunk - case _: - raise Exception(f"未知的 Dify API 类型:{self.api_type}") - except Exception as e: - logger.error(f"Dify 请求失败:{e!s}") - return LLMResponse(role="err", completion_text=f"Dify 请求失败:{e!s}") - - if not result: - logger.warning("Dify 请求结果为空,请查看 Debug 日志。") - - chain = await self.parse_dify_result(result) - - return LLMResponse(role="assistant", result_chain=chain) - - async def text_chat_stream( - self, - prompt, - session_id=None, - image_urls=..., - func_tool=None, - contexts=..., - system_prompt=None, - tool_calls_result=None, - model=None, - **kwargs, - ): - # raise NotImplementedError("This method is not implemented yet.") - # 调用 text_chat 模拟流式 - llm_response = await self.text_chat( - prompt=prompt, - session_id=session_id, - image_urls=image_urls, - func_tool=func_tool, - contexts=contexts, - system_prompt=system_prompt, - tool_calls_result=tool_calls_result, - ) - llm_response.is_chunk = True - yield llm_response - llm_response.is_chunk = False - yield llm_response - - async def parse_dify_result(self, chunk: dict | str) -> MessageChain: - if isinstance(chunk, str): - # Chat - return MessageChain(chain=[Comp.Plain(chunk)]) - - async def parse_file(item: dict): - match item["type"]: - case "image": - return Comp.Image(file=item["url"], url=item["url"]) - case "audio": - # 仅支持 wav - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - path = os.path.join(temp_dir, f"{item['filename']}.wav") - await download_file(item["url"], path) - return Comp.Image(file=item["url"], url=item["url"]) - case "video": - return Comp.Video(file=item["url"]) - case _: - return Comp.File(name=item["filename"], file=item["url"]) - - output = chunk["data"]["outputs"][self.workflow_output_key] - chains = [] - if isinstance(output, str): - # 纯文本输出 - chains.append(Comp.Plain(output)) - elif isinstance(output, list): - # 主要适配 Dify 的 HTTP 请求结点的多模态输出 - for item in output: - # handle Array[File] - if ( - not isinstance(item, dict) - or item.get("dify_model_identity", "") != "__dify__file__" - ): - chains.append(Comp.Plain(str(output))) - break - else: - chains.append(Comp.Plain(str(output))) - - # scan file - files = chunk["data"].get("files", []) - for item in files: - comp = await parse_file(item) - chains.append(comp) - - return MessageChain(chain=chains) - - async def forget(self, session_id): - self.conversation_ids[session_id] = "" - return True - - async def get_current_key(self): - return self.api_key - - async def set_key(self, key): - raise Exception("Dify 适配器不支持设置 API Key。") - - async def get_models(self): - return [self.get_model()] - - async def get_human_readable_context(self, session_id, page, page_size): - raise Exception("暂不支持获得 Dify 的历史消息记录。") - - async def terminate(self): - await self.api_client.close() diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py index ea8ff9dff..20efdbbd0 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/utils/dify_api_client.py @@ -101,14 +101,16 @@ class DifyAPIClient: async def file_upload( self, - file_path: str, user: str, + file_path: str | None = None, + file_data: bytes | None = None, ) -> dict[str, Any]: url = f"{self.api_base}/files/upload" - with open(file_path, "rb") as f: + + if file_data is not None: payload = { "user": user, - "file": f, + "file": file_data, } async with self.session.post( url, @@ -116,6 +118,20 @@ class DifyAPIClient: headers=self.headers, ) as resp: return await resp.json() # {"id": "xxx", ...} + elif file_path is not None: + with open(file_path, "rb") as f: + payload = { + "user": user, + "file": f, + } + async with self.session.post( + url, + data=payload, + headers=self.headers, + ) as resp: + return await resp.json() # {"id": "xxx", ...} + else: + raise ValueError("file_path 和 file_data 不能同时为 None") async def close(self): await self.session.close() diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index c6b4c5ede..6b1f52a69 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -40,9 +40,6 @@ class SharedPreferences: else: ret = default return ret - raise ValueError( - "scope_id and key cannot be None when getting a specific preference.", - ) async def range_get_async( self, @@ -56,30 +53,6 @@ class SharedPreferences: ret = await self.db_helper.get_preferences(scope, scope_id, key) return ret - @overload - async def session_get( - self, - umo: None, - key: str, - default: Any = None, - ) -> list[Preference]: ... - - @overload - async def session_get( - self, - umo: str, - key: None, - default: Any = None, - ) -> list[Preference]: ... - - @overload - async def session_get( - self, - umo: None, - key: None, - default: Any = None, - ) -> list[Preference]: ... - async def session_get( self, umo: str | None, @@ -88,7 +61,7 @@ class SharedPreferences: ) -> _VT | list[Preference]: """获取会话范围的偏好设置 - Note: 当 scope_id 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。 + Note: 当 umo 或者 key 为 None,时,返回 Preference 列表,其中的 value 属性是一个 dict,value["val"] 为值。 """ if umo is None or key is None: return await self.range_get_async("umo", umo, key) diff --git a/dashboard/src/components/shared/AstrBotConfigV4.vue b/dashboard/src/components/shared/AstrBotConfigV4.vue index 77fe39ebc..57c7b36ee 100644 --- a/dashboard/src/components/shared/AstrBotConfigV4.vue +++ b/dashboard/src/components/shared/AstrBotConfigV4.vue @@ -230,11 +230,8 @@ function hasVisibleItemsAfter(items, currentIndex) {