diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 68c3146b5..880e7b404 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -527,6 +527,7 @@ CONFIG_METADATA_2 = { "dify_api_base": "https://api.dify.ai/v1", "dify_workflow_output_key": "", "dify_query_input_key": "astrbot_text_query", + "variables": {}, "timeout": 60, }, "dashscope": { @@ -536,6 +537,7 @@ CONFIG_METADATA_2 = { "dashscope_app_type": "agent", "dashscope_api_key": "", "dashscope_app_id": "", + "variables": {}, "timeout": 60, }, "whisper(API)": { @@ -574,6 +576,12 @@ CONFIG_METADATA_2 = { }, }, "items": { + "variables": { + "description": "工作流固定输入变量", + "type": "object", + "obvious_hint": True, + "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。", + }, "dashscope_app_type": { "description": "应用类型", "type": "string", diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index a120efbd3..eb1db3607 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -10,6 +10,7 @@ from .openai_source import ProviderOpenAIOfficial from astrbot.core import logger, sp from dashscope import Application + @register_provider_adapter("dashscope", "Dashscope APP 适配器。") class ProviderDashscope(ProviderOpenAIOfficial): def __init__( @@ -18,10 +19,15 @@ class ProviderDashscope(ProviderOpenAIOfficial): provider_settings: dict, db_helper: BaseDatabase, persistant_history=False, - default_persona: Personality=None + default_persona: Personality = None, ) -> None: Provider.__init__( - self, provider_config, provider_settings, persistant_history, db_helper, default_persona + self, + provider_config, + provider_settings, + persistant_history, + db_helper, + default_persona, ) self.api_key = provider_config.get("dashscope_api_key", "") if not self.api_key: @@ -33,7 +39,8 @@ class ProviderDashscope(ProviderOpenAIOfficial): if not self.dashscope_app_type: raise Exception("阿里云百炼 APP 类型不能为空。") self.model_name = "dashscope" - + self.variables: dict = provider_config.get("variables", {}) + self.timeout = provider_config.get("timeout", 120) if isinstance(self.timeout, str): self.timeout = int(self.timeout) @@ -48,13 +55,15 @@ class ProviderDashscope(ProviderOpenAIOfficial): system_prompt: str = None, **kwargs, ) -> LLMResponse: - # 获得会话变量 + payload_vars = self.variables.copy() + # 动态变量 session_vars = sp.get("session_variables", {}) session_var = session_vars.get(session_id, {}) - + payload_vars.update(session_var) + if self.dashscope_app_type in ["agent", "dialog-workflow"]: - # 支持多轮对话的 + # 支持多轮对话的 new_record = {"role": "user", "content": prompt} if image_urls: logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。") @@ -63,43 +72,42 @@ class ProviderDashscope(ProviderOpenAIOfficial): if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) for part in context_query: - if '_no_save' in part: - del part['_no_save'] + if "_no_save" in part: + del part["_no_save"] # 调用阿里云百炼 API partial = functools.partial( - Application.call, - app_id = self.app_id, - api_key = self.api_key, - messages = context_query, - biz_params = session_var or None, - ) - response = await asyncio.get_event_loop().run_in_executor( - None, - partial + Application.call, + app_id=self.app_id, + api_key=self.api_key, + messages=context_query, + biz_params=payload_vars or None, ) + response = await asyncio.get_event_loop().run_in_executor(None, partial) else: # 不支持多轮对话的 # 调用阿里云百炼 API partial = functools.partial( - Application.call, - app_id = self.app_id, - promtp = prompt, - api_key = self.api_key, - biz_params=session_var or None, + Application.call, + app_id=self.app_id, + promtp=prompt, + api_key=self.api_key, + biz_params=payload_vars or None, ) - response = await asyncio.get_event_loop().run_in_executor( - None, - partial - ) - + response = await asyncio.get_event_loop().run_in_executor(None, partial) + logger.debug(f"dashscope resp: {response}") if response.status_code != 200: - logger.error(f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code") - return LLMResponse(role="err", completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}") - + logger.error( + f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code" + ) + return LLMResponse( + role="err", + completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}", + ) + output_text = response.output.get("text", "") - return LLMResponse(role="assistant", completion_text=output_text) + return LLMResponse(role="assistant", completion_text=output_text) async def forget(self, session_id): return True @@ -117,4 +125,4 @@ class ProviderDashscope(ProviderOpenAIOfficial): raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。") async def terminate(self): - await self.api_client.close() \ No newline at end of file + await self.api_client.close() diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 807520e0e..1127c7561 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -32,6 +32,7 @@ class ProviderDify(Provider): self.model_name = "dify" self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output") self.dify_query_input_key = provider_config.get("dify_query_input_key", "astrbot_text_query") + self.variables: dict = provider_config.get("variables", {}) if not self.dify_query_input_key: self.dify_query_input_key = "astrbot_text_query" self.timeout = provider_config.get("timeout", 120) @@ -72,15 +73,18 @@ class ProviderDify(Provider): logger.warning(f"未知的图片链接:{image_url},图片将忽略。") # 获得会话变量 + payload_vars = self.variables.copy() + # 动态变量 session_vars = sp.get("session_variables", {}) session_var = session_vars.get(session_id, {}) + payload_vars.update(session_var) try: match self.api_type: case "chat" | "agent": async for chunk in self.api_client.chat_messages( inputs={ - **session_var + **payload_vars, }, query=prompt, user=session_id, @@ -101,7 +105,7 @@ class ProviderDify(Provider): inputs={ self.dify_query_input_key: prompt, "astrbot_session_id": session_id, - **session_var + **payload_vars, }, user=session_id, files=files_payload,