129 lines
4.8 KiB
Python
129 lines
4.8 KiB
Python
import asyncio
|
|
import functools
|
|
from typing import List
|
|
from .. import Provider, Personality
|
|
from ..entites import LLMResponse
|
|
from ..func_tool_manager import FuncCall
|
|
from astrbot.core.db import BaseDatabase
|
|
from ..register import register_provider_adapter
|
|
from .openai_source import ProviderOpenAIOfficial
|
|
from astrbot.core import logger, sp
|
|
from dashscope import Application
|
|
|
|
|
|
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
|
|
class ProviderDashscope(ProviderOpenAIOfficial):
|
|
def __init__(
|
|
self,
|
|
provider_config: dict,
|
|
provider_settings: dict,
|
|
db_helper: BaseDatabase,
|
|
persistant_history=False,
|
|
default_persona: Personality = None,
|
|
) -> None:
|
|
Provider.__init__(
|
|
self,
|
|
provider_config,
|
|
provider_settings,
|
|
persistant_history,
|
|
db_helper,
|
|
default_persona,
|
|
)
|
|
self.api_key = provider_config.get("dashscope_api_key", "")
|
|
if not self.api_key:
|
|
raise Exception("阿里云百炼 API Key 不能为空。")
|
|
self.app_id = provider_config.get("dashscope_app_id", "")
|
|
if not self.app_id:
|
|
raise Exception("阿里云百炼 APP ID 不能为空。")
|
|
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
|
|
if not self.dashscope_app_type:
|
|
raise Exception("阿里云百炼 APP 类型不能为空。")
|
|
self.model_name = "dashscope"
|
|
self.variables: dict = provider_config.get("variables", {})
|
|
|
|
self.timeout = provider_config.get("timeout", 120)
|
|
if isinstance(self.timeout, str):
|
|
self.timeout = int(self.timeout)
|
|
|
|
async def text_chat(
|
|
self,
|
|
prompt: str,
|
|
session_id: str = None,
|
|
image_urls: List[str] = [],
|
|
func_tool: FuncCall = None,
|
|
contexts: List = None,
|
|
system_prompt: str = None,
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
# 获得会话变量
|
|
payload_vars = self.variables.copy()
|
|
# 动态变量
|
|
session_vars = sp.get("session_variables", {})
|
|
session_var = session_vars.get(session_id, {})
|
|
payload_vars.update(session_var)
|
|
|
|
if self.dashscope_app_type in ["agent", "dialog-workflow"]:
|
|
# 支持多轮对话的
|
|
new_record = {"role": "user", "content": prompt}
|
|
if image_urls:
|
|
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
|
|
contexts_no_img = await self._remove_image_from_context(contexts)
|
|
context_query = [*contexts_no_img, new_record]
|
|
if system_prompt:
|
|
context_query.insert(0, {"role": "system", "content": system_prompt})
|
|
for part in context_query:
|
|
if "_no_save" in part:
|
|
del part["_no_save"]
|
|
# 调用阿里云百炼 API
|
|
partial = functools.partial(
|
|
Application.call,
|
|
app_id=self.app_id,
|
|
api_key=self.api_key,
|
|
messages=context_query,
|
|
biz_params=payload_vars or None,
|
|
)
|
|
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
|
else:
|
|
# 不支持多轮对话的
|
|
# 调用阿里云百炼 API
|
|
partial = functools.partial(
|
|
Application.call,
|
|
app_id=self.app_id,
|
|
promtp=prompt,
|
|
api_key=self.api_key,
|
|
biz_params=payload_vars or None,
|
|
)
|
|
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
|
|
|
logger.debug(f"dashscope resp: {response}")
|
|
|
|
if response.status_code != 200:
|
|
logger.error(
|
|
f"阿里云百炼请求失败: request_id={response.request_id}, code={response.status_code}, message={response.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code"
|
|
)
|
|
return LLMResponse(
|
|
role="err",
|
|
completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
|
|
)
|
|
|
|
output_text = response.output.get("text", "")
|
|
return LLMResponse(role="assistant", completion_text=output_text)
|
|
|
|
async def forget(self, session_id):
|
|
return True
|
|
|
|
async def get_current_key(self):
|
|
return self.api_key
|
|
|
|
async def set_key(self, key):
|
|
raise Exception("阿里云百炼 适配器不支持设置 API Key。")
|
|
|
|
async def get_models(self):
|
|
return [self.get_model()]
|
|
|
|
async def get_human_readable_context(self, session_id, page, page_size):
|
|
raise Exception("暂不支持获得 阿里云百炼 的历史消息记录。")
|
|
|
|
async def terminate(self):
|
|
pass
|