diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index da603d465..6b226d48a 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,3 +8,7 @@ ### Modifications + +### Check +- [ ] 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary) +- [ ] 我新增/修复/优化的功能经过良好的测试 diff --git a/README.md b/README.md index 9791c037b..5444b7dd8 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,8 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_ Static Badge [![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e) ![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=10800&style=for-the-badge&color=3b618e) +![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=7200) + English日本語 | diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py index 557273acd..9b1ade50a 100644 --- a/astrbot/api/provider/__init__.py +++ b/astrbot/api/provider/__init__.py @@ -1,5 +1,5 @@ from astrbot.core.provider import Provider, STTProvider, Personality -from astrbot.core.provider.entites import ( +from astrbot.core.provider.entities import ( ProviderRequest, ProviderType, ProviderMetaData, diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 00e27c15c..4a6364979 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -50,6 +50,7 @@ DEFAULT_CONFIG = { "default_personality": "default", "prompt_prefix": "", "max_context_length": -1, + "streaming_response": False, }, "provider_stt_settings": { "enable": False, @@ -247,6 +248,9 @@ CONFIG_METADATA_2 = { "description": "平台设置", "type": "object", "items": { + "plugin_enable": { + "invisible": True, # 隐藏插件启用配置 + }, "unique_session": { "description": "会话隔离", "type": "bool", @@ -993,6 +997,11 @@ CONFIG_METADATA_2 = { "type": "int", "hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。", }, + "streaming_response": { + "description": "启用流式回复", + "type": "bool", + "hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台", + }, }, }, "persona": { diff --git a/astrbot/core/log.py b/astrbot/core/log.py index e1e2cde2e..6609b8246 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -141,11 +141,13 @@ class LogQueueHandler(logging.Handler): record (logging.LogRecord): 日志记录对象, 包含日志信息 """ log_entry = self.format(record) - self.log_broker.publish({ - "level": record.levelname, - "time": record.asctime, - "data": log_entry, - }) + self.log_broker.publish( + { + "level": record.levelname, + "time": record.asctime, + "data": log_entry, + } + ) class LogManager: diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 50e50ceb5..28c92fa89 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -1,6 +1,6 @@ import enum -from typing import List, Optional, Union +from typing import List, Optional, Union, AsyncGenerator from dataclasses import dataclass, field from astrbot.core.message.components import ( BaseMessageComponent, @@ -111,6 +111,30 @@ class MessageChain: """获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。""" return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)]) + def squash_plain(self): + """将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。""" + if not self.chain: + return + + new_chain = [] + first_plain = None + plain_texts = [] + + for comp in self.chain: + if isinstance(comp, Plain): + if first_plain is None: + first_plain = comp + new_chain.append(comp) + plain_texts.append(comp.text) + else: + new_chain.append(comp) + + if first_plain is not None: + first_plain.text = "".join(plain_texts) + + self.chain = new_chain + return self + class EventResultType(enum.Enum): """用于描述事件处理的结果类型。 @@ -131,6 +155,10 @@ class ResultContentType(enum.Enum): """调用 LLM 产生的结果""" GENERAL_RESULT = enum.auto() """普通的消息结果""" + STREAMING_RESULT = enum.auto() + """调用 LLM 产生的流式结果""" + STREAMING_FINISH= enum.auto() + """流式输出完成""" @dataclass @@ -152,6 +180,9 @@ class MessageEventResult(MessageChain): default_factory=lambda: ResultContentType.GENERAL_RESULT ) + async_stream: Optional[AsyncGenerator] = None + """异步流""" + def stop_event(self) -> "MessageEventResult": """终止事件传播。""" self.result_type = EventResultType.STOP @@ -168,6 +199,11 @@ class MessageEventResult(MessageChain): """ return self.result_type == EventResultType.STOP + def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult": + """设置异步流。""" + self.async_stream = stream + return self + def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult": """设置事件处理的结果类型。 diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index b97fc0f12..406fcc796 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -7,6 +7,7 @@ from .waking_check.stage import WakingCheckStage from .whitelist_check.stage import WhitelistCheckStage from .rate_limit_check.stage import RateLimitStage from .content_safety_check.stage import ContentSafetyCheckStage +from .platform_compatibility.stage import PlatformCompatibilityStage from .preprocess_stage.stage import PreProcessStage from .process_stage.stage import ProcessStage from .result_decorate.stage import ResultDecorateStage @@ -18,6 +19,7 @@ STAGES_ORDER = [ "WhitelistCheckStage", # 检查是否在群聊/私聊白名单 "RateLimitStage", # 检查会话是否超过频率限制 "ContentSafetyCheckStage", # 检查内容安全 + "PlatformCompatibilityStage", # 检查所有处理器的平台兼容性 "PreProcessStage", # 预处理 "ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用 "ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等 @@ -29,6 +31,7 @@ __all__ = [ "WhitelistCheckStage", "RateLimitStage", "ContentSafetyCheckStage", + "PlatformCompatibilityStage", "PreProcessStage", "ProcessStage", "ResultDecorateStage", diff --git a/astrbot/core/pipeline/platform_compatibility/stage.py b/astrbot/core/pipeline/platform_compatibility/stage.py new file mode 100644 index 000000000..644912c26 --- /dev/null +++ b/astrbot/core/pipeline/platform_compatibility/stage.py @@ -0,0 +1,56 @@ +from ..stage import Stage, register_stage +from ..context import PipelineContext +from typing import Union, AsyncGenerator +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.star.star import star_map +from astrbot.core.star.star_handler import StarHandlerMetadata +from astrbot.core import logger + + +@register_stage +class PlatformCompatibilityStage(Stage): + """检查所有处理器的平台兼容性。 + + 这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。 + """ + + async def initialize(self, ctx: PipelineContext) -> None: + """初始化平台兼容性检查阶段 + + Args: + ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器 + """ + self.ctx = ctx + + async def process( + self, event: AstrMessageEvent + ) -> Union[None, AsyncGenerator[None, None]]: + # 获取当前平台ID + platform_id = event.get_platform_id() + + # 获取已激活的处理器 + activated_handlers = event.get_extra("activated_handlers") + if activated_handlers is None: + activated_handlers = [] + + # 标记不兼容的处理器 + for handler in activated_handlers: + if not isinstance(handler, StarHandlerMetadata): + continue + # 检查处理器是否在当前平台启用 + enabled = handler.is_enabled_for_platform(platform_id) + if not enabled: + if handler.handler_module_path in star_map: + plugin_name = star_map[handler.handler_module_path].name + logger.debug( + f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容" + ) + # 设置处理器为平台不兼容状态 + # TODO: 更好的标记方式 + handler.platform_compatible = False + else: + # 确保处理器为平台兼容状态 + handler.platform_compatible = True + + # 更新已激活的处理器列表 + event.set_extra("activated_handlers", activated_handlers) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 99b09b5a1..8a7062fd7 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -12,11 +12,12 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.message.message_event_result import ( MessageEventResult, ResultContentType, + MessageChain, ) from astrbot.core.message.components import Image from astrbot.core import logger from astrbot.core.utils.metrics import Metric -from astrbot.core.provider.entites import ( +from astrbot.core.provider.entities import ( ProviderRequest, LLMResponse, ToolCallMessageSegment, @@ -37,6 +38,9 @@ class LLMRequestSubStage(Stage): self.max_context_length = ctx.astrbot_config["provider_settings"][ "max_context_length" ] # int + self.streaming_response = ctx.astrbot_config["provider_settings"][ + "streaming_response" + ] # bool for bwp in self.bot_wake_prefixs: if self.provider_wake_prefix.startswith(bwp): @@ -146,8 +150,10 @@ class LLMRequestSubStage(Stage): # 执行请求 LLM 前事件钩子。 # 装饰 system_prompt 等功能 + # 获取当前平台ID + platform_id = event.get_platform_id() handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnLLMRequestEvent + EventType.OnLLMRequestEvent, platform_id=platform_id ) for handler in handlers: try: @@ -179,70 +185,127 @@ class LLMRequestSubStage(Stage): if not req.session_id: req.session_id = event.unified_msg_origin - try: - need_loop = True - while need_loop: - need_loop = False - logger.debug(f"提供商请求 Payload: {req}") - llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM + async def requesting(req: ProviderRequest): + try: + need_loop = True + while need_loop: + need_loop = False + logger.debug(f"提供商请求 Payload: {req}") - # 执行 LLM 响应后的事件钩子。 - handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnLLMResponseEvent - ) - for handler in handlers: - try: - logger.debug( - f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" - ) - await handler.handler(event, llm_response) - except BaseException: - logger.error(traceback.format_exc()) + final_llm_response = None - if event.is_stopped(): - logger.info( - f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" - ) - return - - async for result in self._handle_llm_response(event, req, llm_response): - if isinstance(result, ProviderRequest): - # 有函数工具调用并且返回了结果,我们需要再次请求 LLM - req = result - need_loop = True + if self.streaming_response: + stream = provider.text_chat_stream(**req.__dict__) + async for llm_response in stream: + if llm_response.is_chunk: + if llm_response.result_chain: + yield llm_response.result_chain # MessageChain + else: + yield MessageChain().message( + llm_response.completion_text + ) + else: + final_llm_response = llm_response else: - yield + final_llm_response = await provider.text_chat( + **req.__dict__ + ) # 请求 LLM - asyncio.create_task( - Metric.upload( - llm_tick=1, - model_name=provider.get_model(), - provider_type=provider.meta().type, + if not final_llm_response: + raise Exception("LLM response is None.") + + # 执行 LLM 响应后的事件钩子。 + handlers = star_handlers_registry.get_handlers_by_event_type( + EventType.OnLLMResponseEvent + ) + for handler in handlers: + try: + logger.debug( + f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" + ) + await handler.handler(event, final_llm_response) + except BaseException: + logger.error(traceback.format_exc()) + + if event.is_stopped(): + logger.info( + f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。" + ) + return + + if self.streaming_response: + # 流式输出的处理 + async for result in self._handle_llm_stream_response( + event, req, final_llm_response + ): + if isinstance(result, ProviderRequest): + # 有函数工具调用并且返回了结果,我们需要再次请求 LLM + req = result + need_loop = True + else: + yield + else: + # 非流式输出的处理 + async for result in self._handle_llm_response( + event, req, final_llm_response + ): + if isinstance(result, ProviderRequest): + # 有函数工具调用并且返回了结果,我们需要再次请求 LLM + req = result + need_loop = True + else: + yield + + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=provider.get_model(), + provider_type=provider.meta().type, + ) ) - ) - # 保存到历史记录 - await self._save_to_history(event, req, llm_response) + # 保存到历史记录 + await self._save_to_history(event, req, final_llm_response) - except BaseException as e: - logger.error(traceback.format_exc()) + except BaseException as e: + logger.error(traceback.format_exc()) + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" + ) + ) + + if not self.streaming_response: + event.set_extra("tool_call_result", None) + async for _ in requesting(req): + yield + else: event.set_result( - MessageEventResult().message( - f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}" - ) + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream(requesting(req)) ) - return + # 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理 + yield + + if event.get_extra("tool_call_result"): + event.set_result(event.get_extra("tool_call_result")) + event.set_extra("tool_call_result", None) + yield async def _handle_llm_response( - self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse - ) -> AsyncGenerator[None, None]: - """处理 LLM 响应。 + self, + event: AstrMessageEvent, + req: ProviderRequest, + llm_response: LLMResponse, + ) -> AsyncGenerator[Union[None, ProviderRequest], None]: + """处理非流式 LLM 响应。 Returns: - bool: 是否需要继续调用 LLM + AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM Yields: - Iterator[bool]: 将 event 交付给下一个 stage + Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM """ if llm_response.role == "assistant": # text completion @@ -265,83 +328,147 @@ class LLMRequestSubStage(Stage): ) ) elif llm_response.role == "tool": - # function calling - tool_call_result: list[ToolCallMessageSegment] = [] - logger.info( - f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" + # 处理函数工具调用 + async for result in self._handle_function_tools(event, req, llm_response): + yield result + + async def _handle_llm_stream_response( + self, + event: AstrMessageEvent, + req: ProviderRequest, + llm_response: LLMResponse, + ) -> AsyncGenerator[Union[None, ProviderRequest], None]: + """处理流式 LLM 响应。 + + 专门用于处理流式输出完成后的响应,与非流式响应处理分离。 + + Returns: + AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM + + Yields: + Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM + """ + if llm_response.role == "assistant": + # text completion + if llm_response.result_chain: + event.set_result( + MessageEventResult( + chain=llm_response.result_chain.chain + ).set_result_content_type(ResultContentType.STREAMING_FINISH) + ) + else: + event.set_result( + MessageEventResult() + .message(llm_response.completion_text) + .set_result_content_type(ResultContentType.STREAMING_FINISH) + ) + elif llm_response.role == "err": + event.set_result( + MessageEventResult().message( + f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}" + ) ) - for func_tool_name, func_tool_args, func_tool_id in zip( - llm_response.tools_call_name, - llm_response.tools_call_args, - llm_response.tools_call_ids, - ): - try: - func_tool = req.func_tool.get_func(func_tool_name) - if func_tool.origin == "mcp": - logger.info( - f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" + elif llm_response.role == "tool": + # 处理函数工具调用 + async for result in self._handle_function_tools(event, req, llm_response): + yield result + + async def _handle_function_tools( + self, + event: AstrMessageEvent, + req: ProviderRequest, + llm_response: LLMResponse, + ) -> AsyncGenerator[Union[None, ProviderRequest], None]: + """处理函数工具调用。 + + Returns: + AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM + """ + # function calling + tool_call_result: list[ToolCallMessageSegment] = [] + logger.info( + f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}" + ) + for func_tool_name, func_tool_args, func_tool_id in zip( + llm_response.tools_call_name, + llm_response.tools_call_args, + llm_response.tools_call_ids, + ): + try: + func_tool = req.func_tool.get_func(func_tool_name) + if func_tool.origin == "mcp": + logger.info( + f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" + ) + client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name] + res = await client.session.call_tool(func_tool.name, func_tool_args) + if res: + # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=res.content[0].text, + ) ) - client = req.func_tool.mcp_client_dict[ - func_tool.mcp_server_name - ] - res = await client.session.call_tool( - func_tool.name, func_tool_args + else: + # 获取处理器,过滤掉平台不兼容的处理器 + platform_id = event.get_platform_id() + if not func_tool.handler.is_enabled_for_platform(platform_id): + logger.debug( + f"处理器 {func_tool_name} 在当前平台不兼容,跳过执行" ) - if res: - # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 + # 直接跳过,不添加任何消息到tool_call_result + continue + + logger.info( + f"调用工具函数:{func_tool_name},参数:{func_tool_args}" + ) + # 尝试调用工具函数 + wrapper = self._call_handler( + self.ctx, event, func_tool.handler, **func_tool_args + ) + async for resp in wrapper: + if resp is not None: # 有 return 返回 tool_call_result.append( ToolCallMessageSegment( role="tool", tool_call_id=func_tool_id, - content=res.content[0].text, + content=resp, ) ) - else: - logger.info( - f"调用工具函数:{func_tool_name},参数:{func_tool_args}" - ) - # 尝试调用工具函数 - wrapper = self._call_handler( - self.ctx, event, func_tool.handler, **func_tool_args - ) - async for resp in wrapper: - if resp is not None: # 有 return 返回 - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=resp, - ) - ) - else: - yield # 有生成器返回 - event.clear_result() # 清除上一个 handler 的结果 - except BaseException as e: - logger.warning(traceback.format_exc()) - tool_call_result.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=f"error: {str(e)}", - ) + else: + res = event.get_result() + if res and res.chain: + event.set_extra("tool_call_result", res) + yield # 有生成器返回 + event.clear_result() # 清除上一个 handler 的结果 + except BaseException as e: + logger.warning(traceback.format_exc()) + tool_call_result.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=f"error: {str(e)}", ) - if tool_call_result: - # 函数调用结果 - req.func_tool = None # 暂时不支持递归工具调用 - assistant_msg_seg = AssistantMessageSegment( - role="assistant", tool_calls=llm_response.to_openai_tool_calls() ) - # 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。 - req.tool_calls_result = ToolCallsResult( - tool_calls_info=assistant_msg_seg, - tool_calls_result=tool_call_result, + if tool_call_result: + # 函数调用结果 + req.func_tool = None # 暂时不支持递归工具调用 + assistant_msg_seg = AssistantMessageSegment( + role="assistant", tool_calls=llm_response.to_openai_tool_calls() + ) + # 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。 + req.tool_calls_result = ToolCallsResult( + tool_calls_info=assistant_msg_seg, + tool_calls_result=tool_call_result, + ) + yield req # 再次执行 LLM 请求 + else: + if llm_response.completion_text: + event.set_result( + MessageEventResult().message(llm_response.completion_text) ) - yield req # 再次执行 LLM 请求 - else: - if llm_response.completion_text: - event.set_result( - MessageEventResult().message(llm_response.completion_text) - ) async def _save_to_history( self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index d369e53ed..c7817e49c 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -31,7 +31,18 @@ class StarRequestSubStage(Stage): ) if not handlers_parsed_params: handlers_parsed_params = {} + for handler in activated_handlers: + # 检查处理器是否在当前平台兼容 + if ( + hasattr(handler, "platform_compatible") + and handler.platform_compatible is False + ): + logger.debug( + f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行" + ) + continue + params = handlers_parsed_params.get(handler.handler_full_name, {}) try: if handler.handler_module_path not in star_map: diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index 4c52a4a3e..f653a9fb9 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -5,7 +5,7 @@ from .method.llm_request import LLMRequestSubStage from .method.star_request import StarRequestSubStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star_handler import StarHandlerMetadata -from astrbot.core.provider.entites import ProviderRequest +from astrbot.core.provider.entities import ProviderRequest from astrbot.core import logger diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 86c165945..60a052454 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -7,7 +7,7 @@ from typing import Union, AsyncGenerator from ..stage import register_stage, Stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.message.message_event_result import MessageChain, ResultContentType from astrbot.core import logger from astrbot.core.message.message_event_result import BaseMessageComponent from astrbot.core.star.star_handler import star_handlers_registry, EventType @@ -18,7 +18,9 @@ from astrbot.core.star.star import star_map class RespondStage(Stage): # 组件类型到其非空判断函数的映射 _component_validators = { - Comp.Plain: lambda comp: bool(comp.text and comp.text.strip()), # 纯文本消息需要strip + Comp.Plain: lambda comp: bool( + comp.text and comp.text.strip() + ), # 纯文本消息需要strip Comp.Face: lambda comp: comp.id is not None, # QQ表情 Comp.Record: lambda comp: bool(comp.file), # 语音 Comp.Video: lambda comp: bool(comp.file), # 视频 @@ -31,13 +33,17 @@ class RespondStage(Stage): Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享 Comp.Contact: lambda comp: True, # 联系人(未完成) Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置 - Comp.Music: lambda comp: bool(comp._type) and bool(comp.url) and bool(comp.audio), # 音乐 + Comp.Music: lambda comp: bool(comp._type) + and bool(comp.url) + and bool(comp.audio), # 音乐 Comp.Image: lambda comp: bool(comp.file), # 图片 Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复 Comp.RedBag: lambda comp: bool(comp.title), # 红包 Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳 Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发 - Comp.Node: lambda comp: bool(comp.name) and comp.uin != 0 and bool(comp.content), # 一个转发节点 + Comp.Node: lambda comp: bool(comp.name) + and comp.uin != 0 + and bool(comp.content), # 一个转发节点 Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点 Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML Comp.Json: lambda comp: bool(comp.data), # JSON @@ -132,8 +138,17 @@ class RespondStage(Stage): result = event.get_result() if result is None: return + if result.result_content_type == ResultContentType.STREAMING_FINISH: + return - if len(result.chain) > 0: + if result.result_content_type == ResultContentType.STREAMING_RESULT: + # 流式结果直接交付平台适配器处理 + logger.info(f"应用流式输出({event.get_platform_name()})") + await event._pre_send() + await event.send_streaming(result.async_stream) + await event._post_send() + return + elif len(result.chain) > 0: await event._pre_send() # 检查消息链是否为空 @@ -183,7 +198,7 @@ class RespondStage(Stage): ) handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnAfterMessageSentEvent + EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id() ) for handler in handlers: try: diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index d7bb9583c..957e2a491 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -5,6 +5,7 @@ from typing import Union, AsyncGenerator from ..stage import Stage, register_stage, registered_stages from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.message.message_event_result import ResultContentType from astrbot.core.platform.message_type import MessageType from astrbot.core import logger from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node @@ -72,11 +73,17 @@ class ResultDecorateStage(Stage): if result is None or not result.chain: return + if result.result_content_type == ResultContentType.STREAMING_RESULT: + return + + is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH + # 回复时检查内容安全 if ( self.content_safe_check_reply and self.content_safe_check_stage and result.is_llm_result() + and not is_stream # 流式输出不检查内容安全 ): text = "" for comp in result.chain: @@ -89,13 +96,17 @@ class ResultDecorateStage(Stage): # 发送消息前事件钩子 handlers = star_handlers_registry.get_handlers_by_event_type( - EventType.OnDecoratingResultEvent + EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id() ) for handler in handlers: try: logger.debug( f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}" ) + if is_stream: + logger.warning( + "启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作" + ) await handler.handler(event) if event.get_result() is None or not event.get_result().chain: logger.debug( @@ -110,6 +121,11 @@ class ResultDecorateStage(Stage): ) return + # 流式输出不执行下面的逻辑 + if is_stream: + logger.info("流式输出已启用,跳过结果装饰阶段") + return + # 需要再获取一次。插件可能直接对 chain 进行了替换。 result = event.get_result() if result is None: diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index dfe19dc85..cfc905693 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -1,5 +1,6 @@ from ..stage import Stage, register_stage from ..context import PipelineContext +from astrbot import logger from typing import Union, AsyncGenerator from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.message.message_event_result import MessageEventResult, MessageChain @@ -93,6 +94,7 @@ class WakingCheckStage(Stage): # filter 需满足 AND 逻辑关系 passed = True permission_not_pass = False + permission_filter_raise_error = False if len(handler.event_filters) == 0: continue @@ -101,6 +103,7 @@ class WakingCheckStage(Stage): if isinstance(filter, PermissionTypeFilter): if not filter.filter(event, self.ctx.astrbot_config): permission_not_pass = True + permission_filter_raise_error = filter.raise_error else: if not filter.filter(event, self.ctx.astrbot_config): passed = False @@ -117,6 +120,9 @@ class WakingCheckStage(Stage): break if passed: if permission_not_pass: + if not permission_filter_raise_error: + # 跳过 + continue if self.no_permission_reply: await event.send( MessageChain().message( @@ -124,6 +130,9 @@ class WakingCheckStage(Stage): ) ) await event._post_send() + logger.info( + f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。" + ) event.stop_event() return diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 3e1b14ee6..96a7ad6f1 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,7 +1,7 @@ import abc import asyncio from dataclasses import dataclass -from typing import List, Union, Optional +from typing import List, Union, Optional, AsyncGenerator from astrbot.core.db.po import Conversation from astrbot.core.message.components import ( @@ -16,7 +16,7 @@ from astrbot.core.message.components import ( ) from astrbot.core.message.message_event_result import MessageEventResult, MessageChain from astrbot.core.platform.message_type import MessageType -from astrbot.core.provider.entites import ProviderRequest +from astrbot.core.provider.entities import ProviderRequest from astrbot.core.utils.metrics import Metric from .astrbot_message import AstrBotMessage, Group from .platform_metadata import PlatformMetadata @@ -81,6 +81,9 @@ class AstrMessageEvent(abc.ABC): def get_platform_name(self): return self.platform_meta.name + def get_platform_id(self): + return self.platform_meta.id + def get_message_str(self) -> str: """ 获取消息字符串。 @@ -202,6 +205,15 @@ class AstrMessageEvent(abc.ABC): """ return self.role == "admin" + async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]): + """发送流式消息到消息平台,使用异步生成器。 + 目前仅支持: telegram,qq official 私聊。 + """ + asyncio.create_task( + Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name) + ) + self._has_send_oper = True + async def _pre_send(self): """调度器会在执行 send() 前调用该方法""" diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py index 48fe23af7..dd0e93fec 100644 --- a/astrbot/core/platform/platform_metadata.py +++ b/astrbot/core/platform/platform_metadata.py @@ -7,6 +7,8 @@ class PlatformMetadata: """平台的名称""" description: str """平台的描述""" + id: str = None + """平台的唯一标识符,用于配置中识别特定平台""" default_config_tmpl: dict = None """平台的默认配置模板""" diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 295014ab4..9bb8b938f 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -82,6 +82,19 @@ class AiocqhttpMessageEvent(AstrMessageEvent): await super().send(message) + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) + async def get_group(self, group_id=None, **kwargs): if isinstance(group_id, str) and group_id.isdigit(): group_id = int(group_id) diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index e41071a56..88f2ae3fc 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -39,8 +39,9 @@ class AiocqhttpAdapter(Platform): self.port = platform_config["ws_reverse_port"] self.metadata = PlatformMetadata( - "aiocqhttp", - "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", + name="aiocqhttp", + description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", + id=self.config.get("id"), ) self.bot = CQHttp( diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 95347172b..7a83a8abe 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -73,8 +73,9 @@ class DingtalkPlatformAdapter(Platform): def meta(self) -> PlatformMetadata: return PlatformMetadata( - "dingtalk", - "钉钉机器人官方 API 适配器", + name="dingtalk", + description="钉钉机器人官方 API 适配器", + id=self.config.get("id"), ) async def convert_msg( diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index 7980ecd55..d850a759f 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -24,7 +24,11 @@ class DingtalkMessageEvent(AstrMessageEvent): if isinstance(segment, Comp.Plain): segment.text = segment.text.strip() await asyncio.get_event_loop().run_in_executor( - None, client.reply_markdown, "AstrBot", segment.text, self.message_obj.raw_message + None, + client.reply_markdown, + "AstrBot", + segment.text, + self.message_obj.raw_message, ) elif isinstance(segment, Comp.Image): markdown_str = "" @@ -56,3 +60,16 @@ class DingtalkMessageEvent(AstrMessageEvent): async def send(self, message: MessageChain): await self.send_with_client(self.client, message) await super().send(message) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index 78902a4c5..829a348c6 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -216,3 +216,16 @@ class GewechatPlatformEvent(AstrMessageEvent): group_owner=data.get("chatRoomOwner"), members=members, ) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py index acf39197f..930359837 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_platform_adapter.py @@ -60,8 +60,9 @@ class GewechatPlatformAdapter(Platform): @override def meta(self) -> PlatformMetadata: return PlatformMetadata( - "gewechat", - "基于 gewechat 的 Wechat 适配器", + name="gewechat", + description="基于 gewechat 的 Wechat 适配器", + id=self.config.get("id"), ) async def terminate(self): diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index cbc3a45bb..8ea2ce36b 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -70,8 +70,9 @@ class LarkPlatformAdapter(Platform): def meta(self) -> PlatformMetadata: return PlatformMetadata( - "lark", - "飞书机器人官方 API 适配器", + name="lark", + description="飞书机器人官方 API 适配器", + id=self.config.get("id"), ) async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index e170b76a0..544a7a5be 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -91,3 +91,16 @@ class LarkMessageEvent(AstrMessageEvent): logger.error(f"回复飞书消息失败({response.code}): {response.msg}") await super().send(message) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index d31006618..f74edd1ce 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -2,6 +2,7 @@ import botpy import botpy.message import botpy.types import botpy.types.message +import asyncio from astrbot.core.utils.io import file_to_base64, download_image_by_url from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata @@ -9,6 +10,7 @@ from astrbot.api.message_components import Plain, Image from botpy import Client from botpy.http import Route from astrbot.api import logger +from botpy.types import message class QQOfficialMessageEvent(AstrMessageEvent): @@ -30,8 +32,46 @@ class QQOfficialMessageEvent(AstrMessageEvent): else: self.send_buffer.chain.extend(message.chain) - async def _post_send(self): + async def send_streaming(self, generator): + """流式输出仅支持消息列表私聊""" + stream_payload = {"state": 1, "id": None, "index": 0, "reset": False} + last_edit_time = 0 # 上次编辑消息的时间 + throttle_interval = 1 # 编辑消息的间隔时间 (秒) + try: + async for chain in generator: + source = self.message_obj.raw_message + if not self.send_buffer: + self.send_buffer = chain + else: + self.send_buffer.chain.extend(chain.chain) + + if isinstance(source, botpy.message.C2CMessage): + # 真流式传输 + current_time = asyncio.get_event_loop().time() + time_since_last_edit = current_time - last_edit_time + + if time_since_last_edit >= throttle_interval: + ret = await self._post_send(stream=stream_payload) + stream_payload["index"] += 1 + stream_payload["id"] = ret["id"] + last_edit_time = asyncio.get_event_loop().time() + + if isinstance(source, botpy.message.C2CMessage): + # 结束流式对话,并且传输 buffer 中剩余的消息 + stream_payload["state"] = 10 + ret = await self._post_send(stream=stream_payload) + + except Exception as e: + logger.error(f"发送流式消息时出错: {e}", exc_info=True) + self.send_buffer = None + + return await super().send_streaming(generator) + + async def _post_send(self, stream: dict = None): """QQ 官方 API 仅支持回复一次""" + if not self.send_buffer: + return + source = self.message_obj.raw_message assert isinstance( source, @@ -65,7 +105,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) payload["media"] = media payload["msg_type"] = 7 - await self.bot.api.post_group_message( + ret = await self.bot.api.post_group_message( group_openid=source.group_openid, **payload ) case botpy.message.C2CMessage: @@ -75,22 +115,34 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) payload["media"] = media payload["msg_type"] = 7 - await self.bot.api.post_c2c_message( - openid=source.author.user_openid, **payload - ) + if stream: + ret = await self.post_c2c_message( + openid=source.author.user_openid, + **payload, + stream=stream, + ) + else: + ret = await self.post_c2c_message( + openid=source.author.user_openid, **payload + ) + logger.debug(f"Message sent to C2C: {ret}") case botpy.message.Message: if image_path: payload["file_image"] = image_path - await self.bot.api.post_message(channel_id=source.channel_id, **payload) + ret = await self.bot.api.post_message( + channel_id=source.channel_id, **payload + ) case botpy.message.DirectMessage: if image_path: payload["file_image"] = image_path - await self.bot.api.post_dms(guild_id=source.guild_id, **payload) + ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload) await super().send(self.send_buffer) self.send_buffer = None + return ret + async def upload_group_and_c2c_image( self, image_base64: str, file_type: int, **kwargs ) -> botpy.types.message.Media: @@ -112,6 +164,27 @@ class QQOfficialMessageEvent(AstrMessageEvent): ) return await self.bot.api._http.request(route, json=payload) + async def post_c2c_message( + self, + openid: str, + msg_type: int = 0, + content: str = None, + embed: message.Embed = None, + ark: message.Ark = None, + message_reference: message.Reference = None, + media: message.Media = None, + msg_id: str = None, + msg_seq: str = 1, + event_id: str = None, + markdown: message.MarkdownPayload = None, + keyboard: message.Keyboard = None, + stream: dict = None, + ) -> message.Message: + payload = locals() + payload.pop("self", None) + route = Route("POST", "/v2/users/{openid}/messages", openid=openid) + return await self.bot.api._http.request(route, json=payload) + @staticmethod async def _parse_to_qqofficial(message: MessageChain): plain_text = "" diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 57bc8683f..d5285f759 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -126,8 +126,9 @@ class QQOfficialPlatformAdapter(Platform): def meta(self) -> PlatformMetadata: return PlatformMetadata( - "qq_official", - "QQ 机器人官方 API 适配器", + name="qq_official", + description="QQ 机器人官方 API 适配器", + id=self.config.get("id"), ) @staticmethod diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 226a1276d..cc12e9765 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -99,8 +99,9 @@ class QQOfficialWebhookPlatformAdapter(Platform): def meta(self) -> PlatformMetadata: return PlatformMetadata( - "qq_official_webhook", - "QQ 机器人官方 API 适配器", + name="qq_official_webhook", + description="QQ 机器人官方 API 适配器", + id=self.config.get("id"), ) async def run(self): @@ -116,5 +117,8 @@ class QQOfficialWebhookPlatformAdapter(Platform): async def terminate(self): self.webhook_helper.shutdown_event.set() await self.client.close() - await self.webhook_helper.server.shutdown() + try: + await self.webhook_helper.server.shutdown() + except Exception as _: + pass logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭") diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 12f17a819..9ff761c06 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -80,8 +80,7 @@ class TelegramPlatformAdapter(Platform): @override def meta(self) -> PlatformMetadata: return PlatformMetadata( - "telegram", - "telegram 适配器", + name="telegram", description="telegram 适配器", id=self.config.get("id") ) @override diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index eab41ad84..bcc9189c2 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -1,7 +1,15 @@ +import asyncio import telegramify_markdown from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType -from astrbot.api.message_components import Plain, Image, Reply, At, File, Record +from astrbot.api.message_components import ( + Plain, + Image, + Reply, + At, + File, + Record, +) from telegram.ext import ExtBot from astrbot.core.utils.io import download_file from astrbot import logger @@ -82,3 +90,109 @@ class TelegramPlatformEvent(AstrMessageEvent): else: await self.send_with_client(self.client, message, self.get_sender_id()) await super().send(message) + + async def send_streaming(self, generator): + message_thread_id = None + + if self.get_message_type() == MessageType.GROUP_MESSAGE: + user_name = self.message_obj.group_id + else: + user_name = self.get_sender_id() + + if "#" in user_name: + # it's a supergroup chat with message_thread_id + user_name, message_thread_id = user_name.split("#") + payload = { + "chat_id": user_name, + } + if message_thread_id: + payload["reply_to_message_id"] = message_thread_id + + delta = "" + current_content = "" + message_id = None + last_edit_time = 0 # 上次编辑消息的时间 + throttle_interval = 0.6 # 编辑消息的间隔时间 (秒) + + async for chain in generator: + if isinstance(chain, MessageChain): + # 处理消息链中的每个组件 + for i in chain.chain: + if isinstance(i, Plain): + delta += i.text + elif isinstance(i, Image): + image_path = await i.convert_to_file_path() + await self.client.send_photo(photo=image_path, **payload) + continue + elif isinstance(i, File): + if i.file.startswith("https://"): + path = "data/temp/" + i.name + await download_file(i.file, path) + i.file = path + + await self.client.send_document( + document=i.file, filename=i.name, **payload + ) + continue + elif isinstance(i, Record): + path = await i.convert_to_file_path() + await self.client.send_voice(voice=path, **payload) + continue + else: + logger.warning(f"不支持的消息类型: {type(i)}") + continue + + # Plain + if not message_id: + try: + msg = await self.client.send_message(text=delta, **payload) + current_content = delta + except Exception as e: + logger.warning(f"发送消息失败(streaming): {e!s}") + message_id = msg.message_id + last_edit_time = ( + asyncio.get_event_loop().time() + ) # 记录初始消息发送时间 + else: + current_time = asyncio.get_event_loop().time() + time_since_last_edit = current_time - last_edit_time + + # 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间 + if time_since_last_edit >= throttle_interval: + # 编辑消息 + try: + await self.client.edit_message_text( + text=delta, + chat_id=payload["chat_id"], + message_id=message_id, + ) + current_content = delta + except Exception as e: + logger.warning(f"编辑消息失败(streaming): {e!s}") + last_edit_time = ( + asyncio.get_event_loop().time() + ) # 更新上次编辑的时间 + + try: + if delta and current_content != delta: + try: + markdown_text = telegramify_markdown.markdownify( + delta, max_line_length=None, normalize_whitespace=False + ) + await self.client.edit_message_text( + text=markdown_text, + chat_id=payload["chat_id"], + message_id=message_id, + parse_mode="MarkdownV2" + ) + except Exception as e: + logger.warning(f"Markdown转换失败,使用普通文本: {e!s}") + await self.client.edit_message_text( + text=delta, + chat_id=payload["chat_id"], + message_id=message_id + ) + except Exception as e: + logger.warning(f"编辑消息失败(streaming): {e!s}") + + return await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 6fa3d5c59..01a042fb8 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -43,8 +43,7 @@ class WebChatAdapter(Platform): self.imgs_dir = "data/webchat/imgs" self.metadata = PlatformMetadata( - "webchat", - "webchat", + name="webchat", description="webchat", id=self.config.get("id") ) async def send_by_session( diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index ef82dbfed..ef5532920 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -16,16 +16,26 @@ class WebChatMessageEvent(AstrMessageEvent): os.makedirs(imgs_dir, exist_ok=True) @staticmethod - async def _send(message: MessageChain, session_id: str): + async def _send(message: MessageChain, session_id: str, streaming: bool = False): if not message: - web_chat_back_queue.put_nowait(None) + await web_chat_back_queue.put( + {"type": "end", "data": "", "streaming": False} + ) return cid = session_id.split("!")[-1] - + data = "" for comp in message.chain: if isinstance(comp, Plain): - web_chat_back_queue.put_nowait((comp.text, cid)) + data = comp.text + await web_chat_back_queue.put( + { + "type": "plain", + "cid": cid, + "data": data, + "streaming": streaming, + } + ) elif isinstance(comp, Image): # save image to local filename = str(uuid.uuid4()) + ".jpg" @@ -46,7 +56,15 @@ class WebChatMessageEvent(AstrMessageEvent): with open(path, "wb") as f: with open(comp.file, "rb") as f2: f.write(f2.read()) - web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid)) + data = f"[IMAGE]{filename}" + await web_chat_back_queue.put( + { + "type": "image", + "cid": cid, + "data": data, + "streaming": streaming, + } + ) elif isinstance(comp, Record): # save record to local filename = str(uuid.uuid4()) + ".wav" @@ -62,11 +80,45 @@ class WebChatMessageEvent(AstrMessageEvent): with open(path, "wb") as f: with open(comp.file, "rb") as f2: f.write(f2.read()) - web_chat_back_queue.put_nowait((f"[RECORD]{filename}", cid)) + data = f"[RECORD]{filename}" + await web_chat_back_queue.put( + { + "type": "record", + "cid": cid, + "data": data, + "streaming": streaming, + } + ) else: logger.debug(f"webchat 忽略: {comp.type}") - web_chat_back_queue.put_nowait(None) + + return data async def send(self, message: MessageChain): await WebChatMessageEvent._send(message, session_id=self.session_id) + await web_chat_back_queue.put( + { + "type": "end", + "data": "", + "streaming": False, + "cid": self.session_id.split("!")[-1], + } + ) await super().send(message) + + async def send_streaming(self, generator): + final_data = "" + async for chain in generator: + final_data += await WebChatMessageEvent._send( + chain, session_id=self.session_id, streaming=True + ) + + await web_chat_back_queue.put( + { + "type": "end", + "data": final_data, + "streaming": True, + "cid": self.session_id.split("!")[-1], + } + ) + await super().send_streaming(generator) diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 470b7b1f8..d8ee8b9a3 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -84,3 +84,16 @@ class WecomPlatformEvent(AstrMessageEvent): ) await super().send(message) + + async def send_streaming(self, generator): + buffer = None + async for chain in generator: + if not buffer: + buffer = chain + else: + buffer.chain.extend(chain.chain) + if not buffer: + return + buffer.squash_plain() + await self.send(buffer) + return await super().send_streaming(generator) diff --git a/astrbot/core/provider/__init__.py b/astrbot/core/provider/__init__.py index f30d1ac32..ed7135fe6 100644 --- a/astrbot/core/provider/__init__.py +++ b/astrbot/core/provider/__init__.py @@ -1,5 +1,5 @@ from .provider import Provider, Personality, STTProvider -from .entites import ProviderMetaData +from .entities import ProviderMetaData __all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"] diff --git a/astrbot/core/provider/entites.py b/astrbot/core/provider/entities.py similarity index 98% rename from astrbot/core/provider/entites.py rename to astrbot/core/provider/entities.py index a8ffcdf64..99824fd0e 100644 --- a/astrbot/core/provider/entites.py +++ b/astrbot/core/provider/entities.py @@ -204,6 +204,9 @@ class LLMResponse: _completion_text: str = "" + is_chunk: bool = False + """是否是流式输出的单个 Chunk""" + def __init__( self, role: str, @@ -214,6 +217,7 @@ class LLMResponse: tools_call_ids: List[str] = [], raw_completion: ChatCompletion = None, _new_record: Dict[str, any] = None, + is_chunk: bool = False, ): """初始化 LLMResponse @@ -233,6 +237,7 @@ class LLMResponse: self.tools_call_ids = tools_call_ids self.raw_completion = raw_completion self._new_record = _new_record + self.is_chunk = is_chunk @property def completion_text(self): diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index a3fa65e86..9812a7e6a 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -2,7 +2,7 @@ import traceback import asyncio from astrbot.core.config.astrbot_config import AstrBotConfig from .provider import Provider, STTProvider, TTSProvider, Personality -from .entites import ProviderType +from .entities import ProviderType from typing import List from astrbot.core.db import BaseDatabase from .register import provider_cls_map, llm_tools diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 8dcff9a52..96547c5c2 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,9 +1,9 @@ import abc from typing import List from astrbot.core.db import BaseDatabase -from typing import TypedDict +from typing import TypedDict, AsyncGenerator from astrbot.core.provider.func_tool_manager import FuncCall -from astrbot.core.provider.entites import LLMResponse, ToolCallsResult +from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from dataclasses import dataclass @@ -108,7 +108,35 @@ class Provider(AbstractProvider): - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 """ - raise NotImplementedError() + ... + + async def text_chat_stream( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = None, + func_tool: FuncCall = None, + contexts: List = None, + system_prompt: str = None, + tool_calls_result: ToolCallsResult = None, + **kwargs, + ) -> AsyncGenerator[LLMResponse, None]: + """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 + + Args: + prompt: 提示词 + session_id: 会话 ID(此属性已经被废弃) + image_urls: 图片 URL 列表 + tools: Function-calling 工具 + contexts: 上下文 + tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling + kwargs: 其他参数 + + Notes: + - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 + - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 + """ + ... async def pop_record(self, context: List): """ diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 41a7a29d5..02d7934d1 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,5 +1,5 @@ from typing import List, Dict -from .entites import ProviderMetaData, ProviderType +from .entities import ProviderMetaData, ProviderType from astrbot.core import logger from .func_tool_manager import FuncCall diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index fd19c40ca..319515c52 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -10,7 +10,8 @@ from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from ..register import register_provider_adapter -from astrbot.core.provider.entites import LLMResponse, ToolCallsResult +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from .openai_source import ProviderOpenAIOfficial @@ -72,7 +73,8 @@ class ProviderAnthropic(ProviderOpenAIOfficial): if content.type == "text": # text completion completion_text = str(content.text).strip() - llm_response.completion_text = completion_text + # llm_response.completion_text = completion_text + llm_response.result_chain = MessageChain().message(completion_text) # Anthropic每次只返回一个函数调用 if completion.stop_reason == "tool_use": @@ -145,7 +147,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial): messages=context_query, **model_config ) llm_response = LLMResponse("assistant") - llm_response.completion_text = response.content[0].text + llm_response.result_chain = MessageChain().message(response.content[0].text) llm_response.raw_completion = response return llm_response except Exception as e: @@ -160,6 +162,33 @@ class ProviderAnthropic(ProviderOpenAIOfficial): return llm_response + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response + async def assemble_context(self, text: str, image_urls: List[str] = None): """组装上下文,支持文本和图片""" if not image_urls: diff --git a/astrbot/core/provider/sources/dashscope_source.py b/astrbot/core/provider/sources/dashscope_source.py index 14aefceef..2c4930692 100644 --- a/astrbot/core/provider/sources/dashscope_source.py +++ b/astrbot/core/provider/sources/dashscope_source.py @@ -3,10 +3,11 @@ import asyncio import functools from typing import List from .. import Provider, Personality -from ..entites import LLMResponse +from ..entities import LLMResponse from ..func_tool_manager import FuncCall from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter +from astrbot.core.message.message_event_result import MessageChain from .openai_source import ProviderOpenAIOfficial from astrbot.core import logger, sp from dashscope import Application @@ -132,7 +133,9 @@ class ProviderDashscope(ProviderOpenAIOfficial): ) return LLMResponse( role="err", - completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}", + result_chain=MessageChain().message( + f"阿里云百炼请求失败: message={response.message} code={response.status_code}" + ), ) output_text = response.output.get("text", "") @@ -141,11 +144,45 @@ class ProviderDashscope(ProviderOpenAIOfficial): if self.output_reference and response.output.get("doc_references", None): ref_str = "" for ref in response.output.get("doc_references", []): - ref_title = ref.get("title", "") if ref.get("title") else ref.get("doc_name", "") + ref_title = ( + ref.get("title", "") + if ref.get("title") + else ref.get("doc_name", "") + ) ref_str += f"{ref['index_id']}. {ref_title}\n" output_text += f"\n\n回答来源:\n{ref_str}" - return LLMResponse(role="assistant", completion_text=output_text) + llm_response = LLMResponse("assistant") + llm_response.result_chain = MessageChain().message(output_text) + + return llm_response + + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response async def forget(self, session_id): return True diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index 06b390fcd..7a038e4ba 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -3,7 +3,7 @@ import uuid import asyncio from dashscope.audio.tts_v2 import * from ..provider import TTSProvider -from ..entites import ProviderType +from ..entities import ProviderType from ..register import register_provider_adapter @@ -20,7 +20,7 @@ class ProviderDashscopeTTSAPI(TTSProvider): self.chosen_api_key: str = provider_config.get("api_key", "") self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella") self.set_model(provider_config.get("model", None)) - self.timeout_ms = float(provider_config.get("timeout", 20))*1000 + self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000 dashscope.api_key = self.chosen_api_key self.synthesizer = SpeechSynthesizer( diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 8b5890c28..1adb0f884 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -2,7 +2,7 @@ import astrbot.core.message.components as Comp from typing import List from .. import Provider, Personality -from ..entites import LLMResponse +from ..entities import LLMResponse from ..func_tool_manager import FuncCall from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter @@ -189,6 +189,33 @@ class ProviderDify(Provider): return LLMResponse(role="assistant", result_chain=chain) + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response + async def parse_dify_result(self, chunk: dict | str) -> MessageChain: if isinstance(chunk, str): # Chat diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index 0eadb2190..338abe263 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -4,7 +4,7 @@ import edge_tts import subprocess import asyncio from ..provider import TTSProvider -from ..entites import ProviderType +from ..entities import ProviderType from ..register import register_provider_adapter from astrbot.core import logger diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index 84b4b677e..07d0c32ab 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, conint from httpx import AsyncClient from typing import Annotated, Literal from ..provider import TTSProvider -from ..entites import ProviderType +from ..entities import ProviderType from ..register import register_provider_adapter diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 11f3f7eaa..a3ca8c8f2 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -12,7 +12,7 @@ from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from typing import List from ..register import register_provider_adapter -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entities import LLMResponse class SimpleGoogleGenAIClient: @@ -78,6 +78,39 @@ class SimpleGoogleGenAIClient: logger.error(f"Gemini 返回了非 json 数据: {text}") raise Exception("Gemini 返回了非 json 数据: ") + async def stream_generate_content( + self, + contents: List[dict], + model: str = "gemini-1.5-flash", + system_instruction: str = "", + tools: dict = None, + modalities: List[str] = ["Text"], + safety_settings: List[dict] = [], + ): + payload = {} + if system_instruction: + payload["system_instruction"] = {"parts": {"text": system_instruction}} + if tools: + payload["tools"] = [tools] + payload["contents"] = contents + payload["generationConfig"] = { + "responseModalities": modalities, + "stream": True, + } + payload["safetySettings"] = [ + {"category": s["category"], "threshold": s["threshold"]} + for s in safety_settings + ] + logger.debug(f"payload: {payload}") + request_url = ( + f"{self.api_base}/v1beta/models/{model}:streamGenerateContent?key={self.api_key}" + ) + async with self.client.post( + request_url, json=payload, timeout=self.timeout + ) as resp: + async for line in resp.content: + if line: + yield line @register_provider_adapter( "googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器" @@ -338,6 +371,33 @@ class ProviderGoogleGenAI(Provider): return llm_response + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response + def get_current_key(self) -> str: return self.client.api_key diff --git a/astrbot/core/provider/sources/gsvi_tts_source.py b/astrbot/core/provider/sources/gsvi_tts_source.py index b57932edf..581eef4dc 100644 --- a/astrbot/core/provider/sources/gsvi_tts_source.py +++ b/astrbot/core/provider/sources/gsvi_tts_source.py @@ -2,7 +2,7 @@ import uuid import aiohttp import urllib.parse from ..provider import TTSProvider -from ..entites import ProviderType +from ..entities import ProviderType from ..register import register_provider_adapter diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py index bfd9e03a5..85994fd59 100644 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -2,7 +2,7 @@ import os from llmtuner.chat import ChatModel from typing import List from .. import Provider -from ..entites import LLMResponse +from ..entities import LLMResponse from ..func_tool_manager import FuncCall from astrbot.core.db import BaseDatabase from ..register import register_provider_adapter @@ -95,6 +95,33 @@ class LLMTunerModelLoader(Provider): return llm_response + async def text_chat_stream( + self, + prompt, + session_id=None, + image_urls=..., + func_tool=None, + contexts=..., + system_prompt=None, + tool_calls_result=None, + **kwargs, + ): + # raise NotImplementedError("This method is not implemented yet.") + # 调用 text_chat 模拟流式 + llm_response = await self.text_chat( + prompt=prompt, + session_id=session_id, + image_urls=image_urls, + func_tool=func_tool, + contexts=contexts, + system_prompt=system_prompt, + tool_calls_result=tool_calls_result, + ) + llm_response.is_chunk = True + yield llm_response + llm_response.is_chunk = False + yield llm_response + async def get_current_key(self): return "none" diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index f8d392404..8023d18d1 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -4,19 +4,24 @@ import os import inspect import random import asyncio +import astrbot.core.message.components as Comp from openai import AsyncOpenAI, AsyncAzureOpenAI from openai.types.chat.chat_completion import ChatCompletion + +# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai._exceptions import NotFoundError, UnprocessableEntityError +from openai.lib.streaming.chat._completions import ChatCompletionStreamState from astrbot.core.utils.io import download_image_by_url +from astrbot.core.message.message_event_result import MessageChain from astrbot.core.db import BaseDatabase from astrbot.api.provider import Provider, Personality from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall -from typing import List +from typing import List, AsyncGenerator from ..register import register_provider_adapter -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entities import LLMResponse @register_provider_adapter( @@ -107,16 +112,72 @@ class ProviderOpenAIOfficial(Provider): logger.debug(f"completion: {completion}") + llm_response = await self.parse_openai_completion(completion, tools) + + return llm_response + + async def _query_stream( + self, payloads: dict, tools: FuncCall + ) -> AsyncGenerator[LLMResponse, None]: + """流式查询API,逐步返回结果""" + if tools: + tool_list = tools.get_func_desc_openai_style() + if tool_list: + payloads["tools"] = tool_list + + # 不在默认参数中的参数放在 extra_body 中 + extra_body = {} + to_del = [] + for key in payloads.keys(): + if key not in self.default_params: + extra_body[key] = payloads[key] + to_del.append(key) + for key in to_del: + del payloads[key] + + stream = await self.client.chat.completions.create( + **payloads, stream=True, extra_body=extra_body + ) + + llm_response = LLMResponse("assistant", is_chunk=True) + + state = ChatCompletionStreamState() + + async for chunk in stream: + try: + state.handle_chunk(chunk) + except Exception as e: + logger.warning("Saving chunk state error: " + str(e)) + if len(chunk.choices) == 0: + continue + delta = chunk.choices[0].delta + # 处理文本内容 + if delta.content: + completion_text = delta.content + llm_response.result_chain = MessageChain( + chain=[Comp.Plain(completion_text)] + ) + yield llm_response + + final_completion = state.get_final_completion() + llm_response = await self.parse_openai_completion(final_completion, tools) + + yield llm_response + + async def parse_openai_completion( + self, completion: ChatCompletion, tools: FuncCall + ): + """解析 OpenAI 的 ChatCompletion 响应""" + llm_response = LLMResponse("assistant") + if len(completion.choices) == 0: raise Exception("API 返回的 completion 为空。") choice = completion.choices[0] - llm_response = LLMResponse("assistant") - if choice.message.content: # text completion completion_text = str(choice.message.content).strip() - llm_response.completion_text = completion_text + llm_response.result_chain = MessageChain().message(completion_text) if choice.message.tool_calls: # tools call (function calling) @@ -148,7 +209,7 @@ class ProviderOpenAIOfficial(Provider): return llm_response - async def text_chat( + async def _prepare_chat_payload( self, prompt: str, session_id: str = None, @@ -158,7 +219,8 @@ class ProviderOpenAIOfficial(Provider): system_prompt=None, tool_calls_result=None, **kwargs, - ) -> LLMResponse: + ) -> tuple: + """准备聊天所需的有效载荷和上下文""" new_record = await self.assemble_context(prompt, image_urls) context_query = [*contexts, new_record] if system_prompt: @@ -177,8 +239,117 @@ class ProviderOpenAIOfficial(Provider): payloads = {"messages": context_query, **model_config} - llm_response = None + return payloads, context_query, func_tool + async def _handle_api_error( + self, + e: Exception, + payloads: dict, + context_query: list, + func_tool: FuncCall, + chosen_key: str, + available_api_keys: List[str], + retry_cnt: int, + max_retries: int, + ) -> tuple: + """处理API错误并尝试恢复""" + if "429" in str(e): + logger.warning( + f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}" + ) + # 最后一次不等待 + if retry_cnt < max_retries - 1: + await asyncio.sleep(1) + available_api_keys.remove(chosen_key) + if len(available_api_keys) > 0: + chosen_key = random.choice(available_api_keys) + return ( + False, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) + else: + raise e + elif "maximum context length" in str(e): + logger.warning( + f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" + ) + await self.pop_record(context_query) + payloads["messages"] = context_query + return ( + False, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) + elif "The model is not a VLM" in str(e): # siliconcloud + # 尝试删除所有 image + new_contexts = await self._remove_image_from_context(context_query) + payloads["messages"] = new_contexts + context_query = new_contexts + return ( + False, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) + elif ( + "Function calling is not enabled" in str(e) + or ("tool" in str(e).lower() and "support" in str(e).lower()) + or ("function" in str(e).lower() and "support" in str(e).lower()) + ): + # openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配 + logger.info( + f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。" + ) + if "tools" in payloads: + del payloads["tools"] + return False, chosen_key, available_api_keys, payloads, context_query, None + else: + logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}") + + if "tool" in str(e).lower() and "support" in str(e).lower(): + logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all") + + if "Connection error." in str(e): + proxy = os.environ.get("http_proxy", None) + if proxy: + logger.error( + f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}" + ) + + raise e + + async def text_chat( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, + contexts=[], + system_prompt=None, + tool_calls_result=None, + **kwargs, + ) -> LLMResponse: + payloads, context_query, func_tool = await self._prepare_chat_payload( + prompt, + session_id, + image_urls, + func_tool, + contexts, + system_prompt, + tool_calls_result, + **kwargs, + ) + + llm_response = None max_retries = 10 available_api_keys = self.api_keys.copy() chosen_key = random.choice(available_api_keys) @@ -197,64 +368,97 @@ class ProviderOpenAIOfficial(Provider): payloads["messages"] = new_contexts context_query = new_contexts except Exception as e: - if "429" in str(e): - logger.warning( - f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}" - ) - # 最后一次不等待 - if retry_cnt < max_retries - 1: - await asyncio.sleep(1) - available_api_keys.remove(chosen_key) - if len(available_api_keys) > 0: - chosen_key = random.choice(available_api_keys) - continue - else: - raise e - elif "maximum context length" in str(e): - logger.warning( - f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}" - ) - await self.pop_record(context_query) - elif "The model is not a VLM" in str(e): # siliconcloud - # 尝试删除所有 image - new_contexts = await self._remove_image_from_context(context_query) - payloads["messages"] = new_contexts - elif ( - "Function calling is not enabled" in str(e) - or ("tool" in str(e).lower() and "support" in str(e).lower()) - or ("function" in str(e).lower() and "support" in str(e).lower()) - ): - # openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配 - logger.info( - f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。" - ) - if "tools" in payloads: - del payloads["tools"] - func_tool = None - else: - logger.error( - f"发生了错误。Provider 配置如下: {self.provider_config}" - ) - - if "tool" in str(e).lower() and "support" in str(e).lower(): - logger.error( - "疑似该模型不支持函数调用工具调用。请输入 /tool off_all" - ) - - if "Connection error." in str(e): - proxy = os.environ.get("http_proxy", None) - if proxy: - logger.error( - f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}" - ) - - raise e + ( + success, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) = await self._handle_api_error( + e, + payloads, + context_query, + func_tool, + chosen_key, + available_api_keys, + retry_cnt, + max_retries, + ) + if success: + break if retry_cnt == max_retries - 1: logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。") raise e return llm_response + async def text_chat_stream( + self, + prompt: str, + session_id: str = None, + image_urls: List[str] = [], + func_tool: FuncCall = None, + contexts=[], + system_prompt=None, + tool_calls_result=None, + **kwargs, + ) -> AsyncGenerator[LLMResponse, None]: + """流式对话,与服务商交互并逐步返回结果""" + payloads, context_query, func_tool = await self._prepare_chat_payload( + prompt, + session_id, + image_urls, + func_tool, + contexts, + system_prompt, + tool_calls_result, + **kwargs, + ) + + max_retries = 10 + available_api_keys = self.api_keys.copy() + chosen_key = random.choice(available_api_keys) + + e = None + retry_cnt = 0 + for retry_cnt in range(max_retries): + try: + self.client.api_key = chosen_key + async for response in self._query_stream(payloads, func_tool): + yield response + break + except UnprocessableEntityError as e: + logger.warning(f"不可处理的实体错误:{e},尝试删除图片。") + # 尝试删除所有 image + new_contexts = await self._remove_image_from_context(context_query) + payloads["messages"] = new_contexts + context_query = new_contexts + except Exception as e: + ( + success, + chosen_key, + available_api_keys, + payloads, + context_query, + func_tool, + ) = await self._handle_api_error( + e, + payloads, + context_query, + func_tool, + chosen_key, + available_api_keys, + retry_cnt, + max_retries, + ) + if success: + break + + if retry_cnt == max_retries - 1: + logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。") + raise e + async def _remove_image_from_context(self, contexts: List): """ 从上下文中删除所有带有 image 的记录 diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index f120a6a59..20b00f949 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -1,7 +1,7 @@ import uuid from openai import AsyncOpenAI, NOT_GIVEN from ..provider import TTSProvider -from ..entites import ProviderType +from ..entities import ProviderType from ..register import register_provider_adapter diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index 84087ecf6..b6e3331f8 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -11,7 +11,7 @@ import re from funasr_onnx import SenseVoiceSmall from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess from ..provider import STTProvider -from ..entites import ProviderType +from ..entities import ProviderType from astrbot.core.utils.io import download_file from ..register import register_provider_adapter from astrbot.core import logger diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index e38a81de9..0009af906 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -2,7 +2,7 @@ import uuid import os from openai import AsyncOpenAI, NOT_GIVEN from ..provider import STTProvider -from ..entites import ProviderType +from ..entities import ProviderType from astrbot.core.utils.io import download_file from ..register import register_provider_adapter from astrbot.core import logger diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index cfd1267d0..96f0b6f6d 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -3,7 +3,7 @@ import os import asyncio import whisper from ..provider import STTProvider -from ..entites import ProviderType +from ..entities import ProviderType from astrbot.core.utils.io import download_file from ..register import register_provider_adapter from astrbot.core import logger diff --git a/astrbot/core/provider/sources/zhipu_source.py b/astrbot/core/provider/sources/zhipu_source.py index 3e819d633..2f7490317 100644 --- a/astrbot/core/provider/sources/zhipu_source.py +++ b/astrbot/core/provider/sources/zhipu_source.py @@ -3,7 +3,7 @@ from astrbot import logger from astrbot.core.provider.func_tool_manager import FuncCall from typing import List from ..register import register_provider_adapter -from astrbot.core.provider.entites import LLMResponse +from astrbot.core.provider.entities import LLMResponse from .openai_source import ProviderOpenAIOfficial diff --git a/astrbot/core/star/filter/command.py b/astrbot/core/star/filter/command.py old mode 100644 new mode 100755 diff --git a/astrbot/core/star/filter/command_group.py b/astrbot/core/star/filter/command_group.py old mode 100644 new mode 100755 diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py index 521513449..10cf90c8b 100644 --- a/astrbot/core/star/star.py +++ b/astrbot/core/star/star.py @@ -47,5 +47,29 @@ class StarMetadata: star_handler_full_names: List[str] = field(default_factory=list) """注册的 Handler 的全名列表""" + supported_platforms: Dict[str, bool] = field(default_factory=dict) + """插件支持的平台ID字典,key为平台ID,value为是否支持""" + def __str__(self) -> str: return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})" + + def update_platform_compatibility(self, plugin_enable_config: dict) -> None: + """更新插件支持的平台列表 + + Args: + plugin_enable_config: 平台插件启用配置,即platform_settings.plugin_enable配置项 + """ + if not plugin_enable_config: + return + + # 清空之前的配置 + self.supported_platforms.clear() + + # 遍历所有平台配置 + for platform_id, plugins in plugin_enable_config.items(): + # 检查该插件在当前平台的配置 + if self.name in plugins: + self.supported_platforms[platform_id] = plugins[self.name] + else: + # 如果没有明确配置,默认为启用 + self.supported_platforms[platform_id] = True diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 7be0e053c..0764f15f6 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -30,21 +30,36 @@ class StarHandlerRegistry(Generic[T]): print(handler.handler_full_name) def get_handlers_by_event_type( - self, event_type: EventType, only_activated=True + self, event_type: EventType, only_activated=True, platform_id=None ) -> List[StarHandlerMetadata]: - """通过事件类型获取 Handler""" - handlers = [ - handler - for _, handler in self._handlers - if handler.event_type == event_type - and ( - not only_activated - or ( - star_map[handler.handler_module_path] - and star_map[handler.handler_module_path].activated - ) - ) - ] + """通过事件类型获取 Handler + + Args: + event_type: 事件类型 + only_activated: 是否只返回已激活的插件的处理器 + platform_id: 平台ID,如果提供此参数,将过滤掉在此平台不兼容的处理器 + + Returns: + List[StarHandlerMetadata]: 处理器列表 + """ + handlers = [] + for _, handler in self._handlers: + if handler.event_type != event_type: + continue + + # 只激活的插件处理器 + if only_activated: + plugin = star_map.get(handler.handler_module_path) + if not (plugin and plugin.activated): + continue + + # 平台兼容性过滤 + if platform_id and event_type != EventType.OnAstrBotLoadedEvent: + if not handler.is_enabled_for_platform(platform_id): + continue + + handlers.append(handler) + return handlers def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata: @@ -139,3 +154,32 @@ class StarHandlerMetadata: return self.extras_configs.get("priority", 0) < other.extras_configs.get( "priority", 0 ) + + def is_enabled_for_platform(self, platform_id: str) -> bool: + """检查插件是否在指定平台启用 + + Args: + platform_id: 平台ID,这是从event.get_platform_id()获取的,用于唯一标识平台实例 + + Returns: + bool: 是否启用,True表示启用,False表示禁用 + """ + plugin = star_map.get(self.handler_module_path) + + # 如果插件元数据不存在,默认允许执行 + if not plugin or not plugin.name: + return True + + # 先检查插件是否被激活 + if not plugin.activated: + return False + + # 直接使用StarMetadata中缓存的supported_platforms判断平台兼容性 + if ( + hasattr(plugin, "supported_platforms") + and platform_id in plugin.supported_platforms + ): + return plugin.supported_platforms[platform_id] + + # 如果没有缓存数据,默认允许执行 + return True diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 2d610f15c..a4ae48250 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -209,7 +209,31 @@ class PluginManager: await self._unbind_plugin(smd.name, specified_module_path) - return await self.load(specified_module_path) + result = await self.load(specified_module_path) + + # 更新所有插件的平台兼容性 + await self.update_all_platform_compatibility() + + return result + + async def update_all_platform_compatibility(self): + """更新所有插件的平台兼容性设置""" + # 获取最新的平台插件启用配置 + plugin_enable_config = self.config.get("platform_settings", {}).get( + "plugin_enable", {} + ) + logger.debug( + f"更新所有插件的平台兼容性设置,平台数量: {len(plugin_enable_config)}" + ) + + # 遍历所有插件,更新平台兼容性 + for plugin in self.context.get_all_stars(): + plugin.update_platform_compatibility(plugin_enable_config) + logger.debug( + f"插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}" + ) + + return True async def load(self, specified_module_path=None, specified_dir_name=None): """载入插件。 @@ -320,6 +344,12 @@ class PluginManager: metadata.root_dir_name = root_dir_name metadata.reserved = reserved + # 更新插件的平台兼容性 + plugin_enable_config = self.config.get("platform_settings", {}).get( + "plugin_enable", {} + ) + metadata.update_platform_compatibility(plugin_enable_config) + # 绑定 handler related_handlers = ( star_handlers_registry.get_handlers_by_module_name( diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 68468e353..405ccc631 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -1,9 +1,12 @@ +import inspect from typing import Union, Awaitable, List, Optional, ClassVar from astrbot.core.message.components import BaseMessageComponent from astrbot.core.message.message_event_result import MessageChain from astrbot.api.platform import MessageMember, AstrBotMessage from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.star.context import Context +from astrbot.core.star.star import star_map +from pathlib import Path class StarTools: @@ -142,3 +145,48 @@ class StarTools: name (str): 工具名称 """ cls._context.unregister_llm_tool(name) + + @classmethod + def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path: + """ + 返回插件数据目录的绝对路径。 + + 此方法会在 data/plugin_data 目录下为插件创建一个专属的数据目录。如果未提供插件名称, + 会自动从调用栈中获取插件信息。 + + Args: + plugin_name: 可选的插件名称。如果为None,将自动检测调用者的插件名称。 + + Returns: + Path (Path): 插件数据目录的绝对路径,位于 data/plugin_data/{plugin_name}。 + + Raises: + RuntimeError: 当出现以下情况时抛出: + - 无法获取调用者模块信息 + - 无法获取模块的元数据信息 + - 创建目录失败(权限不足或其他IO错误) + """ + if not plugin_name: + frame = inspect.currentframe().f_back + module = inspect.getmodule(frame) + + if not module: + raise RuntimeError("无法获取调用者模块信息") + + metadata = star_map.get(module.__name__, None) + + if not metadata: + raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") + + plugin_name = metadata.name + + data_dir = Path("data/plugin_data") / plugin_name + + try: + data_dir.mkdir(parents=True, exist_ok=True) + except OSError as e: + if isinstance(e, PermissionError): + raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e + raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e + + return data_dir.resolve() diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index bf88ba8db..b11987322 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -15,7 +15,7 @@ class SharedPreferences: def _save_preferences(self): with open(self.path, "w") as f: - json.dump(self._data, f, indent=4) + json.dump(self._data, f, indent=4, ensure_ascii=False) f.flush() def get(self, key, default=None): diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index db1461f59..d767ddea4 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -161,42 +161,53 @@ class ChatRoute(Route): username = g.get("username", "guest") if username in self.curr_chat_sse: - return "[ERROR]\n" + return Response().error("Already connected").__dict__ self.curr_chat_sse[username] = None + heartbeat = json.dumps({"type": "heartbeat", "data": "ping"}) + async def stream(): try: - yield "[HB]\n" + yield f"data: {heartbeat}\n\n" # 心跳包 while True: try: result = await asyncio.wait_for( web_chat_back_queue.get(), timeout=10 ) # 设置超时时间为5秒 except asyncio.TimeoutError: - yield "[HB]\n" # 心跳包 + yield f"data: {heartbeat}\n\n" # 心跳包 continue if not result: continue - result_text, cid = result + + result_text = result["data"] + type = result.get("type") + cid = result.get("cid") + streaming = result.get("streaming", False) if cid != self.curr_user_cid.get(username): # 丢弃 continue - yield result_text + "\n" + yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" + await asyncio.sleep(0.05) - conversation = self.db.get_conversation_by_user_id(username, cid) - try: - history = json.loads(conversation.history) - except BaseException as e: - print(e) - history = [] - history.append({"type": "bot", "message": result_text}) - self.db.update_conversation( - username, cid, history=json.dumps(history) - ) + if streaming and type != "end": + continue - await asyncio.sleep(0.5) + if result_text: + conversation = self.db.get_conversation_by_user_id( + username, cid + ) + try: + history = json.loads(conversation.history) + except BaseException as e: + print(e) + history = [] + history.append({"type": "bot", "message": result_text}) + self.db.update_conversation( + username, cid, history=json.dumps(history) + ) except BaseException as _: logger.debug(f"用户 {username} 断开聊天长连接。") self.curr_chat_sse.pop(username) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 629a424f1..2747865e4 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -179,7 +179,7 @@ class ConfigRoute(Route): await self._save_astrbot_configs(post_configs) return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__ except Exception as e: - logger.error(e) + logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ async def post_plugin_configs(self): diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py index 6f3940c0a..f99110530 100644 --- a/astrbot/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -20,7 +20,7 @@ class LogRoute(Route): message = await queue.get() payload = { "type": "log", - **message # see astrbot/core/log.py + **message, # see astrbot/core/log.py } yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n" except asyncio.CancelledError: diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index e369ad054..9fb9d231a 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -1,5 +1,6 @@ import traceback import aiohttp +import os import ssl import certifi @@ -36,6 +37,9 @@ class PluginRoute(Route): "/plugin/off": ("POST", self.off_plugin), "/plugin/on": ("POST", self.on_plugin), "/plugin/reload": ("POST", self.reload_plugins), + "/plugin/readme": ("GET", self.get_plugin_readme), + "/plugin/platform_enable/get": ("GET", self.get_plugin_platform_enable), + "/plugin/platform_enable/set": ("POST", self.set_plugin_platform_enable), } self.core_lifecycle = core_lifecycle self.plugin_manager = plugin_manager @@ -317,3 +321,135 @@ class PluginRoute(Route): except Exception as e: logger.error(f"/api/plugin/on: {traceback.format_exc()}") return Response().error(str(e)).__dict__ + + async def get_plugin_readme(self): + plugin_name = request.args.get("name") + logger.debug(f"正在获取插件 {plugin_name} 的README文件内容") + + if not plugin_name: + logger.warning("插件名称为空") + return Response().error("插件名称不能为空").__dict__ + + plugin_obj = None + for plugin in self.plugin_manager.context.get_all_stars(): + if plugin.name == plugin_name: + plugin_obj = plugin + break + + if not plugin_obj: + logger.warning(f"插件 {plugin_name} 不存在") + return Response().error(f"插件 {plugin_name} 不存在").__dict__ + + plugin_dir = os.path.join( + self.plugin_manager.plugin_store_path, plugin_obj.root_dir_name + ) + + if not os.path.isdir(plugin_dir): + logger.warning(f"无法找到插件目录: {plugin_dir}") + return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ + + readme_path = os.path.join(plugin_dir, "README.md") + + if not os.path.isfile(readme_path): + logger.warning(f"插件 {plugin_name} 没有README文件") + return Response().error(f"插件 {plugin_name} 没有README文件").__dict__ + + try: + with open(readme_path, "r", encoding="utf-8") as f: + readme_content = f.read() + + return ( + Response() + .ok({"content": readme_content}, "成功获取README内容") + .__dict__ + ) + except Exception as e: + logger.error(f"/api/plugin/readme: {traceback.format_exc()}") + return Response().error(f"读取README文件失败: {str(e)}").__dict__ + + async def get_plugin_platform_enable(self): + """获取插件在各平台的可用性配置""" + try: + platform_enable = self.core_lifecycle.astrbot_config.get( + "platform_settings", {} + ).get("plugin_enable", {}) + + # 获取所有可用平台 + platforms = [] + + for platform in self.core_lifecycle.astrbot_config.get("platform", []): + platform_type = platform.get("type", "") + platform_id = platform.get("id", "") + + platforms.append( + { + "name": platform_id, # 使用type作为name,这是系统内部使用的平台名称 + "id": platform_id, # 保留id字段以便前端可以显示 + "type": platform_type, + "display_name": f"{platform_type}({platform_id})", + } + ) + + adjusted_platform_enable = {} + for platform_id, plugins in platform_enable.items(): + adjusted_platform_enable[platform_id] = plugins + + # 获取所有插件,包括系统内部插件 + plugins = [] + for plugin in self.plugin_manager.context.get_all_stars(): + plugins.append( + { + "name": plugin.name, + "desc": plugin.desc, + "reserved": plugin.reserved, # 添加reserved标志 + } + ) + + logger.debug( + f"获取插件平台配置: 原始配置={platform_enable}, 调整后={adjusted_platform_enable}" + ) + + return ( + Response() + .ok( + { + "platforms": platforms, + "plugins": plugins, + "platform_enable": adjusted_platform_enable, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"/api/plugin/platform_enable/get: {traceback.format_exc()}") + return Response().error(str(e)).__dict__ + + async def set_plugin_platform_enable(self): + """设置插件在各平台的可用性配置""" + if DEMO_MODE: + return ( + Response() + .error("You are not permitted to do this operation in demo mode") + .__dict__ + ) + + try: + data = await request.json + platform_enable = data.get("platform_enable", {}) + + # 更新配置 + config = self.core_lifecycle.astrbot_config + platform_settings = config.get("platform_settings", {}) + platform_settings["plugin_enable"] = platform_enable + config["platform_settings"] = platform_settings + config.save_config() + + # 更新插件的平台兼容性缓存 + await self.plugin_manager.update_all_platform_compatibility() + + logger.info(f"插件平台可用性配置已更新: {platform_enable}") + + return Response().ok(None, "插件平台可用性配置已更新").__dict__ + except Exception as e: + logger.error(f"/api/plugin/platform_enable/set: {traceback.format_exc()}") + return Response().error(str(e)).__dict__ diff --git a/dashboard/src/components/shared/ExtensionCard.vue b/dashboard/src/components/shared/ExtensionCard.vue index 18f845646..e6573260c 100644 --- a/dashboard/src/components/shared/ExtensionCard.vue +++ b/dashboard/src/components/shared/ExtensionCard.vue @@ -24,13 +24,10 @@ const emit = defineEmits([ 'install', 'uninstall', 'toggle-activation', - 'view-handlers' + 'view-handlers', + 'view-readme' ]); -const open = (link: string | undefined) => { - window.open(link, '_blank'); -}; - const reveal = ref(false); // 操作函数 @@ -70,6 +67,10 @@ const toggleActivation = () => { const viewHandlers = () => { emit('view-handlers', props.extension); }; + +const viewReadme = () => { + emit('view-readme', props.extension); +};