feat: 支持动态设置会话变量以适用 Dify 输入变量

This commit is contained in:
Soulter
2025-01-10 12:31:49 +08:00
parent 5929a8d42b
commit 3e2b4bc727
3 changed files with 63 additions and 18 deletions
+10 -4
View File
@@ -6,8 +6,7 @@ from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import logger
from astrbot.core import logger, sp
@register_provider_adapter("dify", "Dify APP 适配器。")
class ProviderDify(Provider):
@@ -67,10 +66,16 @@ class ProviderDify(Provider):
logger.debug(files_payload)
# 获得会话变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
match self.api_type:
case "chat" | "agent":
async for chunk in self.api_client.chat_messages(
inputs={},
inputs={
**session_var
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
@@ -88,7 +93,8 @@ class ProviderDify(Provider):
async for chunk in self.api_client.workflow_run(
inputs={
"astrbot_text_query": prompt,
"astrbot_session_id": session_id
"astrbot_session_id": session_id,
**session_var
},
user=session_id,
files=files_payload
+21 -9
View File
@@ -1,4 +1,5 @@
import json
from astrbot.core import logger
from aiohttp import ClientSession
from typing import Dict, List, Any, AsyncGenerator
@@ -29,11 +30,18 @@ class DifyAPIClient:
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
async for data in resp.content:
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip():
continue
if data.startswith(b"data:"):
yield json.loads(data[5:])
elif data.startswith(b"data:"):
try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def workflow_run(
self,
@@ -50,11 +58,18 @@ class DifyAPIClient:
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
async for data in resp.content:
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip():
continue
if data.startswith(b"data:"):
yield json.loads(data[5:])
elif data.startswith(b"data:"):
try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def file_upload(
self,
@@ -70,9 +85,6 @@ class DifyAPIClient:
url, data=payload, headers=self.headers
) as resp:
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()
+32 -5
View File
@@ -56,6 +56,10 @@ class Main(star.Star):
/persona: 情境人格设置
/tool ls: 查看、激活、停用当前注册的函数工具
[其他]
/set <变量名> <值>: 为当前会话定义一个变量。适用于 Dify 工作流输入。
/unset <变量名>: 删除当前会话的变量。
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
{notice}"""
@@ -365,12 +369,35 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
if provider.curr_personality['prompt']:
req.system_prompt += f"\n{provider.curr_personality['prompt']}"
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
async def other_message(self, event: AstrMessageEvent):
print("triggered")
event.stop_event()
@filter.command("set")
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
session_id = event.get_session_id()
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
session_var[key] = value
session_vars[session_id] = session_var
sp.put("session_variables", session_vars)
yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。")
@filter.command("unset")
async def unset_variable(self, event: AstrMessageEvent, key: str):
session_id = event.get_session_id()
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
if key not in session_var:
yield event.plain_result("没有那个变量名。")
else:
del session_var[key]
sp.put("session_variables", session_vars)
yield event.plain_result(f"会话 {session_id} 变量 {key} 移除成功。")
@filter.command_group("kdb")
def kdb(self):
pass