From 395786187881f5cc56cb4e58addb271627ed4419 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Thu, 13 Nov 2025 10:08:57 +0800 Subject: [PATCH] refactor: streamline llm processing logic (#3607) * refactor: streamline llm processing logic * perf: merge-nested-ifs Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * fix: ruff format * refactor: remove unnecessary debug logs in FunctionToolExecutor and LLMRequestSubStage --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> --- .../process_stage/method/llm_request.py | 395 ++++++++++-------- 1 file changed, 217 insertions(+), 178 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index f7677d373..69bf31a55 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -35,6 +35,7 @@ from astrbot.core.provider.register import llm_tools from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.star_handler import EventType, star_map from astrbot.core.utils.metrics import Metric +from astrbot.core.utils.session_lock import session_lock_manager from ...context import PipelineContext, call_event_hook, call_local_llm_tool from ..stage import Stage @@ -186,7 +187,6 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): is_override_call = False for ty in type(tool).mro(): if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call: - logger.debug(f"Found call in: {ty}") is_override_call = True break @@ -413,67 +413,12 @@ class LLMRequestSubStage(Stage): raise RuntimeError("无法创建新的对话。") return conversation - async def process( + async def _apply_kb_context( self, event: AstrMessageEvent, - _nested: bool = False, - ) -> None | AsyncGenerator[None, None]: - req: ProviderRequest | None = None - - if not self.ctx.astrbot_config["provider_settings"]["enable"]: - logger.debug("未启用 LLM 能力,跳过处理。") - return - - # 检查会话级别的LLM启停状态 - if not SessionServiceManager.should_process_llm_request(event): - logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。") - return - - provider = self._select_provider(event) - if provider is None: - return - if not isinstance(provider, Provider): - logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。") - return - - streaming_response = self.streaming_response - if (enable_streaming := event.get_extra("enable_streaming")) is not None: - streaming_response = bool(enable_streaming) - - if event.get_extra("provider_request"): - req = event.get_extra("provider_request") - assert isinstance(req, ProviderRequest), ( - "provider_request 必须是 ProviderRequest 类型。" - ) - - if req.conversation: - req.contexts = json.loads(req.conversation.history) - - else: - req = ProviderRequest(prompt="", image_urls=[]) - if sel_model := event.get_extra("selected_model"): - req.model = sel_model - if self.provider_wake_prefix: - if not event.message_str.startswith(self.provider_wake_prefix): - return - req.prompt = event.message_str[len(self.provider_wake_prefix) :] - # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。 - # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() - for comp in event.message_obj.message: - if isinstance(comp, Image): - image_path = await comp.convert_to_file_path() - req.image_urls.append(image_path) - - conversation = await self._get_session_conv(event) - req.conversation = conversation - req.contexts = json.loads(conversation.history) - - event.set_extra("provider_request", req) - - if not req.prompt and not req.image_urls: - return - - # 应用知识库 + req: ProviderRequest, + ): + """应用知识库上下文到请求中""" try: await inject_kb_context( umo=event.unified_msg_origin, @@ -483,43 +428,40 @@ class LLMRequestSubStage(Stage): except Exception as e: logger.error(f"调用知识库时遇到问题: {e}") - # 执行请求 LLM 前事件钩子。 - if await call_event_hook(event, EventType.OnLLMRequestEvent, req): - return + def _truncate_contexts( + self, + contexts: list[dict], + ) -> list[dict]: + """截断上下文列表,确保不超过最大长度""" + if self.max_context_length == -1: + return contexts - if isinstance(req.contexts, str): - req.contexts = json.loads(req.contexts) + if len(contexts) // 2 <= self.max_context_length: + return contexts - # max context length - if ( - self.max_context_length != -1 # -1 为不限制 - and len(req.contexts) // 2 > self.max_context_length - ): - logger.debug("上下文长度超过限制,将截断。") - req.contexts = req.contexts[ - -(self.max_context_length - self.dequeue_context_length + 1) * 2 : - ] - # 找到第一个role 为 user 的索引,确保上下文格式正确 - index = next( - ( - i - for i, item in enumerate(req.contexts) - if item.get("role") == "user" - ), - None, - ) - if index is not None and index > 0: - req.contexts = req.contexts[index:] + truncated_contexts = contexts[ + -(self.max_context_length - self.dequeue_context_length + 1) * 2 : + ] + # 找到第一个role 为 user 的索引,确保上下文格式正确 + index = next( + ( + i + for i, item in enumerate(truncated_contexts) + if item.get("role") == "user" + ), + None, + ) + if index is not None and index > 0: + truncated_contexts = truncated_contexts[index:] - # session_id - if not req.session_id: - req.session_id = event.unified_msg_origin + return truncated_contexts - # fix messages - req.contexts = self.fix_messages(req.contexts) - - # check provider modalities - # 如果提供商不支持图像/工具使用,但请求中包含图像/工具列表,则清空。图片转述等的检测和调用发生在这之前,因此这里可以这样处理。 + def _modalities_fix( + self, + provider: Provider, + req: ProviderRequest, + ): + """检查提供商的模态能力,清理请求中的不支持内容""" if req.image_urls: provider_cfg = provider.provider_config.get("modalities", ["image"]) if "image" not in provider_cfg: @@ -533,7 +475,13 @@ class LLMRequestSubStage(Stage): f"用户设置提供商 {provider} 不支持工具使用,清空工具列表。", ) req.func_tool = None - # 插件可用性设置 + + def _plugin_tool_fix( + self, + event: AstrMessageEvent, + req: ProviderRequest, + ): + """根据事件中的插件设置,过滤请求中的工具列表""" if event.plugins_name is not None and req.func_tool: new_tool_set = ToolSet() for tool in req.func_tool.tools: @@ -547,86 +495,6 @@ class LLMRequestSubStage(Stage): new_tool_set.add_tool(tool) req.func_tool = new_tool_set - stream_to_general = ( - self.unsupported_streaming_strategy == "turn_off" - and not event.platform_meta.support_streaming_message - ) - # 备份 req.contexts - backup_contexts = copy.deepcopy(req.contexts) - - # run agent - agent_runner = AgentRunner() - logger.debug( - f"handle provider[id: {provider.provider_config['id']}] request: {req}", - ) - astr_agent_ctx = AstrAgentContext( - provider=provider, - first_provider_request=req, - curr_provider_request=req, - streaming=streaming_response, - event=event, - ) - await agent_runner.reset( - provider=provider, - request=req, - run_context=AgentContextWrapper( - context=astr_agent_ctx, - tool_call_timeout=self.tool_call_timeout, - ), - tool_executor=FunctionToolExecutor(), - agent_hooks=MAIN_AGENT_HOOKS, - streaming=streaming_response, - ) - - if streaming_response and not stream_to_general: - # 流式响应 - event.set_result( - MessageEventResult() - .set_result_content_type(ResultContentType.STREAMING_RESULT) - .set_async_stream( - run_agent(agent_runner, self.max_step, self.show_tool_use), - ), - ) - yield - if agent_runner.done(): - if final_llm_resp := agent_runner.get_final_llm_resp(): - if final_llm_resp.completion_text: - chain = ( - MessageChain().message(final_llm_resp.completion_text).chain - ) - elif final_llm_resp.result_chain: - chain = final_llm_resp.result_chain.chain - else: - chain = MessageChain().chain - event.set_result( - MessageEventResult( - chain=chain, - result_content_type=ResultContentType.STREAMING_FINISH, - ), - ) - else: - async for _ in run_agent( - agent_runner, self.max_step, self.show_tool_use, stream_to_general - ): - yield - - # 恢复备份的 contexts - req.contexts = backup_contexts - - await self._save_to_history(event, req, agent_runner.get_final_llm_resp()) - - # 异步处理 WebChat 特殊情况 - if event.get_platform_name() == "webchat": - asyncio.create_task(self._handle_webchat(event, req, provider)) - - asyncio.create_task( - Metric.upload( - llm_tick=1, - model_name=agent_runner.provider.get_model(), - provider_type=agent_runner.provider.meta().type, - ), - ) - async def _handle_webchat( self, event: AstrMessageEvent, @@ -674,9 +542,6 @@ class LLMRequestSubStage(Stage): ), ) if llm_resp and llm_resp.completion_text: - logger.debug( - f"WebChat 对话标题生成响应: {llm_resp.completion_text.strip()}", - ) title = llm_resp.completion_text.strip() if not title or "" in title: return @@ -723,7 +588,7 @@ class LLMRequestSubStage(Stage): history=messages, ) - def fix_messages(self, messages: list[dict]) -> list[dict]: + def _fix_messages(self, messages: list[dict]) -> list[dict]: """验证并且修复上下文""" fixed_messages = [] for message in messages: @@ -738,3 +603,177 @@ class LLMRequestSubStage(Stage): else: fixed_messages.append(message) return fixed_messages + + async def process( + self, + event: AstrMessageEvent, + _nested: bool = False, + ) -> None | AsyncGenerator[None, None]: + req: ProviderRequest | None = None + + if not self.ctx.astrbot_config["provider_settings"]["enable"]: + logger.debug("未启用 LLM 能力,跳过处理。") + return + + # 检查会话级别的LLM启停状态 + if not SessionServiceManager.should_process_llm_request(event): + logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。") + return + + provider = self._select_provider(event) + if provider is None: + return + if not isinstance(provider, Provider): + logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。") + return + + streaming_response = self.streaming_response + if (enable_streaming := event.get_extra("enable_streaming")) is not None: + streaming_response = bool(enable_streaming) + + logger.debug("ready to request llm provider") + async with session_lock_manager.acquire_lock(event.unified_msg_origin): + logger.debug("acquired session lock for llm request") + if event.get_extra("provider_request"): + req = event.get_extra("provider_request") + assert isinstance(req, ProviderRequest), ( + "provider_request 必须是 ProviderRequest 类型。" + ) + + if req.conversation: + req.contexts = json.loads(req.conversation.history) + + else: + req = ProviderRequest(prompt="", image_urls=[]) + if sel_model := event.get_extra("selected_model"): + req.model = sel_model + if self.provider_wake_prefix and not event.message_str.startswith( + self.provider_wake_prefix + ): + return + + req.prompt = event.message_str[len(self.provider_wake_prefix) :] + # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。 + # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() + for comp in event.message_obj.message: + if isinstance(comp, Image): + image_path = await comp.convert_to_file_path() + req.image_urls.append(image_path) + + conversation = await self._get_session_conv(event) + req.conversation = conversation + req.contexts = json.loads(conversation.history) + + event.set_extra("provider_request", req) + + if not req.prompt and not req.image_urls: + return + + # apply knowledge base context + await self._apply_kb_context(event, req) + + # call event hook + if await call_event_hook(event, EventType.OnLLMRequestEvent, req): + return + + # fix contexts json str + if isinstance(req.contexts, str): + req.contexts = json.loads(req.contexts) + + # truncate contexts to fit max length + req.contexts = self._truncate_contexts(req.contexts) + + # session_id + if not req.session_id: + req.session_id = event.unified_msg_origin + + # fix messages + req.contexts = self._fix_messages(req.contexts) + + # check provider modalities, if provider does not support image/tool_use, clear them in request. + self._modalities_fix(provider, req) + + # filter tools, only keep tools from this pipeline's selected plugins + self._plugin_tool_fix(event, req) + + stream_to_general = ( + self.unsupported_streaming_strategy == "turn_off" + and not event.platform_meta.support_streaming_message + ) + # 备份 req.contexts + backup_contexts = copy.deepcopy(req.contexts) + + # run agent + agent_runner = AgentRunner() + logger.debug( + f"handle provider[id: {provider.provider_config['id']}] request: {req}", + ) + astr_agent_ctx = AstrAgentContext( + provider=provider, + first_provider_request=req, + curr_provider_request=req, + streaming=streaming_response, + event=event, + ) + await agent_runner.reset( + provider=provider, + request=req, + run_context=AgentContextWrapper( + context=astr_agent_ctx, + tool_call_timeout=self.tool_call_timeout, + ), + tool_executor=FunctionToolExecutor(), + agent_hooks=MAIN_AGENT_HOOKS, + streaming=streaming_response, + ) + + if streaming_response and not stream_to_general: + # 流式响应 + event.set_result( + MessageEventResult() + .set_result_content_type(ResultContentType.STREAMING_RESULT) + .set_async_stream( + run_agent(agent_runner, self.max_step, self.show_tool_use), + ), + ) + yield + if agent_runner.done(): + if final_llm_resp := agent_runner.get_final_llm_resp(): + if final_llm_resp.completion_text: + chain = ( + MessageChain() + .message(final_llm_resp.completion_text) + .chain + ) + elif final_llm_resp.result_chain: + chain = final_llm_resp.result_chain.chain + else: + chain = MessageChain().chain + event.set_result( + MessageEventResult( + chain=chain, + result_content_type=ResultContentType.STREAMING_FINISH, + ), + ) + else: + async for _ in run_agent( + agent_runner, self.max_step, self.show_tool_use, stream_to_general + ): + yield + + # 恢复备份的 contexts + req.contexts = backup_contexts + + await self._save_to_history(event, req, agent_runner.get_final_llm_resp()) + + # 异步处理 WebChat 特殊情况 + if event.get_platform_name() == "webchat": + asyncio.create_task(self._handle_webchat(event, req, provider)) + + asyncio.create_task( + Metric.upload( + llm_tick=1, + model_name=agent_runner.provider.get_model(), + provider_type=agent_runner.provider.meta().type, + ), + )