diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py index 557273acd..9b1ade50a 100644 --- a/astrbot/api/provider/__init__.py +++ b/astrbot/api/provider/__init__.py @@ -1,5 +1,5 @@ from astrbot.core.provider import Provider, STTProvider, Personality -from astrbot.core.provider.entites import ( +from astrbot.core.provider.entities import ( ProviderRequest, ProviderType, ProviderMetaData, diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 00e27c15c..13be0b498 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -50,6 +50,7 @@ DEFAULT_CONFIG = { "default_personality": "default", "prompt_prefix": "", "max_context_length": -1, + "streaming_response": False, }, "provider_stt_settings": { "enable": False, @@ -993,6 +994,11 @@ CONFIG_METADATA_2 = { "type": "int", "hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。", }, + "streaming_response": { + "description": "启用流式回复", + "type": "bool", + "hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台", + }, }, }, "persona": { diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 50e50ceb5..28c92fa89 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -1,6 +1,6 @@ import enum -from typing import List, Optional, Union +from typing import List, Optional, Union, AsyncGenerator from dataclasses import dataclass, field from astrbot.core.message.components import ( BaseMessageComponent, @@ -111,6 +111,30 @@ class MessageChain: """获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) + def squash_plain(self): + """将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。""" + if not self.chain: + return + + new_chain = [] + first_plain = None + plain_texts = [] + + for comp in self.chain: + if isinstance(comp, Plain): + if first_plain is None: + first_plain = comp + new_chain.append(comp) + plain_texts.append(comp.text) + else: + new_chain.append(comp) + + if first_plain is not None: + first_plain.text = "".join(plain_texts) + + self.chain = new_chain + return self + class EventResultType(enum.Enum): """用于描述事件处理的结果类型。 @@ -131,6 +155,10 @@ class ResultContentType(enum.Enum): """调用 LLM 产生的结果""" GENERAL_RESULT = enum.auto() """普通的消息结果""" + STREAMING_RESULT = enum.auto() + """调用 LLM 产生的流式结果""" + STREAMING_FINISH= enum.auto() + """流式输出完成""" @dataclass @@ -152,6 +180,9 @@ class MessageEventResult(MessageChain): default_factory=lambda: ResultContentType.GENERAL_RESULT ) + async_stream: Optional[AsyncGenerator] = None + """异步流""" + def stop_event(self) -> "MessageEventResult": """终止事件传播。""" self.result_type = EventResultType.STOP @@ -168,6 +199,11 @@ class MessageEventResult(MessageChain): """ return self.result_type == EventResultType.STOP + def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult": + """设置异步流。""" + self.async_stream = stream + return self + def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult": """设置事件处理的结果类型。 diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 674a7fd79..c6a87b37c 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -12,11 +12,12 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.message.message_event_result import ( MessageEventResult, ResultContentType, + MessageChain, ) from astrbot.core.message.components import Image from astrbot.core import logger from astrbot.core.utils.metrics import Metric -from astrbot.core.provider.entites import ( +from astrbot.core.provider.entities import ( ProviderRequest, LLMResponse, ToolCallMessageSegment, @@ -37,6 +38,9 @@ class LLMRequestSubStage(Stage): self.max_context_length = ctx.astrbot_config["provider_settings"][ "max_context_length" ] # int + self.streaming_response = ctx.astrbot_config["provider_settings"][ + "streaming_response" + ] # bool for bwp in self.bot_wake_prefixs: if self.provider_wake_prefix.startswith(bwp): @@ -137,70 +141,127 @@ class LLMRequestSubStage(Stage): if not req.session_id: req.session_id = event.unified_msg_origin - try: - need_loop = True - while need_loop: - need_loop = False - logger.debug(f"提供商请求 Payload: {req}") - llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM + async def requesting(req: ProviderRequest): + try: + need_loop = True + while need_loop: + need_loop = False + logger.debug(f"提供商请求 Payload: {req}") - # 执行 LLM 响应后的事件钩子。 - handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnLLMResponseEvent - ) - for handler in handlers: - try: - logger.debug( - f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" - ) - await handler.handler(event, llm_response) - except BaseException: - logger.error(traceback.format_exc()) + final_llm_response = None - if event.is_stopped(): - logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" - ) - return - - async for result in self._handle_llm_response(event, req, llm_response): - if isinstance(result, ProviderRequest): - # 有函数工具调用并且返回了结果,我们需要再次请求 LLM - req = result - need_loop = True + if self.streaming_response: + stream = provider.text_chat_stream(**req.__dict__) + async for llm_response in stream: + if llm_response.is_chunk: + if llm_response.result_chain: + yield llm_response.result_chain # MessageChain + else: + yield MessageChain().message( + llm_response.completion_text + ) + else: + final_llm_response = llm_response else: - yield + final_llm_response = await provider.text_chat( + **req.__dict__ + ) # 请求 LLM - asyncio.create_task( - Metric.upload( - llm_tick=1, - model_name=provider.get_model(), - provider_type=provider.meta().type, + if not final_llm_response: + raise Exception("LLM response is None.") + + # 执行 LLM 响应后的事件钩子。 + handlers = star_handlers_registry.get_handlers_by_event_type( + EventType.OnLLMResponseEvent + ) + for handler in handlers: + try: + logger.debug( + f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + ) + await handler.handler(event, final_llm_response) + except BaseException: + logger.error(traceback.format_exc()) + + if event.is_stopped(): + logger.info( + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + ) + return + + if self.streaming_response: + # 流式输出的处理 + async for result in self._handle_llm_stream_response( + event, req, final_llm_response + ): + if isinstance(result, ProviderRequest): + # 有函数工具调用并且返回了结果,我们需要再次请求 LLM + req = result + need_loop = True + else: + yield + else: + # 非流式输出的处理 + async for result in self._handle_llm_response( + event, req, final_llm_response + ): + if isinstance(result, ProviderRequest): + # 有函数工具调用并且返回了结果,我们需要再次请求 LLM + req = result + need_loop = True + else: + yield + + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=provider.get_model(), + provider_type=provider.meta().type, + ) ) - ) - # 保存到历史记录 - await self._save_to_history(event, req, llm_response) + # 保存到历史记录 + await self._save_to_history(event, req, final_llm_response) - except BaseException as e: - logger.error(traceback.format_exc()) + except BaseException as e: + logger.error(traceback.format_exc()) + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" + ) + ) + + if not self.streaming_response: + event.set_extra("tool_call_result", None) + async for _ in requesting(req): + yield + else: event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" - ) + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream(requesting(req)) ) - return + # 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理 + yield + + if event.get_extra("tool_call_result"): + event.set_result(event.get_extra("tool_call_result")) + event.set_extra("tool_call_result", None) + yield async def _handle_llm_response( - self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse - ) -> AsyncGenerator[None, None]: - """处理 LLM 响应。 + self, + event: AstrMessageEvent, + req: ProviderRequest, + llm_response: LLMResponse, + ) -> AsyncGenerator[Union[None, ProviderRequest], None]: + """处理非流式 LLM 响应。 Returns: - bool: 是否需要继续调用 LLM + AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM Yields: - Iterator[bool]: 将 event 交付给下一个 stage + Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM """ if llm_response.role == "assistant": # text completion @@ -223,83 +284,138 @@ class LLMRequestSubStage(Stage): ) ) elif llm_response.role == "tool": - # function calling - tool_call_result: list[ToolCallMessageSegment] = [] - logger.info( - f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" + # 处理函数工具调用 + async for result in self._handle_function_tools(event, req, llm_response): + yield result + + async def _handle_llm_stream_response( + self, + event: AstrMessageEvent, + req: ProviderRequest, + llm_response: LLMResponse, + ) -> AsyncGenerator[Union[None, ProviderRequest], None]: + """处理流式 LLM 响应。 + + 专门用于处理流式输出完成后的响应,与非流式响应处理分离。 + + Returns: + AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM + + Yields: + Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM + """ + if llm_response.role == "assistant": + # text completion + if llm_response.result_chain: + event.set_result( + MessageEventResult( + chain=llm_response.result_chain.chain + ).set_result_content_type(ResultContentType.STREAMING_FINISH) + ) + else: + event.set_result( + MessageEventResult() + .message(llm_response.completion_text) + .set_result_content_type(ResultContentType.STREAMING_FINISH) + ) + elif llm_response.role == "err": + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" + ) ) - for func_tool_name, func_tool_args, func_tool_id in zip( - llm_response.tools_call_name, - llm_response.tools_call_args, - llm_response.tools_call_ids, - ): - try: - func_tool = req.func_tool.get_func(func_tool_name) - if func_tool.origin == "mcp": - logger.info( - f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" + elif llm_response.role == "tool": + # 处理函数工具调用 + async for result in self._handle_function_tools(event, req, llm_response): + yield result + + async def _handle_function_tools( + self, + event: AstrMessageEvent, + req: ProviderRequest, + llm_response: LLMResponse, + ) -> AsyncGenerator[Union[None, ProviderRequest], None]: + """处理函数工具调用。 + + Returns: + AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM + """ + # function calling + tool_call_result: list[ToolCallMessageSegment] = [] + logger.info( + f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" + ) + for func_tool_name, func_tool_args, func_tool_id in zip( + llm_response.tools_call_name, + llm_response.tools_call_args, + llm_response.tools_call_ids, + ): + try: + func_tool = req.func_tool.get_func(func_tool_name) + if func_tool.origin == "mcp": + logger.info( + f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" + ) + client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name] + res = await client.session.call_tool(func_tool.name, func_tool_args) + if res: + # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=res.content[0].text, + ) ) - client = req.func_tool.mcp_client_dict[ - func_tool.mcp_server_name - ] - res = await client.session.call_tool( - func_tool.name, func_tool_args - ) - if res: - # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 + else: + logger.info( + f"调用工具函数:{func_tool_name},参数:{func_tool_args}" + ) + # 尝试调用工具函数 + wrapper = self._call_handler( + self.ctx, event, func_tool.handler, **func_tool_args + ) + async for resp in wrapper: + if resp is not None: # 有 return 返回 tool_call_result.append( ToolCallMessageSegment( role="tool", tool_call_id=func_tool_id, - content=res.content[0].text, + content=resp, ) ) - else: - logger.info( - f"调用工具函数:{func_tool_name},参数:{func_tool_args}" - ) - # 尝试调用工具函数 - wrapper = self._call_handler( - self.ctx, event, func_tool.handler, **func_tool_args - ) - async for resp in wrapper: - if resp is not None: # 有 return 返回 - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=resp, - ) - ) - else: - yield # 有生成器返回 - event.clear_result() # 清除上一个 handler 的结果 - except BaseException as e: - logger.warning(traceback.format_exc()) - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=f"error: {str(e)}", - ) + else: + res = event.get_result() + if res and res.chain: + event.set_extra("tool_call_result", res) + yield # 有生成器返回 + event.clear_result() # 清除上一个 handler 的结果 + except BaseException as e: + logger.warning(traceback.format_exc()) + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=f"error: {str(e)}", ) - if tool_call_result: - # 函数调用结果 - req.func_tool = None # 暂时不支持递归工具调用 - assistant_msg_seg = AssistantMessageSegment( - role="assistant", tool_calls=llm_response.to_openai_tool_calls() ) - # 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。 - req.tool_calls_result = ToolCallsResult( - tool_calls_info=assistant_msg_seg, - tool_calls_result=tool_call_result, + if tool_call_result: + # 函数调用结果 + req.func_tool = None # 暂时不支持递归工具调用 + assistant_msg_seg = AssistantMessageSegment( + role="assistant", tool_calls=llm_response.to_openai_tool_calls() + ) + # 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。 + req.tool_calls_result = ToolCallsResult( + tool_calls_info=assistant_msg_seg, + tool_calls_result=tool_call_result, + ) + yield req # 再次执行 LLM 请求 + else: + if llm_response.completion_text: + event.set_result( + MessageEventResult().message(llm_response.completion_text) ) - yield req # 再次执行 LLM 请求 - else: - if llm_response.completion_text: - event.set_result( - MessageEventResult().message(llm_response.completion_text) - ) async def _save_to_history( self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 4c52a4a3e..f653a9fb9 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -5,7 +5,7 @@ from .method.llm_request import LLMRequestSubStage from .method.star_request import StarRequestSubStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star_handler import StarHandlerMetadata -from astrbot.core.provider.entites import ProviderRequest +from astrbot.core.provider.entities import ProviderRequest from astrbot.core import logger diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 8fa48cfe6..0d5044054 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -7,7 +7,7 @@ from typing import Union, AsyncGenerator from ..stage import register_stage, Stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.message.message_event_result import MessageChain, ResultContentType from astrbot.core import logger from astrbot.core.message.message_event_result import BaseMessageComponent from astrbot.core.star.star_handler import star_handlers_registry, EventType @@ -138,8 +138,17 @@ class RespondStage(Stage): result = event.get_result() if result is None: return + if result.result_content_type == ResultContentType.STREAMING_FINISH: + return - if len(result.chain) > 0: + if result.result_content_type == ResultContentType.STREAMING_RESULT: + # 流式结果直接交付平台适配器处理 + logger.info(f"应用流式输出({event.get_platform_name()})") + await event._pre_send() + await event.send_streaming(result.async_stream) + await event._post_send() + return + elif len(result.chain) > 0: await event._pre_send() # 检查消息链是否为空 diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index d7bb9583c..a0be2423b 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -5,6 +5,7 @@ from typing import Union, AsyncGenerator from ..stage import Stage, register_stage, registered_stages from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import ResultContentType from astrbot.core.platform.message_type import MessageType from astrbot.core import logger from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node @@ -72,11 +73,17 @@ class ResultDecorateStage(Stage): if result is None or not result.chain: return + if result.result_content_type == ResultContentType.STREAMING_RESULT: + return + + is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH + # 回复时检查内容安全 if ( self.content_safe_check_reply and self.content_safe_check_stage and result.is_llm_result() + and not is_stream # 流式输出不检查内容安全 ): text = "" for comp in result.chain: @@ -96,6 +103,10 @@ class ResultDecorateStage(Stage): logger.debug( f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" ) + if is_stream: + logger.warning( + "启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作" + ) await handler.handler(event) if event.get_result() is None or not event.get_result().chain: logger.debug( @@ -110,6 +121,11 @@ class ResultDecorateStage(Stage): ) return + # 流式输出不执行下面的逻辑 + if is_stream: + logger.info("流式输出已启用,跳过结果装饰阶段") + return + # 需要再获取一次。插件可能直接对 chain 进行了替换。 result = event.get_result() if result is None: diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 3e1b14ee6..8d3bc4c59 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,7 +1,7 @@ import abc import asyncio from dataclasses import dataclass -from typing import List, Union, Optional +from typing import List, Union, Optional, AsyncGenerator from astrbot.core.db.po import Conversation from astrbot.core.message.components import ( @@ -16,7 +16,7 @@ from astrbot.core.message.components import ( ) from astrbot.core.message.message_event_result import MessageEventResult, MessageChain from astrbot.core.platform.message_type import MessageType -from astrbot.core.provider.entites import ProviderRequest +from astrbot.core.provider.entities import ProviderRequest from astrbot.core.utils.metrics import Metric from .astrbot_message import AstrBotMessage, Group from .platform_metadata import PlatformMetadata @@ -202,6 +202,15 @@ class AstrMessageEvent(abc.ABC): """ return self.role == "admin" + async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]): + """发送流式消息到消息平台,使用异步生成器。 + 目前仅支持: telegram,qq official 私聊。 + """ + asyncio.create_task( + Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) + ) + self._has_send_oper = True + async def _pre_send(self): """调度器会在执行 send() 前调用该方法""" diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 295014ab4..9bb8b938f 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -82,6 +82,19 @@ class AiocqhttpMessageEvent(AstrMessageEvent): await super().send(message) + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) + async def get_group(self, group_id=None, **kwargs): if isinstance(group_id, str) and group_id.isdigit(): group_id = int(group_id) diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index aac1acfc5..d850a759f 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -60,3 +60,16 @@ class DingtalkMessageEvent(AstrMessageEvent): async def send(self, message: MessageChain): await self.send_with_client(self.client, message) await super().send(message) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index 78902a4c5..829a348c6 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -216,3 +216,16 @@ class GewechatPlatformEvent(AstrMessageEvent): group_owner=data.get("chatRoomOwner"), members=members, ) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index e170b76a0..544a7a5be 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -91,3 +91,16 @@ class LarkMessageEvent(AstrMessageEvent): logger.error(f"回复飞书消息失败({response.code}): {response.msg}") await super().send(message) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index d31006618..f74edd1ce 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -2,6 +2,7 @@ import botpy import botpy.message import botpy.types import botpy.types.message +import asyncio from astrbot.core.utils.io import file_to_base64, download_image_by_url from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata @@ -9,6 +10,7 @@ from astrbot.api.message_components import Plain, Image from botpy import Client from botpy.http import Route from astrbot.api import logger +from botpy.types import message class QQOfficialMessageEvent(AstrMessageEvent): @@ -30,8 +32,46 @@ class QQOfficialMessageEvent(AstrMessageEvent): else: self.send_buffer.chain.extend(message.chain) - async def _post_send(self): + async def send_streaming(self, generator): + """流式输出仅支持消息列表私聊""" + stream_payload = {"state": 1, "id": None, "index": 0, "reset": False} + last_edit_time = 0 # 上次编辑消息的时间 + throttle_interval = 1 # 编辑消息的间隔时间 (秒) + try: + async for chain in generator: + source = self.message_obj.raw_message + if not self.send_buffer: + self.send_buffer = chain + else: + self.send_buffer.chain.extend(chain.chain) + + if isinstance(source, botpy.message.C2CMessage): + # 真流式传输 + current_time = asyncio.get_event_loop().time() + time_since_last_edit = current_time - last_edit_time + + if time_since_last_edit >= throttle_interval: + ret = await self._post_send(stream=stream_payload) + stream_payload["index"] += 1 + stream_payload["id"] = ret["id"] + last_edit_time = asyncio.get_event_loop().time() + + if isinstance(source, botpy.message.C2CMessage): + # 结束流式对话,并且传输 buffer 中剩余的消息 + stream_payload["state"] = 10 + ret = await self._post_send(stream=stream_payload) + + except Exception as e: + logger.error(f"发送流式消息时出错: {e}", exc_info=True) + self.send_buffer = None + + return await super().send_streaming(generator) + + async def _post_send(self, stream: dict = None): """QQ 官方 API 仅支持回复一次""" + if not self.send_buffer: + return + source = self.message_obj.raw_message assert isinstance( source, @@ -65,7 +105,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) payload["media"] = media payload["msg_type"] = 7 - await self.bot.api.post_group_message( + ret = await self.bot.api.post_group_message( group_openid=source.group_openid, **payload ) case botpy.message.C2CMessage: @@ -75,22 +115,34 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) payload["media"] = media payload["msg_type"] = 7 - await self.bot.api.post_c2c_message( - openid=source.author.user_openid, **payload - ) + if stream: + ret = await self.post_c2c_message( + openid=source.author.user_openid, + **payload, + stream=stream, + ) + else: + ret = await self.post_c2c_message( + openid=source.author.user_openid, **payload + ) + logger.debug(f"Message sent to C2C: {ret}") case botpy.message.Message: if image_path: payload["file_image"] = image_path - await self.bot.api.post_message(channel_id=source.channel_id, **payload) + ret = await self.bot.api.post_message( + channel_id=source.channel_id, **payload + ) case botpy.message.DirectMessage: if image_path: payload["file_image"] = image_path - await self.bot.api.post_dms(guild_id=source.guild_id, **payload) + ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload) await super().send(self.send_buffer) self.send_buffer = None + return ret + async def upload_group_and_c2c_image( self, image_base64: str, file_type: int, **kwargs ) -> botpy.types.message.Media: @@ -112,6 +164,27 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) return await self.bot.api._http.request(route, json=payload) + async def post_c2c_message( + self, + openid: str, + msg_type: int = 0, + content: str = None, + embed: message.Embed = None, + ark: message.Ark = None, + message_reference: message.Reference = None, + media: message.Media = None, + msg_id: str = None, + msg_seq: str = 1, + event_id: str = None, + markdown: message.MarkdownPayload = None, + keyboard: message.Keyboard = None, + stream: dict = None, + ) -> message.Message: + payload = locals() + payload.pop("self", None) + route = Route("POST", "/v2/users/{openid}/messages", openid=openid) + return await self.bot.api._http.request(route, json=payload) + @staticmethod async def _parse_to_qqofficial(message: MessageChain): plain_text = "" diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 226a1276d..ede09e7fd 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -116,5 +116,8 @@ class QQOfficialWebhookPlatformAdapter(Platform): async def terminate(self): self.webhook_helper.shutdown_event.set() await self.client.close() - await self.webhook_helper.server.shutdown() + try: + await self.webhook_helper.server.shutdown() + except Exception as _: + pass logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭") diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index eab41ad84..bcc9189c2 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -1,7 +1,15 @@ +import asyncio import telegramify_markdown from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType -from astrbot.api.message_components import Plain, Image, Reply, At, File, Record +from astrbot.api.message_components import ( + Plain, + Image, + Reply, + At, + File, + Record, +) from telegram.ext import ExtBot from astrbot.core.utils.io import download_file from astrbot import logger @@ -82,3 +90,109 @@ class TelegramPlatformEvent(AstrMessageEvent): else: await self.send_with_client(self.client, message, self.get_sender_id()) await super().send(message) + + async def send_streaming(self, generator): + message_thread_id = None + + if self.get_message_type() == MessageType.GROUP_MESSAGE: + user_name = self.message_obj.group_id + else: + user_name = self.get_sender_id() + + if "#" in user_name: + # it's a supergroup chat with message_thread_id + user_name, message_thread_id = user_name.split("#") + payload = { + "chat_id": user_name, + } + if message_thread_id: + payload["reply_to_message_id"] = message_thread_id + + delta = "" + current_content = "" + message_id = None + last_edit_time = 0 # 上次编辑消息的时间 + throttle_interval = 0.6 # 编辑消息的间隔时间 (秒) + + async for chain in generator: + if isinstance(chain, MessageChain): + # 处理消息链中的每个组件 + for i in chain.chain: + if isinstance(i, Plain): + delta += i.text + elif isinstance(i, Image): + image_path = await i.convert_to_file_path() + await self.client.send_photo(photo=image_path, **payload) + continue + elif isinstance(i, File): + if i.file.startswith("https://"): + path = "data/temp/" + i.name + await download_file(i.file, path) + i.file = path + + await self.client.send_document( + document=i.file, filename=i.name, **payload + ) + continue + elif isinstance(i, Record): + path = await i.convert_to_file_path() + await self.client.send_voice(voice=path, **payload) + continue + else: + logger.warning(f"不支持的消息类型: {type(i)}") + continue + + # Plain + if not message_id: + try: + msg = await self.client.send_message(text=delta, **payload) + current_content = delta + except Exception as e: + logger.warning(f"发送消息失败(streaming): {e!s}") + message_id = msg.message_id + last_edit_time = ( + asyncio.get_event_loop().time() + ) # 记录初始消息发送时间 + else: + current_time = asyncio.get_event_loop().time() + time_since_last_edit = current_time - last_edit_time + + # 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间 + if time_since_last_edit >= throttle_interval: + # 编辑消息 + try: + await self.client.edit_message_text( + text=delta, + chat_id=payload["chat_id"], + message_id=message_id, + ) + current_content = delta + except Exception as e: + logger.warning(f"编辑消息失败(streaming): {e!s}") + last_edit_time = ( + asyncio.get_event_loop().time() + ) # 更新上次编辑的时间 + + try: + if delta and current_content != delta: + try: + markdown_text = telegramify_markdown.markdownify( + delta, max_line_length=None, normalize_whitespace=False + ) + await self.client.edit_message_text( + text=markdown_text, + chat_id=payload["chat_id"], + message_id=message_id, + parse_mode="MarkdownV2" + ) + except Exception as e: + logger.warning(f"Markdown转换失败,使用普通文本: {e!s}") + await self.client.edit_message_text( + text=delta, + chat_id=payload["chat_id"], + message_id=message_id + ) + except Exception as e: + logger.warning(f"编辑消息失败(streaming): {e!s}") + + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index ef82dbfed..ef5532920 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -16,16 +16,26 @@ class WebChatMessageEvent(AstrMessageEvent): os.makedirs(imgs_dir, exist_ok=True) @staticmethod - async def _send(message: MessageChain, session_id: str): + async def _send(message: MessageChain, session_id: str, streaming: bool = False): if not message: - web_chat_back_queue.put_nowait(None) + await web_chat_back_queue.put( + {"type": "end", "data": "", "streaming": False} + ) return cid = session_id.split("!")[-1] - + data = "" for comp in message.chain: if isinstance(comp, Plain): - web_chat_back_queue.put_nowait((comp.text, cid)) + data = comp.text + await web_chat_back_queue.put( + { + "type": "plain", + "cid": cid, + "data": data, + "streaming": streaming, + } + ) elif isinstance(comp, Image): # save image to local filename = str(uuid.uuid4()) + ".jpg" @@ -46,7 +56,15 @@ class WebChatMessageEvent(AstrMessageEvent): with open(path, "wb") as f: with open(comp.file, "rb") as f2: f.write(f2.read()) - web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid)) + data = f"[IMAGE]{filename}" + await web_chat_back_queue.put( + { + "type": "image", + "cid": cid, + "data": data, + "streaming": streaming, + } + ) elif isinstance(comp, Record): # save record to local filename = str(uuid.uuid4()) + ".wav" @@ -62,11 +80,45 @@ class WebChatMessageEvent(AstrMessageEvent): with open(path, "wb") as f: with open(comp.file, "rb") as f2: f.write(f2.read()) - web_chat_back_queue.put_nowait((f"[RECORD]{filename}", cid)) + data = f"[RECORD]{filename}" + await web_chat_back_queue.put( + { + "type": "record", + "cid": cid, + "data": data, + "streaming": streaming, + } + ) else: logger.debug(f"webchat 忽略: {comp.type}") - web_chat_back_queue.put_nowait(None) + + return data async def send(self, message: MessageChain): await WebChatMessageEvent._send(message, session_id=self.session_id) + await web_chat_back_queue.put( + { + "type": "end", + "data": "", + "streaming": False, + "cid": self.session_id.split("!")[-1], + } + ) await super().send(message) + + async def send_streaming(self, generator): + final_data = "" + async for chain in generator: + final_data += await WebChatMessageEvent._send( + chain, session_id=self.session_id, streaming=True + ) + + await web_chat_back_queue.put( + { + "type": "end", + "data": final_data, + "streaming": True, + "cid": self.session_id.split("!")[-1], + } + ) + await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 470b7b1f8..d8ee8b9a3 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -84,3 +84,16 @@ class WecomPlatformEvent(AstrMessageEvent): ) await super().send(message) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py index f30d1ac32..ed7135fe6 100644 --- a/astrbot/core/provider/__init__.py +++ b/astrbot/core/provider/__init__.py @@ -1,5 +1,5 @@ from .provider import Provider, Personality, STTProvider -from .entites import ProviderMetaData +from .entities import ProviderMetaData __all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"] diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entities.py similarity index 98% rename from astrbot/core/provider/entites.py rename to astrbot/core/provider/entities.py index a8ffcdf64..99824fd0e 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entities.py @@ -204,6 +204,9 @@ class LLMResponse: _completion_text: str = "" + is_chunk: bool = False + """是否是流式输出的单个 Chunk""" + def __init__( self, role: str, @@ -214,6 +217,7 @@ class LLMResponse: tools_call_ids: List[str] = [], raw_completion: ChatCompletion = None, _new_record: Dict[str, any] = None, + is_chunk: bool = False, ): """初始化 LLMResponse @@ -233,6 +237,7 @@ class LLMResponse: self.tools_call_ids = tools_call_ids self.raw_completion = raw_completion self._new_record = _new_record + self.is_chunk = is_chunk @property def completion_text(self): diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index a3fa65e86..9812a7e6a 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -2,7 +2,7 @@ import traceback import asyncio from astrbot.core.config.astrbot_config import AstrBotConfig from .provider import Provider, STTProvider, TTSProvider, Personality -from .entites import ProviderType +from .entities import ProviderType from typing import List from astrbot.core.db import BaseDatabase from .register import provider_cls_map, llm_tools diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 8dcff9a52..96547c5c2 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,9 +1,9 @@ import abc from typing import List from astrbot.core.db import BaseDatabase -from typing import TypedDict +from typing import TypedDict, AsyncGenerator from astrbot.core.provider.func_tool_manager import FuncCall -from astrbot.core.provider.entites import LLMResponse, ToolCallsResult +from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from dataclasses import dataclass @@ -108,7 +108,35 @@ class Provider(AbstractProvider): - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 """ - raise NotImplementedError() + ... + + async def text_chat_stream( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = None, + func_tool: FuncCall = None, + contexts: List = None, + system_prompt: str = None, + tool_calls_result: ToolCallsResult = None, + **kwargs, + ) -> AsyncGenerator[LLMResponse, None]: + """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 + + Args: + prompt: 提示词 + session_id: 会话 ID(此属性已经被废弃) + image_urls: 图片 URL 列表 + tools: Function-calling 工具 + contexts: 上下文 + tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling + kwargs: 其他参数 + + Notes: + - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 + - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 + """ + ... async def pop_record(self, context: List): """ diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 41a7a29d5..02d7934d1 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,5 +1,5 @@ from typing import List, Dict -from .entites import ProviderMetaData, ProviderType +from .entities import ProviderMetaData, ProviderType from astrbot.core import logger from .func_tool_manager import FuncCall diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index fd19c40ca..319515c52 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -10,7 +10,8 @@ from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from ..register import register_provider_adapter -from astrbot.core.provider.entites import LLMResponse, ToolCallsResult +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from .openai_source import ProviderOpenAIOfficial @@ -72,7 +73,8 @@ class ProviderAnthropic(ProviderOpenAIOfficial): if content.type == "text": # text completion completion_text = str(content.text).strip() - llm_response.completion_text = completion_text + # llm_response.completion_text = completion_text + llm_response.result_chain = MessageChain().message(completion_text) # Anthropic每次只返回一个函数调用 if completion.stop_reason == "tool_use": @@ -145,7 +147,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial): messages=context_query, **model_config ) llm_response = LLMResponse("assistant") - llm_response.completion_text = response.content[0].text + llm_response.result_chain = MessageChain().message(response.content[0].text) llm_response.raw_completion = response return llm_response except Exception as e: @@ -160,6 +162,33 @@ class ProviderAnthropic(ProviderOpenAIOfficial): return llm_response + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response + async def assemble_context(self, text: str, image_urls: List[str] = None): """组装上下文,支持文本和图片""" if not image_urls: diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index a23814bfc..2c4930692 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -3,10 +3,11 @@ import asyncio import functools from typing import List from .. import Provider, Personality -from ..entites import LLMResponse +from ..entities import LLMResponse from ..func_tool_manager import FuncCall from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter +from astrbot.core.message.message_event_result import MessageChain from .openai_source import ProviderOpenAIOfficial from astrbot.core import logger, sp from dashscope import Application @@ -132,7 +133,9 @@ class ProviderDashscope(ProviderOpenAIOfficial): ) return LLMResponse( role="err", - completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}", + result_chain=MessageChain().message( + f"阿里云百炼请求失败: message={response.message} code={response.status_code}" + ), ) output_text = response.output.get("text", "") @@ -149,7 +152,37 @@ class ProviderDashscope(ProviderOpenAIOfficial): ref_str += f"{ref['index_id']}. {ref_title}\n" output_text += f"\n\n回答来源:\n{ref_str}" - return LLMResponse(role="assistant", completion_text=output_text) + llm_response = LLMResponse("assistant") + llm_response.result_chain = MessageChain().message(output_text) + + return llm_response + + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response async def forget(self, session_id): return True diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index 5ecf0d9be..7a038e4ba 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -3,7 +3,7 @@ import uuid import asyncio from dashscope.audio.tts_v2 import * from ..provider import TTSProvider -from ..entites import ProviderType +from ..entities import ProviderType from ..register import register_provider_adapter diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 8b5890c28..1adb0f884 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -2,7 +2,7 @@ import astrbot.core.message.components as Comp from typing import List from .. import Provider, Personality -from ..entites import LLMResponse +from ..entities import LLMResponse from ..func_tool_manager import FuncCall from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter @@ -189,6 +189,33 @@ class ProviderDify(Provider): return LLMResponse(role="assistant", result_chain=chain) + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response + async def parse_dify_result(self, chunk: dict | str) -> MessageChain: if isinstance(chunk, str): # Chat diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index 0eadb2190..338abe263 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -4,7 +4,7 @@ import edge_tts import subprocess import asyncio from ..provider import TTSProvider -from ..entites import ProviderType +from ..entities import ProviderType from ..register import register_provider_adapter from astrbot.core import logger diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index 84b4b677e..07d0c32ab 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, conint from httpx import AsyncClient from typing import Annotated, Literal from ..provider import TTSProvider -from ..entites import ProviderType +from ..entities import ProviderType from ..register import register_provider_adapter diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 9f5f7c3c1..7def203db 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -12,7 +12,7 @@ from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from typing import List from ..register import register_provider_adapter -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entities import LLMResponse class SimpleGoogleGenAIClient: @@ -78,6 +78,39 @@ class SimpleGoogleGenAIClient: logger.error(f"Gemini 返回了非 json 数据: {text}") raise Exception("Gemini 返回了非 json 数据: ") + async def stream_generate_content( + self, + contents: List[dict], + model: str = "gemini-1.5-flash", + system_instruction: str = "", + tools: dict = None, + modalities: List[str] = ["Text"], + safety_settings: List[dict] = [], + ): + payload = {} + if system_instruction: + payload["system_instruction"] = {"parts": {"text": system_instruction}} + if tools: + payload["tools"] = [tools] + payload["contents"] = contents + payload["generationConfig"] = { + "responseModalities": modalities, + "stream": True, + } + payload["safetySettings"] = [ + {"category": s["category"], "threshold": s["threshold"]} + for s in safety_settings + ] + logger.debug(f"payload: {payload}") + request_url = ( + f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}" + ) + async with self.client.post( + request_url, json=payload, timeout=self.timeout + ) as resp: + async for line in resp.content: + if line: + yield line @register_provider_adapter( "googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器" @@ -338,6 +371,33 @@ class ProviderGoogleGenAI(Provider): return llm_response + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response + def get_current_key(self) -> str: return self.client.api_key diff --git a/astrbot/core/provider/sources/gsvi_tts_source.py b/astrbot/core/provider/sources/gsvi_tts_source.py index b57932edf..581eef4dc 100644 --- a/astrbot/core/provider/sources/gsvi_tts_source.py +++ b/astrbot/core/provider/sources/gsvi_tts_source.py @@ -2,7 +2,7 @@ import uuid import aiohttp import urllib.parse from ..provider import TTSProvider -from ..entites import ProviderType +from ..entities import ProviderType from ..register import register_provider_adapter diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py index bfd9e03a5..85994fd59 100644 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -2,7 +2,7 @@ import os from llmtuner.chat import ChatModel from typing import List from .. import Provider -from ..entites import LLMResponse +from ..entities import LLMResponse from ..func_tool_manager import FuncCall from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter @@ -95,6 +95,33 @@ class LLMTunerModelLoader(Provider): return llm_response + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response + async def get_current_key(self): return "none" diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index f8d392404..9f3db42a4 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -4,19 +4,23 @@ import os import inspect import random import asyncio +import astrbot.core.message.components as Comp from openai import AsyncOpenAI, AsyncAzureOpenAI from openai.types.chat.chat_completion import ChatCompletion +# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai._exceptions import NotFoundError, UnprocessableEntityError +from openai.lib.streaming.chat._completions import ChatCompletionStreamState from astrbot.core.utils.io import download_image_by_url +from astrbot.core.message.message_event_result import MessageChain from astrbot.core.db import BaseDatabase from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall -from typing import List +from typing import List, AsyncGenerator from ..register import register_provider_adapter -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entities import LLMResponse @register_provider_adapter( @@ -107,16 +111,67 @@ class ProviderOpenAIOfficial(Provider): logger.debug(f"completion: {completion}") + llm_response = await self.parse_openai_completion(completion, tools) + + return llm_response + + async def _query_stream( + self, payloads: dict, tools: FuncCall + ) -> AsyncGenerator[LLMResponse, None]: + """流式查询API,逐步返回结果""" + if tools: + tool_list = tools.get_func_desc_openai_style() + if tool_list: + payloads["tools"] = tool_list + + # 不在默认参数中的参数放在 extra_body 中 + extra_body = {} + to_del = [] + for key in payloads.keys(): + if key not in self.default_params: + extra_body[key] = payloads[key] + to_del.append(key) + for key in to_del: + del payloads[key] + + stream = await self.client.chat.completions.create( + **payloads, stream=True, extra_body=extra_body + ) + + llm_response = LLMResponse("assistant", is_chunk=True) + + state = ChatCompletionStreamState() + + async for chunk in stream: + state.handle_chunk(chunk) + if len(chunk.choices) == 0: + continue + delta = chunk.choices[0].delta + # 处理文本内容 + if delta.content: + completion_text = delta.content + llm_response.result_chain = MessageChain(chain=[Comp.Plain(completion_text)]) + yield llm_response + + final_completion = state.get_final_completion() + llm_response = await self.parse_openai_completion(final_completion, tools) + + yield llm_response + + async def parse_openai_completion( + self, completion: ChatCompletion, tools: FuncCall + ): + """解析 OpenAI 的 ChatCompletion 响应""" + llm_response = LLMResponse("assistant") + if len(completion.choices) == 0: raise Exception("API 返回的 completion 为空。") choice = completion.choices[0] - llm_response = LLMResponse("assistant") - if choice.message.content: # text completion completion_text = str(choice.message.content).strip() - llm_response.completion_text = completion_text + llm_response.result_chain = MessageChain().message(completion_text) if choice.message.tool_calls: # tools call (function calling) @@ -148,7 +203,7 @@ class ProviderOpenAIOfficial(Provider): return llm_response - async def text_chat( + async def _prepare_chat_payload( self, prompt: str, session_id: str = None, @@ -158,7 +213,8 @@ class ProviderOpenAIOfficial(Provider): system_prompt=None, tool_calls_result=None, **kwargs, - ) -> LLMResponse: + ) -> tuple: + """准备聊天所需的有效载荷和上下文""" new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] if system_prompt: @@ -177,8 +233,117 @@ class ProviderOpenAIOfficial(Provider): payloads = {"messages": context_query, **model_config} - llm_response = None + return payloads, context_query, func_tool + async def _handle_api_error( + self, + e: Exception, + payloads: dict, + context_query: list, + func_tool: FuncCall, + chosen_key: str, + available_api_keys: List[str], + retry_cnt: int, + max_retries: int, + ) -> tuple: + """处理API错误并尝试恢复""" + if "429" in str(e): + logger.warning( + f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}" + ) + # 最后一次不等待 + if retry_cnt < max_retries - 1: + await asyncio.sleep(1) + available_api_keys.remove(chosen_key) + if len(available_api_keys) > 0: + chosen_key = random.choice(available_api_keys) + return ( + False, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) + else: + raise e + elif "maximum context length" in str(e): + logger.warning( + f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" + ) + await self.pop_record(context_query) + payloads["messages"] = context_query + return ( + False, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) + elif "The model is not a VLM" in str(e): # siliconcloud + # 尝试删除所有 image + new_contexts = await self._remove_image_from_context(context_query) + payloads["messages"] = new_contexts + context_query = new_contexts + return ( + False, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) + elif ( + "Function calling is not enabled" in str(e) + or ("tool" in str(e).lower() and "support" in str(e).lower()) + or ("function" in str(e).lower() and "support" in str(e).lower()) + ): + # openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配 + logger.info( + f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。" + ) + if "tools" in payloads: + del payloads["tools"] + return False, chosen_key, available_api_keys, payloads, context_query, None + else: + logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}") + + if "tool" in str(e).lower() and "support" in str(e).lower(): + logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all") + + if "Connection error." in str(e): + proxy = os.environ.get("http_proxy", None) + if proxy: + logger.error( + f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}" + ) + + raise e + + async def text_chat( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, + contexts=[], + system_prompt=None, + tool_calls_result=None, + **kwargs, + ) -> LLMResponse: + payloads, context_query, func_tool = await self._prepare_chat_payload( + prompt, + session_id, + image_urls, + func_tool, + contexts, + system_prompt, + tool_calls_result, + **kwargs, + ) + + llm_response = None max_retries = 10 available_api_keys = self.api_keys.copy() chosen_key = random.choice(available_api_keys) @@ -197,64 +362,97 @@ class ProviderOpenAIOfficial(Provider): payloads["messages"] = new_contexts context_query = new_contexts except Exception as e: - if "429" in str(e): - logger.warning( - f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}" - ) - # 最后一次不等待 - if retry_cnt < max_retries - 1: - await asyncio.sleep(1) - available_api_keys.remove(chosen_key) - if len(available_api_keys) > 0: - chosen_key = random.choice(available_api_keys) - continue - else: - raise e - elif "maximum context length" in str(e): - logger.warning( - f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" - ) - await self.pop_record(context_query) - elif "The model is not a VLM" in str(e): # siliconcloud - # 尝试删除所有 image - new_contexts = await self._remove_image_from_context(context_query) - payloads["messages"] = new_contexts - elif ( - "Function calling is not enabled" in str(e) - or ("tool" in str(e).lower() and "support" in str(e).lower()) - or ("function" in str(e).lower() and "support" in str(e).lower()) - ): - # openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配 - logger.info( - f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。" - ) - if "tools" in payloads: - del payloads["tools"] - func_tool = None - else: - logger.error( - f"发生了错误。Provider 配置如下: {self.provider_config}" - ) - - if "tool" in str(e).lower() and "support" in str(e).lower(): - logger.error( - "疑似该模型不支持函数调用工具调用。请输入 /tool off_all" - ) - - if "Connection error." in str(e): - proxy = os.environ.get("http_proxy", None) - if proxy: - logger.error( - f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}" - ) - - raise e + ( + success, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) = await self._handle_api_error( + e, + payloads, + context_query, + func_tool, + chosen_key, + available_api_keys, + retry_cnt, + max_retries, + ) + if success: + break if retry_cnt == max_retries - 1: logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。") raise e return llm_response + async def text_chat_stream( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, + contexts=[], + system_prompt=None, + tool_calls_result=None, + **kwargs, + ) -> AsyncGenerator[LLMResponse, None]: + """流式对话,与服务商交互并逐步返回结果""" + payloads, context_query, func_tool = await self._prepare_chat_payload( + prompt, + session_id, + image_urls, + func_tool, + contexts, + system_prompt, + tool_calls_result, + **kwargs, + ) + + max_retries = 10 + available_api_keys = self.api_keys.copy() + chosen_key = random.choice(available_api_keys) + + e = None + retry_cnt = 0 + for retry_cnt in range(max_retries): + try: + self.client.api_key = chosen_key + async for response in self._query_stream(payloads, func_tool): + yield response + break + except UnprocessableEntityError as e: + logger.warning(f"不可处理的实体错误:{e},尝试删除图片。") + # 尝试删除所有 image + new_contexts = await self._remove_image_from_context(context_query) + payloads["messages"] = new_contexts + context_query = new_contexts + except Exception as e: + ( + success, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) = await self._handle_api_error( + e, + payloads, + context_query, + func_tool, + chosen_key, + available_api_keys, + retry_cnt, + max_retries, + ) + if success: + break + + if retry_cnt == max_retries - 1: + logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。") + raise e + async def _remove_image_from_context(self, contexts: List): """ 从上下文中删除所有带有 image 的记录 diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index f120a6a59..20b00f949 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -1,7 +1,7 @@ import uuid from openai import AsyncOpenAI, NOT_GIVEN from ..provider import TTSProvider -from ..entites import ProviderType +from ..entities import ProviderType from ..register import register_provider_adapter diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index 84087ecf6..b6e3331f8 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -11,7 +11,7 @@ import re from funasr_onnx import SenseVoiceSmall from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess from ..provider import STTProvider -from ..entites import ProviderType +from ..entities import ProviderType from astrbot.core.utils.io import download_file from ..register import register_provider_adapter from astrbot.core import logger diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index e38a81de9..0009af906 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -2,7 +2,7 @@ import uuid import os from openai import AsyncOpenAI, NOT_GIVEN from ..provider import STTProvider -from ..entites import ProviderType +from ..entities import ProviderType from astrbot.core.utils.io import download_file from ..register import register_provider_adapter from astrbot.core import logger diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index cfd1267d0..96f0b6f6d 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -3,7 +3,7 @@ import os import asyncio import whisper from ..provider import STTProvider -from ..entites import ProviderType +from ..entities import ProviderType from astrbot.core.utils.io import download_file from ..register import register_provider_adapter from astrbot.core import logger diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py index 3e819d633..2f7490317 100644 --- a/astrbot/core/provider/sources/zhipu_source.py +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -3,7 +3,7 @@ from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from typing import List from ..register import register_provider_adapter -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entities import LLMResponse from .openai_source import ProviderOpenAIOfficial diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index db1461f59..0f4e82ea4 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -161,42 +161,53 @@ class ChatRoute(Route): username = g.get("username", "guest") if username in self.curr_chat_sse: - return "[ERROR]\n" + return Response().error("Already connected").__dict__ self.curr_chat_sse[username] = None + heartbeat = json.dumps({"type": "heartbeat", "data": "ping"}) + async def stream(): try: - yield "[HB]\n" + yield f"data: {heartbeat}\n\n" # 心跳包 while True: try: result = await asyncio.wait_for( web_chat_back_queue.get(), timeout=10 ) # 设置超时时间为5秒 except asyncio.TimeoutError: - yield "[HB]\n" # 心跳包 + yield f"data: {heartbeat}\n\n" # 心跳包 continue if not result: continue - result_text, cid = result + + result_text = result["data"] + type = result.get("type") + cid = result.get("cid") + streaming = result.get("streaming", False) if cid != self.curr_user_cid.get(username): # 丢弃 continue - yield result_text + "\n" + yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" + await asyncio.sleep(0.15) - conversation = self.db.get_conversation_by_user_id(username, cid) - try: - history = json.loads(conversation.history) - except BaseException as e: - print(e) - history = [] - history.append({"type": "bot", "message": result_text}) - self.db.update_conversation( - username, cid, history=json.dumps(history) - ) + if streaming and type != "end": + continue - await asyncio.sleep(0.5) + if result_text: + conversation = self.db.get_conversation_by_user_id( + username, cid + ) + try: + history = json.loads(conversation.history) + except BaseException as e: + print(e) + history = [] + history.append({"type": "bot", "message": result_text}) + self.db.update_conversation( + username, cid, history=json.dumps(history) + ) except BaseException as _: logger.debug(f"用户 {username} 断开聊天长连接。") self.curr_chat_sse.pop(username) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 629a424f1..2747865e4 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -179,7 +179,7 @@ class ConfigRoute(Route): await self._save_astrbot_configs(post_configs) return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__ except Exception as e: - logger.error(e) + logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ async def post_plugin_configs(self): diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index 81c02374b..b00e96d66 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -1,6 +1,7 @@