From 109650faf3bf6b6e6aea23e9c46f30e614caa77e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sun, 6 Apr 2025 00:56:33 +0800 Subject: [PATCH 01/17] =?UTF-8?q?=E2=9C=A8=20feat:=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E6=B5=81=E5=BC=8F=E8=BE=93=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/default.py | 6 + astrbot/core/message/message_event_result.py | 12 +- .../process_stage/method/llm_request.py | 122 ++++--- astrbot/core/pipeline/respond/stage.py | 11 +- .../core/pipeline/result_decorate/stage.py | 4 + astrbot/core/platform/astr_message_event.py | 11 +- .../platform/sources/telegram/tg_event.py | 95 +++++- .../platform/sources/webchat/webchat_event.py | 66 +++- astrbot/core/provider/entites.py | 5 + astrbot/core/provider/provider.py | 32 +- .../core/provider/sources/anthropic_source.py | 13 + .../core/provider/sources/dashscope_source.py | 19 +- astrbot/core/provider/sources/dify_source.py | 13 + .../core/provider/sources/gemini_source.py | 13 + .../core/provider/sources/llmtuner_source.py | 13 + .../core/provider/sources/openai_source.py | 313 ++++++++++++++---- astrbot/dashboard/routes/chat.py | 43 ++- dashboard/src/views/ChatPage.vue | 218 ++++++------ 18 files changed, 762 insertions(+), 247 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 7e2344816..f94d49aca 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -50,6 +50,7 @@ DEFAULT_CONFIG = { "default_personality": "default", "prompt_prefix": "", "max_context_length": -1, + "streaming_response": False, }, "provider_stt_settings": { "enable": False, @@ -992,6 +993,11 @@ CONFIG_METADATA_2 = { "type": "int", "hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。", }, + "streaming_response": { + "description": "启用流式回复", + "type": "bool", + "hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API 以及 Telegram 平台,并且暂不支持工具调用(后续将更新)", + }, }, }, "persona": { diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 50e50ceb5..0f7c4c7af 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -1,6 +1,6 @@ import enum -from typing import List, Optional, Union +from typing import List, Optional, Union, AsyncGenerator from dataclasses import dataclass, field from astrbot.core.message.components import ( BaseMessageComponent, @@ -131,6 +131,8 @@ class ResultContentType(enum.Enum): """调用 LLM 产生的结果""" GENERAL_RESULT = enum.auto() """普通的消息结果""" + STREAMING_RESULT = enum.auto() + """调用 LLM 产生的流式结果""" @dataclass @@ -152,6 +154,9 @@ class MessageEventResult(MessageChain): default_factory=lambda: ResultContentType.GENERAL_RESULT ) + async_stream: Optional[AsyncGenerator] = None + """异步流""" + def stop_event(self) -> "MessageEventResult": """终止事件传播。""" self.result_type = EventResultType.STOP @@ -168,6 +173,11 @@ class MessageEventResult(MessageChain): """ return self.result_type == EventResultType.STOP + def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult": + """设置异步流。""" + self.async_stream = stream + return self + def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult": """设置事件处理的结果类型。 diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 674a7fd79..5f20b13ab 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -37,6 +37,9 @@ class LLMRequestSubStage(Stage): self.max_context_length = ctx.astrbot_config["provider_settings"][ "max_context_length" ] # int + self.streaming_response = ctx.astrbot_config["provider_settings"][ + "streaming_response" + ] # bool for bwp in self.bot_wake_prefixs: if self.provider_wake_prefix.startswith(bwp): @@ -137,59 +140,90 @@ class LLMRequestSubStage(Stage): if not req.session_id: req.session_id = event.unified_msg_origin - try: - need_loop = True - while need_loop: - need_loop = False - logger.debug(f"提供商请求 Payload: {req}") - llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM + async def requesting(req: ProviderRequest): + try: + need_loop = True + while need_loop: + need_loop = False + logger.debug(f"提供商请求 Payload: {req}") - # 执行 LLM 响应后的事件钩子。 - handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnLLMResponseEvent - ) - for handler in handlers: - try: - logger.debug( - f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + final_llm_response = None + + if self.streaming_response: + stream = provider.text_chat_stream( + **req.__dict__ ) - await handler.handler(event, llm_response) - except BaseException: - logger.error(traceback.format_exc()) - - if event.is_stopped(): - logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" - ) - return - - async for result in self._handle_llm_response(event, req, llm_response): - if isinstance(result, ProviderRequest): - # 有函数工具调用并且返回了结果,我们需要再次请求 LLM - req = result - need_loop = True + async for llm_response in stream: + if llm_response.is_chunk: + logger.debug(llm_response) + yield llm_response.result_chain + else: + final_llm_response = llm_response else: - yield + final_llm_response = await provider.text_chat( + **req.__dict__ + ) # 请求 LLM - asyncio.create_task( - Metric.upload( - llm_tick=1, - model_name=provider.get_model(), - provider_type=provider.meta().type, + if not final_llm_response: + raise Exception("LLM response is None.") + + # 执行 LLM 响应后的事件钩子。 + handlers = star_handlers_registry.get_handlers_by_event_type( + EventType.OnLLMResponseEvent + ) + for handler in handlers: + try: + logger.debug( + f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + ) + await handler.handler(event, final_llm_response) + except BaseException: + logger.error(traceback.format_exc()) + + if event.is_stopped(): + logger.info( + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + ) + return + + async for result in self._handle_llm_response( + event, req, final_llm_response + ): + if isinstance(result, ProviderRequest): + # 有函数工具调用并且返回了结果,我们需要再次请求 LLM + req = result + need_loop = True + else: + yield + + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=provider.get_model(), + provider_type=provider.meta().type, + ) ) - ) - # 保存到历史记录 - await self._save_to_history(event, req, llm_response) + # 保存到历史记录 + await self._save_to_history(event, req, final_llm_response) - except BaseException as e: - logger.error(traceback.format_exc()) + except BaseException as e: + logger.error(traceback.format_exc()) + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" + ) + ) + + if not self.streaming_response: + async for _ in requesting(req): + yield + else: event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" - ) + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream(requesting(req)) ) - return async def _handle_llm_response( self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index a43f0b32d..ce77f53ca 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -6,7 +6,7 @@ from typing import Union, AsyncGenerator from ..stage import register_stage, Stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.message.message_event_result import MessageChain, ResultContentType from astrbot.core import logger from astrbot.core.message.message_event_result import BaseMessageComponent from astrbot.core.star.star_handler import star_handlers_registry, EventType @@ -79,7 +79,14 @@ class RespondStage(Stage): if result is None: return - if len(result.chain) > 0: + if result.result_content_type == ResultContentType.STREAMING_RESULT: + # 流式结果直接交付平台适配器处理 + logger.info(f"应用流式输出({event.get_platform_name()})") + await event._pre_send() + await event.send_streaming(result.async_stream) + await event._post_send() + return + elif len(result.chain) > 0: await event._pre_send() if self.enable_seg and ( diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index d7bb9583c..4fa7861d4 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -5,6 +5,7 @@ from typing import Union, AsyncGenerator from ..stage import Stage, register_stage, registered_stages from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import ResultContentType from astrbot.core.platform.message_type import MessageType from astrbot.core import logger from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node @@ -71,6 +72,9 @@ class ResultDecorateStage(Stage): result = event.get_result() if result is None or not result.chain: return + if result.result_content_type == ResultContentType.STREAMING_RESULT: + # 流式结果暂时不进行处理 + return # 回复时检查内容安全 if ( diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 3e1b14ee6..0f91b7087 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,7 +1,7 @@ import abc import asyncio from dataclasses import dataclass -from typing import List, Union, Optional +from typing import List, Union, Optional, AsyncGenerator from astrbot.core.db.po import Conversation from astrbot.core.message.components import ( @@ -202,6 +202,15 @@ class AstrMessageEvent(abc.ABC): """ return self.role == "admin" + async def send_streaming(self, generator: AsyncGenerator[List[BaseMessageComponent], None]): + """发送流式消息到消息平台,使用异步生成器。 + 目前仅支持: telegram。 + """ + asyncio.create_task( + Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) + ) + self._has_send_oper = True + async def _pre_send(self): """调度器会在执行 send() 前调用该方法""" diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index eab41ad84..6374f8623 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -1,7 +1,16 @@ +import asyncio import telegramify_markdown from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType -from astrbot.api.message_components import Plain, Image, Reply, At, File, Record +from astrbot.api.message_components import ( + Plain, + Image, + Reply, + At, + File, + Record, + BaseMessageComponent, +) from telegram.ext import ExtBot from astrbot.core.utils.io import download_file from astrbot import logger @@ -82,3 +91,87 @@ class TelegramPlatformEvent(AstrMessageEvent): else: await self.send_with_client(self.client, message, self.get_sender_id()) await super().send(message) + + async def send_streaming(self, generator): + message_thread_id = None + + if self.get_message_type() == MessageType.GROUP_MESSAGE: + user_name = self.message_obj.group_id + else: + user_name = self.get_sender_id() + + if "#" in user_name: + # it's a supergroup chat with message_thread_id + user_name, message_thread_id = user_name.split("#") + payload = { + "chat_id": user_name, + } + if message_thread_id: + payload["reply_to_message_id"] = message_thread_id + + delta = "" + message_id = None + last_edit_time = 0 # 上次编辑消息的时间 + throttle_interval = 0.6 # 编辑消息的间隔时间 (秒) + + async for chain in generator: + logger.debug(f"streaming: {chain}") + if isinstance(chain, list): + # 处理消息链中的每个组件 + for i in chain: + if isinstance(i, Plain): + delta += i.text + elif isinstance(i, Image): + image_path = await i.convert_to_file_path() + await self.client.send_photo(photo=image_path, **payload) + continue + elif isinstance(i, File): + if i.file.startswith("https://"): + path = "data/temp/" + i.name + await download_file(i.file, path) + i.file = path + + await self.client.send_document( + document=i.file, filename=i.name, **payload + ) + continue + elif isinstance(i, Record): + path = await i.convert_to_file_path() + await self.client.send_voice(voice=path, **payload) + continue + + # Plain + if not message_id: + try: + msg = await self.client.send_message(text=delta, **payload) + except Exception as e: + logger.warning(f"发送消息失败(streaming): {e}") + message_id = msg.message_id + last_edit_time = ( + asyncio.get_event_loop().time() + ) # 记录初始消息发送时间 + else: + current_time = asyncio.get_event_loop().time() + time_since_last_edit = current_time - last_edit_time + + # 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间 + if time_since_last_edit >= throttle_interval: + # 编辑消息 + try: + await self.client.edit_message_text( + text=delta, + chat_id=payload["chat_id"], + message_id=message_id, + ) + except Exception as e: + logger.warning(f"编辑消息失败(streaming): {e}") + last_edit_time = ( + asyncio.get_event_loop().time() + ) # 更新上次编辑的时间 + + if delta: + await self.client.edit_message_text( + text=delta, chat_id=payload["chat_id"], message_id=message_id + ) + + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index ef82dbfed..9aac55c23 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -16,16 +16,26 @@ class WebChatMessageEvent(AstrMessageEvent): os.makedirs(imgs_dir, exist_ok=True) @staticmethod - async def _send(message: MessageChain, session_id: str): + async def _send(message: MessageChain, session_id: str, streaming: bool = False): if not message: - web_chat_back_queue.put_nowait(None) + await web_chat_back_queue.put( + {"type": "end", "data": "", "streaming": False} + ) return cid = session_id.split("!")[-1] - + data = "" for comp in message.chain: if isinstance(comp, Plain): - web_chat_back_queue.put_nowait((comp.text, cid)) + data = comp.text + await web_chat_back_queue.put( + { + "type": "plain", + "cid": cid, + "data": data, + "streaming": streaming, + } + ) elif isinstance(comp, Image): # save image to local filename = str(uuid.uuid4()) + ".jpg" @@ -46,7 +56,15 @@ class WebChatMessageEvent(AstrMessageEvent): with open(path, "wb") as f: with open(comp.file, "rb") as f2: f.write(f2.read()) - web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid)) + data = f"[IMAGE]{filename}" + await web_chat_back_queue.put( + { + "type": "image", + "cid": cid, + "data": data, + "streaming": streaming, + } + ) elif isinstance(comp, Record): # save record to local filename = str(uuid.uuid4()) + ".wav" @@ -62,11 +80,45 @@ class WebChatMessageEvent(AstrMessageEvent): with open(path, "wb") as f: with open(comp.file, "rb") as f2: f.write(f2.read()) - web_chat_back_queue.put_nowait((f"[RECORD]{filename}", cid)) + data = f"[RECORD]{filename}" + await web_chat_back_queue.put( + { + "type": "record", + "cid": cid, + "data": data, + "streaming": streaming, + } + ) else: logger.debug(f"webchat 忽略: {comp.type}") - web_chat_back_queue.put_nowait(None) + + return data async def send(self, message: MessageChain): await WebChatMessageEvent._send(message, session_id=self.session_id) + await web_chat_back_queue.put( + { + "type": "end", + "data": "", + "streaming": False, + "cid": self.session_id.split("!")[-1], + } + ) await super().send(message) + + async def send_streaming(self, generator): + final_data = "" + async for chain in generator: + final_data += await WebChatMessageEvent._send( + MessageChain(chain=chain), session_id=self.session_id, streaming=True + ) + + await web_chat_back_queue.put( + { + "type": "end", + "data": final_data, + "streaming": True, + "cid": self.session_id.split("!")[-1], + } + ) + await super().send_streaming(generator) diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entites.py index a8ffcdf64..99824fd0e 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entites.py @@ -204,6 +204,9 @@ class LLMResponse: _completion_text: str = "" + is_chunk: bool = False + """是否是流式输出的单个 Chunk""" + def __init__( self, role: str, @@ -214,6 +217,7 @@ class LLMResponse: tools_call_ids: List[str] = [], raw_completion: ChatCompletion = None, _new_record: Dict[str, any] = None, + is_chunk: bool = False, ): """初始化 LLMResponse @@ -233,6 +237,7 @@ class LLMResponse: self.tools_call_ids = tools_call_ids self.raw_completion = raw_completion self._new_record = _new_record + self.is_chunk = is_chunk @property def completion_text(self): diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 8dcff9a52..21185d6e3 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,7 +1,7 @@ import abc from typing import List from astrbot.core.db import BaseDatabase -from typing import TypedDict +from typing import TypedDict, AsyncGenerator from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.provider.entites import LLMResponse, ToolCallsResult from dataclasses import dataclass @@ -108,7 +108,35 @@ class Provider(AbstractProvider): - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 """ - raise NotImplementedError() + ... + + async def text_chat_stream( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = None, + func_tool: FuncCall = None, + contexts: List = None, + system_prompt: str = None, + tool_calls_result: ToolCallsResult = None, + **kwargs, + ) -> AsyncGenerator[LLMResponse, None]: + """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 + + Args: + prompt: 提示词 + session_id: 会话 ID(此属性已经被废弃) + image_urls: 图片 URL 列表 + tools: Function-calling 工具 + contexts: 上下文 + tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling + kwargs: 其他参数 + + Notes: + - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 + - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 + """ + ... async def pop_record(self, context: List): """ diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index fd19c40ca..7e26018b1 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -160,6 +160,19 @@ class ProviderAnthropic(ProviderOpenAIOfficial): return llm_response + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + raise NotImplementedError("This method is not implemented yet.") + async def assemble_context(self, text: str, image_urls: List[str] = None): """组装上下文,支持文本和图片""" if not image_urls: diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 14aefceef..cf1559f5e 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -141,12 +141,29 @@ class ProviderDashscope(ProviderOpenAIOfficial): if self.output_reference and response.output.get("doc_references", None): ref_str = "" for ref in response.output.get("doc_references", []): - ref_title = ref.get("title", "") if ref.get("title") else ref.get("doc_name", "") + ref_title = ( + ref.get("title", "") + if ref.get("title") + else ref.get("doc_name", "") + ) ref_str += f"{ref['index_id']}. {ref_title}\n" output_text += f"\n\n回答来源:\n{ref_str}" return LLMResponse(role="assistant", completion_text=output_text) + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + raise NotImplementedError("This method is not implemented yet.") + async def forget(self, session_id): return True diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 8b5890c28..f0c7225e3 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -189,6 +189,19 @@ class ProviderDify(Provider): return LLMResponse(role="assistant", result_chain=chain) + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + raise NotImplementedError("This method is not implemented yet.") + async def parse_dify_result(self, chunk: dict | str) -> MessageChain: if isinstance(chunk, str): # Chat diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index c316544ff..19da42f1a 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -338,6 +338,19 @@ class ProviderGoogleGenAI(Provider): return llm_response + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + raise NotImplementedError("This method is not implemented yet.") + def get_current_key(self) -> str: return self.client.api_key diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py index bfd9e03a5..c43d03580 100644 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -95,6 +95,19 @@ class LLMTunerModelLoader(Provider): return llm_response + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + raise NotImplementedError("This method is not implemented yet.") + async def get_current_key(self): return "none" diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index f8d392404..c29227926 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -4,17 +4,20 @@ import os import inspect import random import asyncio +import astrbot.core.message.components as Comp from openai import AsyncOpenAI, AsyncAzureOpenAI from openai.types.chat.chat_completion import ChatCompletion +# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai._exceptions import NotFoundError, UnprocessableEntityError +from openai.lib.streaming.chat._completions import ChatCompletionStreamState from astrbot.core.utils.io import download_image_by_url from astrbot.core.db import BaseDatabase from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall -from typing import List +from typing import List, AsyncGenerator from ..register import register_provider_adapter from astrbot.core.provider.entites import LLMResponse @@ -107,12 +110,63 @@ class ProviderOpenAIOfficial(Provider): logger.debug(f"completion: {completion}") + llm_response = await self.parse_openai_completion(completion, tools) + + return llm_response + + async def _query_stream( + self, payloads: dict, tools: FuncCall + ) -> AsyncGenerator[LLMResponse, None]: + """流式查询API,逐步返回结果""" + if tools: + tool_list = tools.get_func_desc_openai_style() + if tool_list: + payloads["tools"] = tool_list + + # 不在默认参数中的参数放在 extra_body 中 + extra_body = {} + to_del = [] + for key in payloads.keys(): + if key not in self.default_params: + extra_body[key] = payloads[key] + to_del.append(key) + for key in to_del: + del payloads[key] + + stream = await self.client.chat.completions.create( + **payloads, stream=True, extra_body=extra_body + ) + + llm_response = LLMResponse("assistant", is_chunk=True) + + state = ChatCompletionStreamState() + + async for chunk in stream: + state.handle_chunk(chunk) + if len(chunk.choices) == 0: + continue + delta = chunk.choices[0].delta + # 处理文本内容 + if delta.content: + completion_text = delta.content + llm_response.result_chain = [Comp.Plain(completion_text)] + yield llm_response + + final_completion = state.get_final_completion() + llm_response = await self.parse_openai_completion(final_completion, tools) + + yield llm_response + + async def parse_openai_completion( + self, completion: ChatCompletion, tools: FuncCall + ): + """解析 OpenAI 的 ChatCompletion 响应""" + llm_response = LLMResponse("assistant") + if len(completion.choices) == 0: raise Exception("API 返回的 completion 为空。") choice = completion.choices[0] - llm_response = LLMResponse("assistant") - if choice.message.content: # text completion completion_text = str(choice.message.content).strip() @@ -148,7 +202,7 @@ class ProviderOpenAIOfficial(Provider): return llm_response - async def text_chat( + async def _prepare_chat_payload( self, prompt: str, session_id: str = None, @@ -158,7 +212,8 @@ class ProviderOpenAIOfficial(Provider): system_prompt=None, tool_calls_result=None, **kwargs, - ) -> LLMResponse: + ) -> tuple: + """准备聊天所需的有效载荷和上下文""" new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] if system_prompt: @@ -177,8 +232,117 @@ class ProviderOpenAIOfficial(Provider): payloads = {"messages": context_query, **model_config} - llm_response = None + return payloads, context_query, func_tool + async def _handle_api_error( + self, + e: Exception, + payloads: dict, + context_query: list, + func_tool: FuncCall, + chosen_key: str, + available_api_keys: List[str], + retry_cnt: int, + max_retries: int, + ) -> tuple: + """处理API错误并尝试恢复""" + if "429" in str(e): + logger.warning( + f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}" + ) + # 最后一次不等待 + if retry_cnt < max_retries - 1: + await asyncio.sleep(1) + available_api_keys.remove(chosen_key) + if len(available_api_keys) > 0: + chosen_key = random.choice(available_api_keys) + return ( + False, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) + else: + raise e + elif "maximum context length" in str(e): + logger.warning( + f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" + ) + await self.pop_record(context_query) + payloads["messages"] = context_query + return ( + False, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) + elif "The model is not a VLM" in str(e): # siliconcloud + # 尝试删除所有 image + new_contexts = await self._remove_image_from_context(context_query) + payloads["messages"] = new_contexts + context_query = new_contexts + return ( + False, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) + elif ( + "Function calling is not enabled" in str(e) + or ("tool" in str(e).lower() and "support" in str(e).lower()) + or ("function" in str(e).lower() and "support" in str(e).lower()) + ): + # openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配 + logger.info( + f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。" + ) + if "tools" in payloads: + del payloads["tools"] + return False, chosen_key, available_api_keys, payloads, context_query, None + else: + logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}") + + if "tool" in str(e).lower() and "support" in str(e).lower(): + logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all") + + if "Connection error." in str(e): + proxy = os.environ.get("http_proxy", None) + if proxy: + logger.error( + f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}" + ) + + raise e + + async def text_chat( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, + contexts=[], + system_prompt=None, + tool_calls_result=None, + **kwargs, + ) -> LLMResponse: + payloads, context_query, func_tool = await self._prepare_chat_payload( + prompt, + session_id, + image_urls, + func_tool, + contexts, + system_prompt, + tool_calls_result, + **kwargs, + ) + + llm_response = None max_retries = 10 available_api_keys = self.api_keys.copy() chosen_key = random.choice(available_api_keys) @@ -197,64 +361,97 @@ class ProviderOpenAIOfficial(Provider): payloads["messages"] = new_contexts context_query = new_contexts except Exception as e: - if "429" in str(e): - logger.warning( - f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}" - ) - # 最后一次不等待 - if retry_cnt < max_retries - 1: - await asyncio.sleep(1) - available_api_keys.remove(chosen_key) - if len(available_api_keys) > 0: - chosen_key = random.choice(available_api_keys) - continue - else: - raise e - elif "maximum context length" in str(e): - logger.warning( - f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" - ) - await self.pop_record(context_query) - elif "The model is not a VLM" in str(e): # siliconcloud - # 尝试删除所有 image - new_contexts = await self._remove_image_from_context(context_query) - payloads["messages"] = new_contexts - elif ( - "Function calling is not enabled" in str(e) - or ("tool" in str(e).lower() and "support" in str(e).lower()) - or ("function" in str(e).lower() and "support" in str(e).lower()) - ): - # openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配 - logger.info( - f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。" - ) - if "tools" in payloads: - del payloads["tools"] - func_tool = None - else: - logger.error( - f"发生了错误。Provider 配置如下: {self.provider_config}" - ) - - if "tool" in str(e).lower() and "support" in str(e).lower(): - logger.error( - "疑似该模型不支持函数调用工具调用。请输入 /tool off_all" - ) - - if "Connection error." in str(e): - proxy = os.environ.get("http_proxy", None) - if proxy: - logger.error( - f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}" - ) - - raise e + ( + success, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) = await self._handle_api_error( + e, + payloads, + context_query, + func_tool, + chosen_key, + available_api_keys, + retry_cnt, + max_retries, + ) + if success: + break if retry_cnt == max_retries - 1: logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。") raise e return llm_response + async def text_chat_stream( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, + contexts=[], + system_prompt=None, + tool_calls_result=None, + **kwargs, + ) -> AsyncGenerator[LLMResponse, None]: + """流式对话,与服务商交互并逐步返回结果""" + payloads, context_query, func_tool = await self._prepare_chat_payload( + prompt, + session_id, + image_urls, + func_tool, + contexts, + system_prompt, + tool_calls_result, + **kwargs, + ) + + max_retries = 10 + available_api_keys = self.api_keys.copy() + chosen_key = random.choice(available_api_keys) + + e = None + retry_cnt = 0 + for retry_cnt in range(max_retries): + try: + self.client.api_key = chosen_key + async for response in self._query_stream(payloads, func_tool): + yield response + break + except UnprocessableEntityError as e: + logger.warning(f"不可处理的实体错误:{e},尝试删除图片。") + # 尝试删除所有 image + new_contexts = await self._remove_image_from_context(context_query) + payloads["messages"] = new_contexts + context_query = new_contexts + except Exception as e: + ( + success, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) = await self._handle_api_error( + e, + payloads, + context_query, + func_tool, + chosen_key, + available_api_keys, + retry_cnt, + max_retries, + ) + if success: + break + + if retry_cnt == max_retries - 1: + logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。") + raise e + async def _remove_image_from_context(self, contexts: List): """ 从上下文中删除所有带有 image 的记录 diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index db1461f59..0f4e82ea4 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -161,42 +161,53 @@ class ChatRoute(Route): username = g.get("username", "guest") if username in self.curr_chat_sse: - return "[ERROR]\n" + return Response().error("Already connected").__dict__ self.curr_chat_sse[username] = None + heartbeat = json.dumps({"type": "heartbeat", "data": "ping"}) + async def stream(): try: - yield "[HB]\n" + yield f"data: {heartbeat}\n\n" # 心跳包 while True: try: result = await asyncio.wait_for( web_chat_back_queue.get(), timeout=10 ) # 设置超时时间为5秒 except asyncio.TimeoutError: - yield "[HB]\n" # 心跳包 + yield f"data: {heartbeat}\n\n" # 心跳包 continue if not result: continue - result_text, cid = result + + result_text = result["data"] + type = result.get("type") + cid = result.get("cid") + streaming = result.get("streaming", False) if cid != self.curr_user_cid.get(username): # 丢弃 continue - yield result_text + "\n" + yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" + await asyncio.sleep(0.15) - conversation = self.db.get_conversation_by_user_id(username, cid) - try: - history = json.loads(conversation.history) - except BaseException as e: - print(e) - history = [] - history.append({"type": "bot", "message": result_text}) - self.db.update_conversation( - username, cid, history=json.dumps(history) - ) + if streaming and type != "end": + continue - await asyncio.sleep(0.5) + if result_text: + conversation = self.db.get_conversation_by_user_id( + username, cid + ) + try: + history = json.loads(conversation.history) + except BaseException as e: + print(e) + history = [] + history.append({"type": "bot", "message": result_text}) + self.db.update_conversation( + username, cid, history=json.dumps(history) + ) except BaseException as _: logger.debug(f"用户 {username} 断开聊天长连接。") self.curr_chat_sse.pop(username) diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index 81c02374b..196a8040d 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -1,6 +1,7 @@