chore: format code

This commit is contained in:
Raven95676
2025-07-11 11:23:53 +08:00
parent 8ebf087dbf
commit 4dace7c5d8
8 changed files with 78 additions and 60 deletions
+6 -6
View File
@@ -1,18 +1,18 @@
from astrbot.core.message.message_event_result import (
MessageEventResult,
EventResultType,
MessageEventResult,
)
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .session_status_check.stage import SessionStatusCheckStage
from .rate_limit_check.stage import RateLimitStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .platform_compatibility.stage import PlatformCompatibilityStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .rate_limit_check.stage import RateLimitStage
from .respond.stage import RespondStage
from .result_decorate.stage import ResultDecorateStage
from .session_status_check.stage import SessionStatusCheckStage
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
# 管道阶段顺序
STAGES_ORDER = [
@@ -2,30 +2,30 @@
本地 Agent 模式的 LLM 调用 Stage
"""
import traceback
import asyncio
import json
import copy
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
import json
import traceback
from typing import AsyncGenerator, Union
from astrbot.core import logger
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
MessageChain,
)
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import (
ProviderRequest,
LLMResponse,
ProviderRequest,
)
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from ...context import PipelineContext
from ..agent_runner.tool_loop_agent import ToolLoopAgent
from astrbot.core.provider import Provider
from ..stage import Stage
class LLMRequestSubStage(Stage):
@@ -74,13 +74,11 @@ class LLMRequestSubStage(Stage):
logger.debug("未启用 LLM 能力,跳过处理。")
return
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
return
provider = self._select_provider(event)
if provider is None:
return
@@ -167,7 +165,6 @@ class LLMRequestSubStage(Stage):
if not req.session_id:
req.session_id = event.unified_msg_origin
# fix messages
req.contexts = self.fix_messages(req.contexts)
@@ -184,9 +181,11 @@ class LLMRequestSubStage(Stage):
while step_idx < self.max_step:
# 在每次实际请求 LLM 前检查会话级别的启停状态,这可以防止插件或函数工具调用时绕过会话级别的限制
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,终止 LLM 请求。")
logger.debug(
f"会话 {event.unified_msg_origin} 禁用了 LLM,终止 LLM 请求。"
)
return
step_idx += 1
try:
async for resp in tool_loop_agent.step():
@@ -286,9 +285,11 @@ class LLMRequestSubStage(Stage):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
# 检查会话级别的LLM启停状态,防止标题生成功能绕过会话级别限制
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过 WebChat 标题生成。")
logger.debug(
f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过 WebChat 标题生成。"
)
return
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, req.conversation.cid
)
@@ -3,7 +3,7 @@ import time
import traceback
from typing import AsyncGenerator, Union
from astrbot.core import html_renderer, logger, file_token_service
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -177,7 +177,7 @@ class ResultDecorateStage(Stage):
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
event.unified_msg_origin
)
if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
and result.is_llm_result()
+12 -8
View File
@@ -1,14 +1,16 @@
from ..stage import Stage, register_stage
from ..context import PipelineContext
from typing import AsyncGenerator, Union
from astrbot import logger
from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from ..context import PipelineContext
from ..stage import Stage, register_stage
@register_stage
@@ -168,7 +170,9 @@ class WakingCheckStage(Stage):
event._extras.pop("parsed_params", None)
# 根据会话配置过滤插件处理器
activated_handlers = SessionPluginManager.filter_handlers_by_session(event, activated_handlers)
activated_handlers = SessionPluginManager.filter_handlers_by_session(
event, activated_handlers
)
event.set_extra("activated_handlers", activated_handlers)
event.set_extra("handlers_parsed_params", handlers_parsed_params)
+3 -2
View File
@@ -2,8 +2,9 @@
会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态
"""
from astrbot.core import sp, logger
from typing import Dict
from astrbot.core import logger, sp
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -316,7 +317,7 @@ class SessionServiceManager:
custom_name = SessionServiceManager.get_session_custom_name(session_id)
if custom_name:
return custom_name
# 如果没有自定义名称,返回session_id的最后一段
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
@@ -1,12 +1,15 @@
import traceback
from .route import Route, Response, RouteContext
from astrbot.core import logger, sp
from quart import request
from astrbot.core.db import BaseDatabase
from astrbot.core import logger, sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.provider.entities import ProviderType
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.session_plugin_manager import SessionPluginManager
from .route import Response, Route, RouteContext
class SessionManagementRoute(Route):
+16 -11
View File
@@ -1,20 +1,23 @@
import logging
import jwt
import asyncio
import logging
import os
import socket
import jwt
import psutil
from astrbot.core.config.default import VERSION
from quart import Quart, request, jsonify, g
from quart import Quart, g, jsonify, request
from quart.logging import default_handler
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .routes import *
from .routes.route import RouteContext, Response
from .routes.session_management import SessionManagementRoute
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.io import get_local_ip_addresses
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import get_local_ip_addresses
from .routes import *
from .routes.route import Response, RouteContext
from .routes.session_management import SessionManagementRoute
APP: Quart = None
@@ -36,7 +39,7 @@ class AstrBotDashboard:
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
self.app.json.sort_keys = False
self.app.before_request(self.auth_middleware)
# token 用于验证请求
# token 用于验证请求
logging.getLogger(self.app.name).removeHandler(default_handler)
self.context = RouteContext(self.config, self.app)
self.ur = UpdateRoute(
@@ -54,7 +57,9 @@ class AstrBotDashboard:
self.tools_root = ToolsRoute(self.context, core_lifecycle)
self.conversation_route = ConversationRoute(self.context, db, core_lifecycle)
self.file_route = FileRoute(self.context)
self.session_management_route = SessionManagementRoute(self.context, db, core_lifecycle)
self.session_management_route = SessionManagementRoute(
self.context, db, core_lifecycle
)
self.app.add_url_rule(
"/api/plug/<path:subpath>",
+12 -8
View File
@@ -338,31 +338,35 @@ class Main(star.Star):
async def tts(self, event: AstrMessageEvent):
"""开关文本转语音(会话级别)"""
from astrbot.core.star.session_llm_manager import SessionServiceManager
session_id = event.unified_msg_origin
current_status = SessionServiceManager.is_tts_enabled_for_session(session_id)
# 切换状态
new_status = not current_status
SessionServiceManager.set_tts_status_for_session(session_id, new_status)
status_text = "已开启" if new_status else "已关闭"
event.set_result(MessageEventResult().message(f"{status_text}当前会话的文本转语音。"))
event.set_result(
MessageEventResult().message(f"{status_text}当前会话的文本转语音。")
)
@filter.command("mcp")
async def mcp(self, event: AstrMessageEvent):
"""开关MCP工具调用(会话级别)"""
from astrbot.core.star.session_llm_manager import SessionServiceManager
session_id = event.unified_msg_origin
current_status = SessionServiceManager.is_mcp_enabled_for_session(session_id)
# 切换状态
new_status = not current_status
SessionServiceManager.set_mcp_status_for_session(session_id, new_status)
status_text = "已开启" if new_status else "已关闭"
event.set_result(MessageEventResult().message(f"{status_text}当前会话的MCP工具调用。"))
event.set_result(
MessageEventResult().message(f"{status_text}当前会话的MCP工具调用。")
)
@filter.command("sid")
async def sid(self, event: AstrMessageEvent):