diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index fd6080421..9c78c0220 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -231,10 +231,10 @@ CONFIG_METADATA_2 = { }, }, "provider_group": { - "name": "大语言模型", + "name": "服务提供商", "metadata": { "provider": { - "description": "大语言模型配置", + "description": "服务提供商配置", "type": "list", "config_template": { "openai": { @@ -296,6 +296,15 @@ CONFIG_METADATA_2 = { "llmtuner_template": "", "finetuning_type": "lora", "quantization_bit": 4, + }, + "dify": { + "id": "dify_app_default", + "type": "dify", + "enable": True, + "dify_api_type": "chat", + "dify_api_key": "", + "dify_api_base": "https://api.dify.ai/v1", + "dify_workflow_output_key": "", } }, "items": { @@ -367,6 +376,27 @@ CONFIG_METADATA_2 = { "top_p": {"description": "Top P值", "type": "float"}, }, }, + "dify_api_key": { + "description": "API Key", + "type": "string", + "hint": "Dify API Key。此项必填。", + }, + "dify_api_base": { + "description": "API Base URL", + "type": "string", + "hint": "Dify API Base URL。默认为 https://api.dify.ai/v1", + }, + "dify_api_type": { + "description": "Dify 应用类型", + "type": "string", + "hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型", + "options": ["chat", "agent", "workflow"], + }, + "dify_workflow_output_key": { + "description": "Dify Workflow 输出变量名", + "type": "string", + "hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。", + } }, }, "provider_settings": { diff --git a/astrbot/core/pipeline/process_stage/method/dify_request.py b/astrbot/core/pipeline/process_stage/method/dify_request.py new file mode 100644 index 000000000..9a52fbb03 --- /dev/null +++ b/astrbot/core/pipeline/process_stage/method/dify_request.py @@ -0,0 +1,60 @@ +''' +Dify 调用 Stage +''' +import traceback +from typing import Union, AsyncGenerator +from ...context import PipelineContext +from ..stage import Stage +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType +from astrbot.core.message.components import Image +from astrbot.core import logger +from astrbot.core.utils.metrics import Metric +from astrbot.core.provider.entites import ProviderRequest + +class DifyRequestSubStage(Stage): + + async def initialize(self, ctx: PipelineContext) -> None: + self.ctx = ctx + + async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: + req: ProviderRequest = None + + provider = self.ctx.plugin_manager.context.get_using_provider() + if provider.meta().type != "dify": + return + + if event.get_extra("provider_request"): + req = event.get_extra("provider_request") + assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。" + else: + req = ProviderRequest(prompt="", image_urls=[]) + if self.ctx.astrbot_config['provider_settings']['wake_prefix']: + if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']): + return + req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):] + for comp in event.message_obj.message: + if isinstance(comp, Image): + image_url = comp.url if comp.url else comp.file + req.image_urls.append(image_url) + req.session_id = event.session_id + event.set_extra("provider_request", req) + + if not req.prompt: + return + + try: + logger.debug(f"Dify 请求 Payload: {req.__dict__}") + llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM + await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type) + + if llm_response.role == 'assistant': + # text completion + event.set_result(MessageEventResult().message(llm_response.completion_text) + .set_result_content_type(ResultContentType.LLM_RESULT)) + yield # rick roll + + except BaseException as e: + logger.error(traceback.format_exc()) + event.set_result(MessageEventResult().message("AstrBot 请求 Dify 失败:" + str(e))) + return \ No newline at end of file diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 6afd6e89b..a498fd37f 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -1,3 +1,6 @@ +''' +本地 Agent 模式的 LLM 调用 Stage +''' import traceback from typing import Union, AsyncGenerator from ...context import PipelineContext @@ -41,6 +44,9 @@ class LLMRequestSubStage(Stage): session_provider_context = provider.session_memory.get(event.session_id) req.contexts = session_provider_context if session_provider_context else [] + if not req.prompt: + return + # 执行请求 LLM 前事件。 # 装饰 system_prompt 等功能 handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent) @@ -51,7 +57,7 @@ class LLMRequestSubStage(Stage): logger.error(traceback.format_exc()) try: - logger.debug(f"请求 LLM:{req.__dict__}") + logger.debug(f"提供商请求 Payload: {req.__dict__}") llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type) diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 763a663ee..6863df473 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -1,3 +1,6 @@ +''' +本地 Agent 模式的 AstrBot 插件调用 Stage +''' from ...context import PipelineContext from ..stage import Stage from typing import Dict, Any, List, AsyncGenerator, Union diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 837e349a7..a1ab3ddef 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -3,6 +3,7 @@ from ..stage import Stage, register_stage from ..context import PipelineContext from .method.llm_request import LLMRequestSubStage from .method.star_request import StarRequestSubStage +from .method.dify_request import DifyRequestSubStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star_handler import StarHandlerMetadata from astrbot.core.provider.entites import ProviderRequest @@ -20,11 +21,15 @@ class ProcessStage(Stage): self.star_request_sub_stage = StarRequestSubStage() await self.star_request_sub_stage.initialize(ctx) + + self.dify_request_sub_stage = DifyRequestSubStage() + await self.dify_request_sub_stage.initialize(ctx) async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: '''处理事件 ''' activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers") + # 有插件 Handler 被激活 if activated_handlers: async for resp in self.star_request_sub_stage.process(event): # 生成器返回值处理 @@ -36,10 +41,15 @@ class ProcessStage(Stage): yield else: yield - - if self.ctx.astrbot_config['provider_settings'].get('enable', True): - if not event._has_send_oper: - '''当没有发送操作''' - if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result(): - async for _ in self.llm_request_sub_stage.process(event): - yield \ No newline at end of file + + # 调用提供商相关请求 + if self.ctx.astrbot_config['provider_settings'].get('enable', True) and not event._has_send_oper: + if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result(): + provider = self.ctx.plugin_manager.context.get_using_provider() + match provider.meta().type: + case "dify": + async for _ in self.dify_request_sub_stage.process(event): + yield + case _: + async for _ in self.llm_request_sub_stage.process(event): + yield \ No newline at end of file diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 6863bb909..9b98c6030 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -22,4 +22,6 @@ class RespondStage(Stage): handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent) for handler in handlers: # TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。 - await handler.handler(event) \ No newline at end of file + await handler.handler(event) + + event.clear_result() \ No newline at end of file diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 4aee3eec8..07c672daf 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -104,6 +104,8 @@ class FuncCall: async def func_call(self, question: str, session_id: str, provider) -> tuple: _l = [] for f in self.func_list: + if not f.active: + continue _l.append( { "name": f["name"], @@ -169,3 +171,10 @@ class FuncCall: if ret: tool_call_result.append(str(ret)) return tool_call_result, True + + + def __str__(self): + return str(self.func_list) + + def __repr__(self): + return str(self.func_list) \ No newline at end of file diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 2ae8a2f12..20ac8a0d0 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -39,6 +39,8 @@ class ProviderManager(): case "llm_tuner": logger.info("加载 LLM Tuner 工具 ...") from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401 + case "dify": + from .sources.dify_source import ProviderDify # noqa: F401 async def initialize(self): diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py new file mode 100644 index 000000000..64755c968 --- /dev/null +++ b/astrbot/core/provider/sources/dify_source.py @@ -0,0 +1,129 @@ +import base64 +from typing import List +from .. import Provider +from ..entites import LLMResponse +from ..func_tool_manager import FuncCall +from astrbot.core.db import BaseDatabase +from ..register import register_provider_adapter +from astrbot.core.utils.dify_api_client import DifyAPIClient +from astrbot.core.utils.io import download_image_by_url +from astrbot.core import logger + + +@register_provider_adapter("dify", "Dify APP 适配器。") +class ProviderDify(Provider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + db_helper: BaseDatabase, + persistant_history=False, + ) -> None: + super().__init__( + provider_config, provider_settings, persistant_history, db_helper + ) + self.api_key = provider_config.get("dify_api_key", "") + if not self.api_key: + raise Exception("Dify API Key 不能为空。") + api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1") + self.api_client = DifyAPIClient(self.api_key, api_base) + self.api_type = provider_config.get("dify_api_type", "") + if not self.api_type: + raise Exception("Dify API 类型不能为空。") + self.model_name = "dify" + self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output") + + self.conversation_ids = {} + + + async def text_chat( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = None, + func_tool: FuncCall = None, + contexts: List = None, + system_prompt: str = None, + **kwargs, + ) -> LLMResponse: + result = "" + conversation_id = self.conversation_ids.get(session_id, "") + + files_payload = [] + for image_url in image_urls: + if image_url.startswith("http"): + image_path = await download_image_by_url(image_url) + file_response = await self.api_client.file_upload(image_path, user=session_id) + if 'id' not in file_response: + logger.warning(f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。") + continue + files_payload.append({ + "type": "image", + "transfer_method": "local_file", + "upload_file_id": file_response['id'], + }) + else: + # TODO: 处理更多情况 + logger.warning(f"未知的图片链接:{image_url},图片将忽略。") + + logger.debug(files_payload) + + match self.api_type: + case "chat" | "agent": + async for chunk in self.api_client.chat_messages( + inputs={}, + query=prompt, + user=session_id, + conversation_id=conversation_id, + files=files_payload + ): + logger.debug(f"dify resp chunk: {chunk}") + if chunk['event'] == "message" or \ + chunk['event'] == "agent_message": + result += chunk['answer'] + if not conversation_id: + self.conversation_ids[session_id] = chunk['conversation_id'] + conversation_id = chunk['conversation_id'] + + case "workflow": + async for chunk in self.api_client.workflow_run( + inputs={ + "astrbot_text_query": prompt, + "astrbot_session_id": session_id + }, + user=session_id, + files=files_payload + ): + match chunk['event']: + case "workflow_started": + logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。") + case "node_finished": + logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。") + case "workflow_finished": + logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。") + if chunk['data']['error']: + logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}") + raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}") + if self.workflow_output_key not in chunk['data']['outputs']: + raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}") + result = chunk['data']['outputs'][self.workflow_output_key] + case _: + raise Exception(f"未知的 Dify API 类型:{self.api_type}") + + return LLMResponse(role="assistant", completion_text=result) + + async def forget(self, session_id): + self.conversation_ids.pop(session_id, None) + return True + + async def get_current_key(self): + return self.api_key + + async def set_key(self, key): + raise Exception("Dify 适配器不支持设置 API Key。") + + async def get_models(self): + return [self.get_model()] + + async def get_human_readable_context(self, session_id, page, page_size): + raise Exception("暂不支持获得 Dify 的历史消息记录。") diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py new file mode 100644 index 000000000..de9dc5d66 --- /dev/null +++ b/astrbot/core/utils/dify_api_client.py @@ -0,0 +1,78 @@ +import json +from aiohttp import ClientSession +from typing import Dict, List, Any, AsyncGenerator + + +class DifyAPIClient: + def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"): + self.api_key = api_key + self.api_base = api_base + self.session = ClientSession() + self.headers = { + "Authorization": f"Bearer {self.api_key}", + } + + async def chat_messages( + self, + inputs: Dict, + query: str, + user: str, + response_mode: str = "streaming", + conversation_id: str = "", + files: List[Dict[str, Any]] = [], + timeout: float = 60, + ) -> AsyncGenerator[Dict[str, Any], None]: + url = f"{self.api_base}/chat-messages" + payload = locals() + payload.pop("self") + payload.pop("timeout") + async with self.session.post( + url, json=payload, headers=self.headers, timeout=timeout + ) as resp: + async for data in resp.content: + if not data.strip(): + continue + if data.startswith(b"data:"): + yield json.loads(data[5:]) + + async def workflow_run( + self, + inputs: Dict, + user: str, + response_mode: str = "streaming", + files: List[Dict[str, Any]] = [], + timeout: float = 60, + ): + url = f"{self.api_base}/workflows/run" + payload = locals() + payload.pop("self") + payload.pop("timeout") + async with self.session.post( + url, json=payload, headers=self.headers, timeout=timeout + ) as resp: + async for data in resp.content: + if not data.strip(): + continue + if data.startswith(b"data:"): + yield json.loads(data[5:]) + + async def file_upload( + self, + file_path: str, + user: str, + ) -> Dict[str, Any]: + url = f"{self.api_base}/files/upload" + payload = { + "user": user, + "file": open(file_path, "rb"), + } + async with self.session.post( + url, data=payload, headers=self.headers + ) as resp: + return await resp.json() # {"id": "xxx", ...} + + + + + async def close(self): + await self.session.close() \ No newline at end of file