diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 5e86c0f59..3cdb90686 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -278,7 +278,7 @@ CONFIG_METADATA_2 = { }, "zhipu": { "id": "zhipu_default", - "type": "openai_chat_completion", + "type": "zhipu_chat_completion", "enable": True, "key": [], "api_base": "https://open.bigmodel.cn/api/paas/v4/", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 7e9470c10..a3863c606 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -29,6 +29,8 @@ class ProviderManager(): match provider_cfg['type']: case "openai_chat_completion": from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401 + case "zhipu_chat_completion": + from .sources.zhipu_source import ProviderZhipu # noqa: F401 case "llm_tuner": logger.info("加载 LLM Tuner 工具 ...") from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401 diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 1cc792bf3..b0c9f0a58 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -162,7 +162,12 @@ class ProviderOpenAIOfficial(Provider): logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。") self.pop_record(session_id) logger.warning(traceback.format_exc()) - + + await self.save_history(contexts, new_record, session_id, llm_response) + + return llm_response + + async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse): if llm_response.role == "assistant" and session_id: # 文本回复 if not contexts: @@ -180,8 +185,6 @@ class ProviderOpenAIOfficial(Provider): }] self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type']) - return llm_response - async def forget(self, session_id: str) -> bool: self.session_memory[session_id] = [] return True diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py new file mode 100644 index 000000000..14e48cab4 --- /dev/null +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -0,0 +1,73 @@ +import traceback +from astrbot.core.db import BaseDatabase +from astrbot import logger +from astrbot.core.provider.func_tool_manager import FuncCall +from typing import List +from ..register import register_provider_adapter +from astrbot.core.provider.entites import LLMResponse +from .openai_source import ProviderOpenAIOfficial + +@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器") +class ProviderZhipu(ProviderOpenAIOfficial): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + db_helper: BaseDatabase, + persistant_history = True + ) -> None: + super().__init__(provider_config, provider_settings, db_helper, persistant_history) + + async def text_chat( + self, + prompt: str, + session_id: str, + image_urls: List[str]=None, + func_tool: FuncCall=None, + contexts=None, + system_prompt=None, + **kwargs + ) -> LLMResponse: + new_record = await self.assemble_context(prompt, image_urls) + context_query = [] + + if not contexts: + context_query = [*self.session_memory[session_id], new_record] + else: + context_query = [*contexts, new_record] + + model_cfgs: dict = self.provider_config.get("model_config", {}) + # glm-4v-flash 只支持一张图片 + model: str = model_cfgs.get("model", "") + if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1: + logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片") + logger.debug(context_query) + new_context_query_ = [] + for i in range(0, len(context_query) - 1, 2): + if isinstance(context_query[i].get("content", ""), list): + continue + new_context_query_.append(context_query[i]) + new_context_query_.append(context_query[i+1]) + new_context_query_.append(context_query[-1]) # 保留最后一条记录 + context_query = new_context_query_ + logger.debug(context_query) + + if system_prompt: + context_query.insert(0, {"role": "system", "content": system_prompt}) + + payloads = { + "messages": context_query, + **model_cfgs + } + llm_response = None + try: + llm_response = await self._query(payloads, func_tool) + except Exception as e: + if "maximum context length" in str(e): + logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。") + self.pop_record(session_id) + logger.warning(traceback.format_exc()) + + await self.save_history(contexts, new_record, session_id, llm_response) + + return llm_response \ No newline at end of file