@@ -231,10 +231,10 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"provider_group": {
|
||||
"name": "大语言模型",
|
||||
"name": "服务提供商",
|
||||
"metadata": {
|
||||
"provider": {
|
||||
"description": "大语言模型配置",
|
||||
"description": "服务提供商配置",
|
||||
"type": "list",
|
||||
"config_template": {
|
||||
"openai": {
|
||||
@@ -296,6 +296,15 @@ CONFIG_METADATA_2 = {
|
||||
"llmtuner_template": "",
|
||||
"finetuning_type": "lora",
|
||||
"quantization_bit": 4,
|
||||
},
|
||||
"dify": {
|
||||
"id": "dify_app_default",
|
||||
"type": "dify",
|
||||
"enable": True,
|
||||
"dify_api_type": "chat",
|
||||
"dify_api_key": "",
|
||||
"dify_api_base": "https://api.dify.ai/v1",
|
||||
"dify_workflow_output_key": "",
|
||||
}
|
||||
},
|
||||
"items": {
|
||||
@@ -367,6 +376,27 @@ CONFIG_METADATA_2 = {
|
||||
"top_p": {"description": "Top P值", "type": "float"},
|
||||
},
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
"hint": "Dify API Key。此项必填。",
|
||||
},
|
||||
"dify_api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "Dify API Base URL。默认为 https://api.dify.ai/v1",
|
||||
},
|
||||
"dify_api_type": {
|
||||
"description": "Dify 应用类型",
|
||||
"type": "string",
|
||||
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
|
||||
"options": ["chat", "agent", "workflow"],
|
||||
},
|
||||
"dify_workflow_output_key": {
|
||||
"description": "Dify Workflow 输出变量名",
|
||||
"type": "string",
|
||||
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
|
||||
}
|
||||
},
|
||||
},
|
||||
"provider_settings": {
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
'''
|
||||
Dify 调用 Stage
|
||||
'''
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
|
||||
class DifyRequestSubStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
if provider.meta().type != "dify":
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
|
||||
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
|
||||
return
|
||||
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
req.image_urls.append(image_url)
|
||||
req.session_id = event.session_id
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
if not req.prompt:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.debug(f"Dify 请求 Payload: {req.__dict__}")
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
if llm_response.role == 'assistant':
|
||||
# text completion
|
||||
event.set_result(MessageEventResult().message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.LLM_RESULT))
|
||||
yield # rick roll
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(MessageEventResult().message("AstrBot 请求 Dify 失败:" + str(e)))
|
||||
return
|
||||
@@ -1,3 +1,6 @@
|
||||
'''
|
||||
本地 Agent 模式的 LLM 调用 Stage
|
||||
'''
|
||||
import traceback
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
@@ -41,6 +44,9 @@ class LLMRequestSubStage(Stage):
|
||||
session_provider_context = provider.session_memory.get(event.session_id)
|
||||
req.contexts = session_provider_context if session_provider_context else []
|
||||
|
||||
if not req.prompt:
|
||||
return
|
||||
|
||||
# 执行请求 LLM 前事件。
|
||||
# 装饰 system_prompt 等功能
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
|
||||
@@ -51,7 +57,7 @@ class LLMRequestSubStage(Stage):
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
logger.debug(f"请求 LLM:{req.__dict__}")
|
||||
logger.debug(f"提供商请求 Payload: {req.__dict__}")
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
'''
|
||||
本地 Agent 模式的 AstrBot 插件调用 Stage
|
||||
'''
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from typing import Dict, Any, List, AsyncGenerator, Union
|
||||
|
||||
@@ -3,6 +3,7 @@ from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from .method.llm_request import LLMRequestSubStage
|
||||
from .method.star_request import StarRequestSubStage
|
||||
from .method.dify_request import DifyRequestSubStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
@@ -20,11 +21,15 @@ class ProcessStage(Stage):
|
||||
|
||||
self.star_request_sub_stage = StarRequestSubStage()
|
||||
await self.star_request_sub_stage.initialize(ctx)
|
||||
|
||||
self.dify_request_sub_stage = DifyRequestSubStage()
|
||||
await self.dify_request_sub_stage.initialize(ctx)
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
'''处理事件
|
||||
'''
|
||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
|
||||
# 有插件 Handler 被激活
|
||||
if activated_handlers:
|
||||
async for resp in self.star_request_sub_stage.process(event):
|
||||
# 生成器返回值处理
|
||||
@@ -36,10 +41,15 @@ class ProcessStage(Stage):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
if self.ctx.astrbot_config['provider_settings'].get('enable', True):
|
||||
if not event._has_send_oper:
|
||||
'''当没有发送操作'''
|
||||
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
yield
|
||||
|
||||
# 调用提供商相关请求
|
||||
if self.ctx.astrbot_config['provider_settings'].get('enable', True) and not event._has_send_oper:
|
||||
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
match provider.meta().type:
|
||||
case "dify":
|
||||
async for _ in self.dify_request_sub_stage.process(event):
|
||||
yield
|
||||
case _:
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
yield
|
||||
@@ -22,4 +22,6 @@ class RespondStage(Stage):
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
|
||||
for handler in handlers:
|
||||
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
|
||||
await handler.handler(event)
|
||||
await handler.handler(event)
|
||||
|
||||
event.clear_result()
|
||||
@@ -104,6 +104,8 @@ class FuncCall:
|
||||
async def func_call(self, question: str, session_id: str, provider) -> tuple:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
_l.append(
|
||||
{
|
||||
"name": f["name"],
|
||||
@@ -169,3 +171,10 @@ class FuncCall:
|
||||
if ret:
|
||||
tool_call_result.append(str(ret))
|
||||
return tool_call_result, True
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return str(self.func_list)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.func_list)
|
||||
@@ -39,6 +39,8 @@ class ProviderManager():
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify # noqa: F401
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
import base64
|
||||
from typing import List
|
||||
from .. import Provider
|
||||
from ..entites import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||
class ProviderDify(Provider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
db_helper: BaseDatabase,
|
||||
persistant_history=False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
provider_config, provider_settings, persistant_history, db_helper
|
||||
)
|
||||
self.api_key = provider_config.get("dify_api_key", "")
|
||||
if not self.api_key:
|
||||
raise Exception("Dify API Key 不能为空。")
|
||||
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
|
||||
self.api_client = DifyAPIClient(self.api_key, api_base)
|
||||
self.api_type = provider_config.get("dify_api_type", "")
|
||||
if not self.api_type:
|
||||
raise Exception("Dify API 类型不能为空。")
|
||||
self.model_name = "dify"
|
||||
self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output")
|
||||
|
||||
self.conversation_ids = {}
|
||||
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
result = ""
|
||||
conversation_id = self.conversation_ids.get(session_id, "")
|
||||
|
||||
files_payload = []
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
file_response = await self.api_client.file_upload(image_path, user=session_id)
|
||||
if 'id' not in file_response:
|
||||
logger.warning(f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。")
|
||||
continue
|
||||
files_payload.append({
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_response['id'],
|
||||
})
|
||||
else:
|
||||
# TODO: 处理更多情况
|
||||
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
|
||||
|
||||
logger.debug(files_payload)
|
||||
|
||||
match self.api_type:
|
||||
case "chat" | "agent":
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
files=files_payload
|
||||
):
|
||||
logger.debug(f"dify resp chunk: {chunk}")
|
||||
if chunk['event'] == "message" or \
|
||||
chunk['event'] == "agent_message":
|
||||
result += chunk['answer']
|
||||
if not conversation_id:
|
||||
self.conversation_ids[session_id] = chunk['conversation_id']
|
||||
conversation_id = chunk['conversation_id']
|
||||
|
||||
case "workflow":
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
"astrbot_text_query": prompt,
|
||||
"astrbot_session_id": session_id
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload
|
||||
):
|
||||
match chunk['event']:
|
||||
case "workflow_started":
|
||||
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。")
|
||||
case "node_finished":
|
||||
logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。")
|
||||
case "workflow_finished":
|
||||
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。")
|
||||
if chunk['data']['error']:
|
||||
logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}")
|
||||
raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}")
|
||||
if self.workflow_output_key not in chunk['data']['outputs']:
|
||||
raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}")
|
||||
result = chunk['data']['outputs'][self.workflow_output_key]
|
||||
case _:
|
||||
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
|
||||
|
||||
return LLMResponse(role="assistant", completion_text=result)
|
||||
|
||||
async def forget(self, session_id):
|
||||
self.conversation_ids.pop(session_id, None)
|
||||
return True
|
||||
|
||||
async def get_current_key(self):
|
||||
return self.api_key
|
||||
|
||||
async def set_key(self, key):
|
||||
raise Exception("Dify 适配器不支持设置 API Key。")
|
||||
|
||||
async def get_models(self):
|
||||
return [self.get_model()]
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
raise Exception("暂不支持获得 Dify 的历史消息记录。")
|
||||
@@ -0,0 +1,78 @@
|
||||
import json
|
||||
from aiohttp import ClientSession
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
|
||||
|
||||
class DifyAPIClient:
|
||||
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.session = ClientSession()
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
async def chat_messages(
|
||||
self,
|
||||
inputs: Dict,
|
||||
query: str,
|
||||
user: str,
|
||||
response_mode: str = "streaming",
|
||||
conversation_id: str = "",
|
||||
files: List[Dict[str, Any]] = [],
|
||||
timeout: float = 60,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
url = f"{self.api_base}/chat-messages"
|
||||
payload = locals()
|
||||
payload.pop("self")
|
||||
payload.pop("timeout")
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
async for data in resp.content:
|
||||
if not data.strip():
|
||||
continue
|
||||
if data.startswith(b"data:"):
|
||||
yield json.loads(data[5:])
|
||||
|
||||
async def workflow_run(
|
||||
self,
|
||||
inputs: Dict,
|
||||
user: str,
|
||||
response_mode: str = "streaming",
|
||||
files: List[Dict[str, Any]] = [],
|
||||
timeout: float = 60,
|
||||
):
|
||||
url = f"{self.api_base}/workflows/run"
|
||||
payload = locals()
|
||||
payload.pop("self")
|
||||
payload.pop("timeout")
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
async for data in resp.content:
|
||||
if not data.strip():
|
||||
continue
|
||||
if data.startswith(b"data:"):
|
||||
yield json.loads(data[5:])
|
||||
|
||||
async def file_upload(
|
||||
self,
|
||||
file_path: str,
|
||||
user: str,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{self.api_base}/files/upload"
|
||||
payload = {
|
||||
"user": user,
|
||||
"file": open(file_path, "rb"),
|
||||
}
|
||||
async with self.session.post(
|
||||
url, data=payload, headers=self.headers
|
||||
) as resp:
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
|
||||
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
Reference in New Issue
Block a user