diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index fa797eee6..9fc1372a0 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -6,8 +6,7 @@ from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter from astrbot.core.utils.dify_api_client import DifyAPIClient from astrbot.core.utils.io import download_image_by_url -from astrbot.core import logger - +from astrbot.core import logger, sp @register_provider_adapter("dify", "Dify APP 适配器。") class ProviderDify(Provider): @@ -67,10 +66,16 @@ class ProviderDify(Provider): logger.debug(files_payload) + # 获得会话变量 + session_vars = sp.get("session_variables", {}) + session_var = session_vars.get(session_id, {}) + match self.api_type: case "chat" | "agent": async for chunk in self.api_client.chat_messages( - inputs={}, + inputs={ + **session_var + }, query=prompt, user=session_id, conversation_id=conversation_id, @@ -88,7 +93,8 @@ class ProviderDify(Provider): async for chunk in self.api_client.workflow_run( inputs={ "astrbot_text_query": prompt, - "astrbot_session_id": session_id + "astrbot_session_id": session_id, + **session_var }, user=session_id, files=files_payload diff --git a/astrbot/core/utils/dify_api_client.py b/astrbot/core/utils/dify_api_client.py index de9dc5d66..5d41e2911 100644 --- a/astrbot/core/utils/dify_api_client.py +++ b/astrbot/core/utils/dify_api_client.py @@ -1,4 +1,5 @@ import json +from astrbot.core import logger from aiohttp import ClientSession from typing import Dict, List, Any, AsyncGenerator @@ -29,11 +30,18 @@ class DifyAPIClient: async with self.session.post( url, json=payload, headers=self.headers, timeout=timeout ) as resp: - async for data in resp.content: + while True: + data = await resp.content.read(8192) # 防止数据过大导致高水位报错 + if not data: + break if not data.strip(): continue - if data.startswith(b"data:"): - yield json.loads(data[5:]) + elif data.startswith(b"data:"): + try: + json_ = json.loads(data[5:]) + yield json_ + except BaseException: + pass async def workflow_run( self, @@ -50,11 +58,18 @@ class DifyAPIClient: async with self.session.post( url, json=payload, headers=self.headers, timeout=timeout ) as resp: - async for data in resp.content: + while True: + data = await resp.content.read(8192) # 防止数据过大导致高水位报错 + if not data: + break if not data.strip(): continue - if data.startswith(b"data:"): - yield json.loads(data[5:]) + elif data.startswith(b"data:"): + try: + json_ = json.loads(data[5:]) + yield json_ + except BaseException: + pass async def file_upload( self, @@ -70,9 +85,6 @@ class DifyAPIClient: url, data=payload, headers=self.headers ) as resp: return await resp.json() # {"id": "xxx", ...} - - - async def close(self): await self.session.close() \ No newline at end of file diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 49d02063b..1eca68f88 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -56,6 +56,10 @@ class Main(star.Star): /persona: 情境人格设置 /tool ls: 查看、激活、停用当前注册的函数工具 +[其他] +/set <变量名> <值>: 为当前会话定义一个变量。适用于 Dify 工作流输入。 +/unset <变量名>: 删除当前会话的变量。 + 提示:如果要查看插件指令,请输入 /plugin 查看具体信息。 {notice}""" @@ -365,12 +369,35 @@ UID: {user_id} 此 ID 可用于设置管理员。/op 授权管理员, /deo req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}" if provider.curr_personality['prompt']: req.system_prompt += f"\n{provider.curr_personality['prompt']}" - - @filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE) - async def other_message(self, event: AstrMessageEvent): - print("triggered") - event.stop_event() + @filter.command("set") + async def set_variable(self, event: AstrMessageEvent, key: str, value: str): + session_id = event.get_session_id() + session_vars = sp.get("session_variables", {}) + + session_var = session_vars.get(session_id, {}) + session_var[key] = value + + session_vars[session_id] = session_var + + sp.put("session_variables", session_vars) + + yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。") + + @filter.command("unset") + async def unset_variable(self, event: AstrMessageEvent, key: str): + session_id = event.get_session_id() + session_vars = sp.get("session_variables", {}) + + session_var = session_vars.get(session_id, {}) + + if key not in session_var: + yield event.plain_result("没有那个变量名。") + else: + del session_var[key] + sp.put("session_variables", session_vars) + yield event.plain_result(f"会话 {session_id} 变量 {key} 移除成功。") + @filter.command_group("kdb") def kdb(self): pass