feat: 支持动态设置会话变量以适用 Dify 输入变量
This commit is contained in:
@@ -6,8 +6,7 @@ from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import logger
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
|
||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||
class ProviderDify(Provider):
|
||||
@@ -67,10 +66,16 @@ class ProviderDify(Provider):
|
||||
|
||||
logger.debug(files_payload)
|
||||
|
||||
# 获得会话变量
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
|
||||
match self.api_type:
|
||||
case "chat" | "agent":
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={},
|
||||
inputs={
|
||||
**session_var
|
||||
},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
@@ -88,7 +93,8 @@ class ProviderDify(Provider):
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
"astrbot_text_query": prompt,
|
||||
"astrbot_session_id": session_id
|
||||
"astrbot_session_id": session_id,
|
||||
**session_var
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from astrbot.core import logger
|
||||
from aiohttp import ClientSession
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
|
||||
@@ -29,11 +30,18 @@ class DifyAPIClient:
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
async for data in resp.content:
|
||||
while True:
|
||||
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
|
||||
if not data:
|
||||
break
|
||||
if not data.strip():
|
||||
continue
|
||||
if data.startswith(b"data:"):
|
||||
yield json.loads(data[5:])
|
||||
elif data.startswith(b"data:"):
|
||||
try:
|
||||
json_ = json.loads(data[5:])
|
||||
yield json_
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
async def workflow_run(
|
||||
self,
|
||||
@@ -50,11 +58,18 @@ class DifyAPIClient:
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
async for data in resp.content:
|
||||
while True:
|
||||
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
|
||||
if not data:
|
||||
break
|
||||
if not data.strip():
|
||||
continue
|
||||
if data.startswith(b"data:"):
|
||||
yield json.loads(data[5:])
|
||||
elif data.startswith(b"data:"):
|
||||
try:
|
||||
json_ = json.loads(data[5:])
|
||||
yield json_
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
async def file_upload(
|
||||
self,
|
||||
@@ -70,9 +85,6 @@ class DifyAPIClient:
|
||||
url, data=payload, headers=self.headers
|
||||
) as resp:
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
|
||||
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
@@ -56,6 +56,10 @@ class Main(star.Star):
|
||||
/persona: 情境人格设置
|
||||
/tool ls: 查看、激活、停用当前注册的函数工具
|
||||
|
||||
[其他]
|
||||
/set <变量名> <值>: 为当前会话定义一个变量。适用于 Dify 工作流输入。
|
||||
/unset <变量名>: 删除当前会话的变量。
|
||||
|
||||
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
|
||||
{notice}"""
|
||||
|
||||
@@ -365,12 +369,35 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
|
||||
if provider.curr_personality['prompt']:
|
||||
req.system_prompt += f"\n{provider.curr_personality['prompt']}"
|
||||
|
||||
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
|
||||
async def other_message(self, event: AstrMessageEvent):
|
||||
print("triggered")
|
||||
event.stop_event()
|
||||
|
||||
@filter.command("set")
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
|
||||
session_id = event.get_session_id()
|
||||
session_vars = sp.get("session_variables", {})
|
||||
|
||||
session_var = session_vars.get(session_id, {})
|
||||
session_var[key] = value
|
||||
|
||||
session_vars[session_id] = session_var
|
||||
|
||||
sp.put("session_variables", session_vars)
|
||||
|
||||
yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。")
|
||||
|
||||
@filter.command("unset")
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str):
|
||||
session_id = event.get_session_id()
|
||||
session_vars = sp.get("session_variables", {})
|
||||
|
||||
session_var = session_vars.get(session_id, {})
|
||||
|
||||
if key not in session_var:
|
||||
yield event.plain_result("没有那个变量名。")
|
||||
else:
|
||||
del session_var[key]
|
||||
sp.put("session_variables", session_vars)
|
||||
yield event.plain_result(f"会话 {session_id} 变量 {key} 移除成功。")
|
||||
|
||||
@filter.command_group("kdb")
|
||||
def kdb(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user