diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py index badf5d62b..15a6b71fb 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/utils/dify_api_client.py @@ -1,9 +1,33 @@ +import codecs import json from astrbot.core import logger -from aiohttp import ClientSession +from aiohttp import ClientSession, ClientResponse from typing import Dict, List, Any, AsyncGenerator +async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]: + decoder = codecs.getincrementaldecoder("utf-8")() + buffer = "" + async for chunk in resp.content.iter_chunked(8192): + buffer += decoder.decode(chunk) + while "\n\n" in buffer: + block, buffer = buffer.split("\n\n", 1) + if block.strip().startswith("data:"): + try: + yield json.loads(block[5:]) + except json.JSONDecodeError: + logger.warning(f"Drop invalid dify json data: {block[5:]}") + continue + # flush any remaining text + buffer += decoder.decode(b"", final=True) + if buffer.strip().startswith("data:"): + try: + yield json.loads(buffer[5:]) + except json.JSONDecodeError: + logger.warning(f"Drop invalid dify json data: {buffer[5:]}") + pass + + class DifyAPIClient: def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"): self.api_key = api_key @@ -33,31 +57,11 @@ class DifyAPIClient: ) as resp: if resp.status != 200: text = await resp.text() - raise Exception(f"chat_messages 请求失败:{resp.status}. {text}") - - buffer = "" - while True: - # 保持原有的8192字节限制,防止数据过大导致高水位报错 - chunk = await resp.content.read(8192) - if not chunk: - break - - buffer += chunk.decode("utf-8") - blocks = buffer.split("\n\n") - - # 处理完整的数据块 - for block in blocks[:-1]: - if block.strip() and block.startswith("data:"): - try: - json_str = block[5:] # 移除 "data:" 前缀 - json_obj = json.loads(json_str) - yield json_obj - except json.JSONDecodeError as e: - logger.error(f"JSON解析错误: {str(e)}") - logger.error(f"原始数据块: {json_str}") - - # 保留最后一个可能不完整的块 - buffer = blocks[-1] if blocks else "" + raise Exception( + f"Dify /chat-messages 接口请求失败:{resp.status}. {text}" + ) + async for event in _stream_sse(resp): + yield event async def workflow_run( self, @@ -77,31 +81,11 @@ class DifyAPIClient: ) as resp: if resp.status != 200: text = await resp.text() - raise Exception(f"workflow_run 请求失败:{resp.status}. {text}") - - buffer = "" - while True: - # 保持原有的8192字节限制,防止数据过大导致高水位报错 - chunk = await resp.content.read(8192) - if not chunk: - break - - buffer += chunk.decode("utf-8") - blocks = buffer.split("\n\n") - - # 处理完整的数据块 - for block in blocks[:-1]: - if block.strip() and block.startswith("data:"): - try: - json_str = block[5:] # 移除 "data:" 前缀 - json_obj = json.loads(json_str) - yield json_obj - except json.JSONDecodeError as e: - logger.error(f"JSON解析错误: {str(e)}") - logger.error(f"原始数据块: {json_str}") - - # 保留最后一个可能不完整的块 - buffer = blocks[-1] if blocks else "" + raise Exception( + f"Dify /workflows/run 接口请求失败:{resp.status}. {text}" + ) + async for event in _stream_sse(resp): + yield event async def file_upload( self, @@ -109,12 +93,15 @@ class DifyAPIClient: user: str, ) -> Dict[str, Any]: url = f"{self.api_base}/files/upload" - payload = { - "user": user, - "file": open(file_path, "rb"), - } - async with self.session.post(url, data=payload, headers=self.headers) as resp: - return await resp.json() # {"id": "xxx", ...} + with open(file_path, "rb") as f: + payload = { + "user": user, + "file": f, + } + async with self.session.post( + url, data=payload, headers=self.headers + ) as resp: + return await resp.json() # {"id": "xxx", ...} async def close(self): await self.session.close()