diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 7ce9c6635..3501a5271 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -1,18 +1,18 @@ from astrbot.core.message.message_event_result import ( - MessageEventResult, EventResultType, + MessageEventResult, ) -from .waking_check.stage import WakingCheckStage -from .whitelist_check.stage import WhitelistCheckStage -from .session_status_check.stage import SessionStatusCheckStage -from .rate_limit_check.stage import RateLimitStage from .content_safety_check.stage import ContentSafetyCheckStage from .platform_compatibility.stage import PlatformCompatibilityStage from .preprocess_stage.stage import PreProcessStage from .process_stage.stage import ProcessStage -from .result_decorate.stage import ResultDecorateStage +from .rate_limit_check.stage import RateLimitStage from .respond.stage import RespondStage +from .result_decorate.stage import ResultDecorateStage +from .session_status_check.stage import SessionStatusCheckStage +from .waking_check.stage import WakingCheckStage +from .whitelist_check.stage import WhitelistCheckStage # 管道阶段顺序 STAGES_ORDER = [ diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 770bd65e8..6c6d10799 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -2,30 +2,30 @@ 本地 Agent 模式的 LLM 调用 Stage """ -import traceback import asyncio -import json import copy -from typing import Union, AsyncGenerator -from ...context import PipelineContext -from ..stage import Stage -from astrbot.core.platform.astr_message_event import AstrMessageEvent +import json +import traceback +from typing import AsyncGenerator, Union +from astrbot.core import logger +from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( + MessageChain, MessageEventResult, ResultContentType, - MessageChain, ) -from astrbot.core.message.components import Image -from astrbot.core import logger -from astrbot.core.utils.metrics import Metric +from astrbot.core.platform.astr_message_event import AstrMessageEvent +from astrbot.core.provider import Provider from astrbot.core.provider.entities import ( - ProviderRequest, LLMResponse, + ProviderRequest, ) from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.star_handler import EventType +from astrbot.core.utils.metrics import Metric +from ...context import PipelineContext from ..agent_runner.tool_loop_agent import ToolLoopAgent -from astrbot.core.provider import Provider +from ..stage import Stage class LLMRequestSubStage(Stage): @@ -74,13 +74,11 @@ class LLMRequestSubStage(Stage): logger.debug("未启用 LLM 能力,跳过处理。") return - # 检查会话级别的LLM启停状态 if not SessionServiceManager.should_process_llm_request(event): logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。") return - provider = self._select_provider(event) if provider is None: return @@ -167,7 +165,6 @@ class LLMRequestSubStage(Stage): if not req.session_id: req.session_id = event.unified_msg_origin - # fix messages req.contexts = self.fix_messages(req.contexts) @@ -184,9 +181,11 @@ class LLMRequestSubStage(Stage): while step_idx < self.max_step: # 在每次实际请求 LLM 前检查会话级别的启停状态,这可以防止插件或函数工具调用时绕过会话级别的限制 if not SessionServiceManager.should_process_llm_request(event): - logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,终止 LLM 请求。") + logger.debug( + f"会话 {event.unified_msg_origin} 禁用了 LLM,终止 LLM 请求。" + ) return - + step_idx += 1 try: async for resp in tool_loop_agent.step(): @@ -286,9 +285,11 @@ class LLMRequestSubStage(Stage): """处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title""" # 检查会话级别的LLM启停状态,防止标题生成功能绕过会话级别限制 if not SessionServiceManager.should_process_llm_request(event): - logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过 WebChat 标题生成。") + logger.debug( + f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过 WebChat 标题生成。" + ) return - + conversation = await self.conv_manager.get_conversation( event.unified_msg_origin, req.conversation.cid ) diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 6ec90fdb7..c9b8b4b8a 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -3,7 +3,7 @@ import time import traceback from typing import AsyncGenerator, Union -from astrbot.core import html_renderer, logger, file_token_service +from astrbot.core import file_token_service, html_renderer, logger from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply from astrbot.core.message.message_event_result import ResultContentType from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -177,7 +177,7 @@ class ResultDecorateStage(Stage): tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider( event.unified_msg_origin ) - + if ( self.ctx.astrbot_config["provider_tts_settings"]["enable"] and result.is_llm_result() diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index 3797751bf..2345b6466 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -1,14 +1,16 @@ -from ..stage import Stage, register_stage -from ..context import PipelineContext +from typing import AsyncGenerator, Union + from astrbot import logger -from typing import Union, AsyncGenerator -from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageEventResult, MessageChain from astrbot.core.message.components import At, AtAll, Reply -from astrbot.core.star.star_handler import star_handlers_registry, EventType -from astrbot.core.star.star import star_map +from astrbot.core.message.message_event_result import MessageChain, MessageEventResult +from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.filter.permission import PermissionTypeFilter from astrbot.core.star.session_plugin_manager import SessionPluginManager +from astrbot.core.star.star import star_map +from astrbot.core.star.star_handler import EventType, star_handlers_registry + +from ..context import PipelineContext +from ..stage import Stage, register_stage @register_stage @@ -168,7 +170,9 @@ class WakingCheckStage(Stage): event._extras.pop("parsed_params", None) # 根据会话配置过滤插件处理器 - activated_handlers = SessionPluginManager.filter_handlers_by_session(event, activated_handlers) + activated_handlers = SessionPluginManager.filter_handlers_by_session( + event, activated_handlers + ) event.set_extra("activated_handlers", activated_handlers) event.set_extra("handlers_parsed_params", handlers_parsed_params) diff --git a/astrbot/core/star/session_llm_manager.py b/astrbot/core/star/session_llm_manager.py index 6121cbfa1..0c78c5637 100644 --- a/astrbot/core/star/session_llm_manager.py +++ b/astrbot/core/star/session_llm_manager.py @@ -2,8 +2,9 @@ 会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态 """ -from astrbot.core import sp, logger from typing import Dict + +from astrbot.core import logger, sp from astrbot.core.platform.astr_message_event import AstrMessageEvent @@ -316,7 +317,7 @@ class SessionServiceManager: custom_name = SessionServiceManager.get_session_custom_name(session_id) if custom_name: return custom_name - + # 如果没有自定义名称,返回session_id的最后一段 return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py index d1f5f5631..852323d29 100644 --- a/astrbot/dashboard/routes/session_management.py +++ b/astrbot/dashboard/routes/session_management.py @@ -1,12 +1,15 @@ import traceback -from .route import Route, Response, RouteContext -from astrbot.core import logger, sp + from quart import request -from astrbot.core.db import BaseDatabase + +from astrbot.core import logger, sp from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db import BaseDatabase from astrbot.core.provider.entities import ProviderType -from astrbot.core.star.session_plugin_manager import SessionPluginManager from astrbot.core.star.session_llm_manager import SessionServiceManager +from astrbot.core.star.session_plugin_manager import SessionPluginManager + +from .route import Response, Route, RouteContext class SessionManagementRoute(Route): diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 889c2f327..06f6f8e60 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -1,20 +1,23 @@ -import logging -import jwt import asyncio +import logging import os import socket + +import jwt import psutil -from astrbot.core.config.default import VERSION -from quart import Quart, request, jsonify, g +from quart import Quart, g, jsonify, request from quart.logging import default_handler -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from .routes import * -from .routes.route import RouteContext, Response -from .routes.session_management import SessionManagementRoute + from astrbot.core import logger +from astrbot.core.config.default import VERSION +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.db import BaseDatabase -from astrbot.core.utils.io import get_local_ip_addresses from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.io import get_local_ip_addresses + +from .routes import * +from .routes.route import Response, RouteContext +from .routes.session_management import SessionManagementRoute APP: Quart = None @@ -36,7 +39,7 @@ class AstrBotDashboard: ) # 将 Flask 允许的最大上传文件体大小设置为 128 MB self.app.json.sort_keys = False self.app.before_request(self.auth_middleware) - # token 用于验证请求 + # token 用于验证请求 logging.getLogger(self.app.name).removeHandler(default_handler) self.context = RouteContext(self.config, self.app) self.ur = UpdateRoute( @@ -54,7 +57,9 @@ class AstrBotDashboard: self.tools_root = ToolsRoute(self.context, core_lifecycle) self.conversation_route = ConversationRoute(self.context, db, core_lifecycle) self.file_route = FileRoute(self.context) - self.session_management_route = SessionManagementRoute(self.context, db, core_lifecycle) + self.session_management_route = SessionManagementRoute( + self.context, db, core_lifecycle + ) self.app.add_url_rule( "/api/plug/", diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 0f404320e..2be7887e8 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -338,31 +338,35 @@ class Main(star.Star): async def tts(self, event: AstrMessageEvent): """开关文本转语音(会话级别)""" from astrbot.core.star.session_llm_manager import SessionServiceManager - + session_id = event.unified_msg_origin current_status = SessionServiceManager.is_tts_enabled_for_session(session_id) - + # 切换状态 new_status = not current_status SessionServiceManager.set_tts_status_for_session(session_id, new_status) - + status_text = "已开启" if new_status else "已关闭" - event.set_result(MessageEventResult().message(f"{status_text}当前会话的文本转语音。")) + event.set_result( + MessageEventResult().message(f"{status_text}当前会话的文本转语音。") + ) @filter.command("mcp") async def mcp(self, event: AstrMessageEvent): """开关MCP工具调用(会话级别)""" from astrbot.core.star.session_llm_manager import SessionServiceManager - + session_id = event.unified_msg_origin current_status = SessionServiceManager.is_mcp_enabled_for_session(session_id) - + # 切换状态 new_status = not current_status SessionServiceManager.set_mcp_status_for_session(session_id, new_status) - + status_text = "已开启" if new_status else "已关闭" - event.set_result(MessageEventResult().message(f"{status_text}当前会话的MCP工具调用。")) + event.set_result( + MessageEventResult().message(f"{status_text}当前会话的MCP工具调用。") + ) @filter.command("sid") async def sid(self, event: AstrMessageEvent):