feat: 支持为 dify 和 dashscope 提供商设置默认固定变量 #552

This commit is contained in:
Soulter
2025-02-22 14:48:18 +08:00
parent 466c80b94d
commit 8beb7acdb1
3 changed files with 54 additions and 34 deletions
+8
View File
@@ -527,6 +527,7 @@ CONFIG_METADATA_2 = {
"dify_api_base": "https://api.dify.ai/v1",
"dify_workflow_output_key": "",
"dify_query_input_key": "astrbot_text_query",
"variables": {},
"timeout": 60,
},
"dashscope": {
@@ -536,6 +537,7 @@ CONFIG_METADATA_2 = {
"dashscope_app_type": "agent",
"dashscope_api_key": "",
"dashscope_app_id": "",
"variables": {},
"timeout": 60,
},
"whisper(API)": {
@@ -574,6 +576,12 @@ CONFIG_METADATA_2 = {
},
},
"items": {
"variables": {
"description": "工作流固定输入变量",
"type": "object",
"obvious_hint": True,
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
},
"dashscope_app_type": {
"description": "应用类型",
"type": "string",
@@ -10,6 +10,7 @@ from .openai_source import ProviderOpenAIOfficial
from astrbot.core import logger, sp
from dashscope import Application
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
class ProviderDashscope(ProviderOpenAIOfficial):
def __init__(
@@ -18,10 +19,15 @@ class ProviderDashscope(ProviderOpenAIOfficial):
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=False,
default_persona: Personality=None
default_persona: Personality = None,
) -> None:
Provider.__init__(
self, provider_config, provider_settings, persistant_history, db_helper, default_persona
self,
provider_config,
provider_settings,
persistant_history,
db_helper,
default_persona,
)
self.api_key = provider_config.get("dashscope_api_key", "")
if not self.api_key:
@@ -33,7 +39,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
if not self.dashscope_app_type:
raise Exception("阿里云百炼 APP 类型不能为空。")
self.model_name = "dashscope"
self.variables: dict = provider_config.get("variables", {})
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
@@ -48,13 +55,15 @@ class ProviderDashscope(ProviderOpenAIOfficial):
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
payload_vars.update(session_var)
if self.dashscope_app_type in ["agent", "dialog-workflow"]:
# 支持多轮对话的
# 支持多轮对话的
new_record = {"role": "user", "content": prompt}
if image_urls:
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
@@ -63,43 +72,42 @@ class ProviderDashscope(ProviderOpenAIOfficial):
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if '_no_save' in part:
del part['_no_save']
if "_no_save" in part:
del part["_no_save"]
# 调用阿里云百炼 API
partial = functools.partial(
Application.call,
app_id = self.app_id,
api_key = self.api_key,
messages = context_query,
biz_params = session_var or None,
)
response = await asyncio.get_event_loop().run_in_executor(
None,
partial
Application.call,
app_id=self.app_id,
api_key=self.api_key,
messages=context_query,
biz_params=payload_vars or None,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
else:
# 不支持多轮对话的
# 调用阿里云百炼 API
partial = functools.partial(
Application.call,
app_id = self.app_id,
promtp = prompt,
api_key = self.api_key,
biz_params=session_var or None,
Application.call,
app_id=self.app_id,
promtp=prompt,
api_key=self.api_key,
biz_params=payload_vars or None,
)
response = await asyncio.get_event_loop().run_in_executor(
None,
partial
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
logger.debug(f"dashscope resp: {response}")
if response.status_code != 200:
logger.error(f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code")
return LLMResponse(role="err", completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}")
logger.error(
f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code"
)
return LLMResponse(
role="err",
completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
)
output_text = response.output.get("text", "")
return LLMResponse(role="assistant", completion_text=output_text)
return LLMResponse(role="assistant", completion_text=output_text)
async def forget(self, session_id):
return True
@@ -117,4 +125,4 @@ class ProviderDashscope(ProviderOpenAIOfficial):
raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
async def terminate(self):
await self.api_client.close()
await self.api_client.close()
+6 -2
View File
@@ -32,6 +32,7 @@ class ProviderDify(Provider):
self.model_name = "dify"
self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output")
self.dify_query_input_key = provider_config.get("dify_query_input_key", "astrbot_text_query")
self.variables: dict = provider_config.get("variables", {})
if not self.dify_query_input_key:
self.dify_query_input_key = "astrbot_text_query"
self.timeout = provider_config.get("timeout", 120)
@@ -72,15 +73,18 @@ class ProviderDify(Provider):
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
payload_vars.update(session_var)
try:
match self.api_type:
case "chat" | "agent":
async for chunk in self.api_client.chat_messages(
inputs={
**session_var
**payload_vars,
},
query=prompt,
user=session_id,
@@ -101,7 +105,7 @@ class ProviderDify(Provider):
inputs={
self.dify_query_input_key: prompt,
"astrbot_session_id": session_id,
**session_var
**payload_vars,
},
user=session_id,
files=files_payload,