chore: format code
This commit is contained in:
@@ -1,18 +1,18 @@
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
EventResultType,
|
||||
MessageEventResult,
|
||||
)
|
||||
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .session_status_check.stage import SessionStatusCheckStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .respond.stage import RespondStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .session_status_check.stage import SessionStatusCheckStage
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
|
||||
# 管道阶段顺序
|
||||
STAGES_ORDER = [
|
||||
|
||||
@@ -2,30 +2,30 @@
|
||||
本地 Agent 模式的 LLM 调用 Stage
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import asyncio
|
||||
import json
|
||||
import copy
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
import json
|
||||
import traceback
|
||||
from typing import AsyncGenerator, Union
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
MessageChain,
|
||||
)
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
ProviderRequest,
|
||||
)
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from ...context import PipelineContext
|
||||
from ..agent_runner.tool_loop_agent import ToolLoopAgent
|
||||
from astrbot.core.provider import Provider
|
||||
from ..stage import Stage
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -74,13 +74,11 @@ class LLMRequestSubStage(Stage):
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
@@ -167,7 +165,6 @@ class LLMRequestSubStage(Stage):
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
|
||||
# fix messages
|
||||
req.contexts = self.fix_messages(req.contexts)
|
||||
|
||||
@@ -184,9 +181,11 @@ class LLMRequestSubStage(Stage):
|
||||
while step_idx < self.max_step:
|
||||
# 在每次实际请求 LLM 前检查会话级别的启停状态,这可以防止插件或函数工具调用时绕过会话级别的限制
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,终止 LLM 请求。")
|
||||
logger.debug(
|
||||
f"会话 {event.unified_msg_origin} 禁用了 LLM,终止 LLM 请求。"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
step_idx += 1
|
||||
try:
|
||||
async for resp in tool_loop_agent.step():
|
||||
@@ -286,9 +285,11 @@ class LLMRequestSubStage(Stage):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
# 检查会话级别的LLM启停状态,防止标题生成功能绕过会话级别限制
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过 WebChat 标题生成。")
|
||||
logger.debug(
|
||||
f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过 WebChat 标题生成。"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import time
|
||||
import traceback
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot.core import html_renderer, logger, file_token_service
|
||||
from astrbot.core import file_token_service, html_renderer, logger
|
||||
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -177,7 +177,7 @@ class ResultDecorateStage(Stage):
|
||||
tts_provider = self.ctx.plugin_manager.context.get_using_tts_provider(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
|
||||
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
and result.is_llm_result()
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from astrbot import logger
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
|
||||
|
||||
@register_stage
|
||||
@@ -168,7 +170,9 @@ class WakingCheckStage(Stage):
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
# 根据会话配置过滤插件处理器
|
||||
activated_handlers = SessionPluginManager.filter_handlers_by_session(event, activated_handlers)
|
||||
activated_handlers = SessionPluginManager.filter_handlers_by_session(
|
||||
event, activated_handlers
|
||||
)
|
||||
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
event.set_extra("handlers_parsed_params", handlers_parsed_params)
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
会话服务管理器 - 负责管理每个会话的LLM、TTS等服务的启停状态
|
||||
"""
|
||||
|
||||
from astrbot.core import sp, logger
|
||||
from typing import Dict
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
@@ -316,7 +317,7 @@ class SessionServiceManager:
|
||||
custom_name = SessionServiceManager.get_session_custom_name(session_id)
|
||||
if custom_name:
|
||||
return custom_name
|
||||
|
||||
|
||||
# 如果没有自定义名称,返回session_id的最后一段
|
||||
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id
|
||||
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import logger, sp
|
||||
|
||||
from quart import request
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
class SessionManagementRoute(Route):
|
||||
|
||||
+16
-11
@@ -1,20 +1,23 @@
|
||||
import logging
|
||||
import jwt
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
|
||||
import jwt
|
||||
import psutil
|
||||
from astrbot.core.config.default import VERSION
|
||||
from quart import Quart, request, jsonify, g
|
||||
from quart import Quart, g, jsonify, request
|
||||
from quart.logging import default_handler
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from .routes import *
|
||||
from .routes.route import RouteContext, Response
|
||||
from .routes.session_management import SessionManagementRoute
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
|
||||
from .routes import *
|
||||
from .routes.route import Response, RouteContext
|
||||
from .routes.session_management import SessionManagementRoute
|
||||
|
||||
APP: Quart = None
|
||||
|
||||
@@ -36,7 +39,7 @@ class AstrBotDashboard:
|
||||
) # 将 Flask 允许的最大上传文件体大小设置为 128 MB
|
||||
self.app.json.sort_keys = False
|
||||
self.app.before_request(self.auth_middleware)
|
||||
# token 用于验证请求
|
||||
# token 用于验证请求
|
||||
logging.getLogger(self.app.name).removeHandler(default_handler)
|
||||
self.context = RouteContext(self.config, self.app)
|
||||
self.ur = UpdateRoute(
|
||||
@@ -54,7 +57,9 @@ class AstrBotDashboard:
|
||||
self.tools_root = ToolsRoute(self.context, core_lifecycle)
|
||||
self.conversation_route = ConversationRoute(self.context, db, core_lifecycle)
|
||||
self.file_route = FileRoute(self.context)
|
||||
self.session_management_route = SessionManagementRoute(self.context, db, core_lifecycle)
|
||||
self.session_management_route = SessionManagementRoute(
|
||||
self.context, db, core_lifecycle
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
|
||||
@@ -338,31 +338,35 @@ class Main(star.Star):
|
||||
async def tts(self, event: AstrMessageEvent):
|
||||
"""开关文本转语音(会话级别)"""
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
|
||||
|
||||
session_id = event.unified_msg_origin
|
||||
current_status = SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
|
||||
|
||||
# 切换状态
|
||||
new_status = not current_status
|
||||
SessionServiceManager.set_tts_status_for_session(session_id, new_status)
|
||||
|
||||
|
||||
status_text = "已开启" if new_status else "已关闭"
|
||||
event.set_result(MessageEventResult().message(f"{status_text}当前会话的文本转语音。"))
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"{status_text}当前会话的文本转语音。")
|
||||
)
|
||||
|
||||
@filter.command("mcp")
|
||||
async def mcp(self, event: AstrMessageEvent):
|
||||
"""开关MCP工具调用(会话级别)"""
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
|
||||
|
||||
session_id = event.unified_msg_origin
|
||||
current_status = SessionServiceManager.is_mcp_enabled_for_session(session_id)
|
||||
|
||||
|
||||
# 切换状态
|
||||
new_status = not current_status
|
||||
SessionServiceManager.set_mcp_status_for_session(session_id, new_status)
|
||||
|
||||
|
||||
status_text = "已开启" if new_status else "已关闭"
|
||||
event.set_result(MessageEventResult().message(f"{status_text}当前会话的MCP工具调用。"))
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"{status_text}当前会话的MCP工具调用。")
|
||||
)
|
||||
|
||||
@filter.command("sid")
|
||||
async def sid(self, event: AstrMessageEvent):
|
||||
|
||||
Reference in New Issue
Block a user