Files
AstrBot/astrbot/core/utils/dify_api_client.py
T
2025-02-07 22:57:49 +08:00

122 lines
4.5 KiB
Python

import json
from astrbot.core import logger
from aiohttp import ClientSession
from typing import Dict, List, Any, AsyncGenerator
class DifyAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
self.api_key = api_key
self.api_base = api_base
self.session = ClientSession()
self.headers = {
"Authorization": f"Bearer {self.api_key}",
}
async def chat_messages(
self,
inputs: Dict,
query: str,
user: str,
response_mode: str = "streaming",
conversation_id: str = "",
files: List[Dict[str, Any]] = [],
timeout: float = 60,
) -> AsyncGenerator[Dict[str, Any], None]:
url = f"{self.api_base}/chat-messages"
payload = locals()
payload.pop("self")
payload.pop("timeout")
logger.info(f"chat_messages payload: {payload}")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(f"chat_messages 请求失败:{resp.status}. {text}")
buffer = ""
while True:
# 保持原有的8192字节限制,防止数据过大导致高水位报错
chunk = await resp.content.read(8192)
if not chunk:
break
buffer += chunk.decode('utf-8')
blocks = buffer.split('\n\n')
# 处理完整的数据块
for block in blocks[:-1]:
if block.strip() and block.startswith('data:'):
try:
json_str = block[5:] # 移除 "data:" 前缀
json_obj = json.loads(json_str)
yield json_obj
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}")
logger.error(f"原始数据块: {json_str}")
# 保留最后一个可能不完整的块
buffer = blocks[-1] if blocks else ""
async def workflow_run(
self,
inputs: Dict,
user: str,
response_mode: str = "streaming",
files: List[Dict[str, Any]] = [],
timeout: float = 60,
):
url = f"{self.api_base}/workflows/run"
payload = locals()
payload.pop("self")
payload.pop("timeout")
logger.info(f"workflow_run payload: {payload}")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
if resp.status != 200:
text = await resp.text()
raise Exception(f"workflow_run 请求失败:{resp.status}. {text}")
buffer = ""
while True:
# 保持原有的8192字节限制,防止数据过大导致高水位报错
chunk = await resp.content.read(8192)
if not chunk:
break
buffer += chunk.decode('utf-8')
blocks = buffer.split('\n\n')
# 处理完整的数据块
for block in blocks[:-1]:
if block.strip() and block.startswith('data:'):
try:
json_str = block[5:] # 移除 "data:" 前缀
json_obj = json.loads(json_str)
yield json_obj
except json.JSONDecodeError as e:
logger.error(f"JSON解析错误: {str(e)}")
logger.error(f"原始数据块: {json_str}")
# 保留最后一个可能不完整的块
buffer = blocks[-1] if blocks else ""
async def file_upload(
self,
file_path: str,
user: str,
) -> Dict[str, Any]:
url = f"{self.api_base}/files/upload"
payload = {
"user": user,
"file": open(file_path, "rb"),
}
async with self.session.post(
url, data=payload, headers=self.headers
) as resp:
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()