style: code quality
This commit is contained in:
@@ -3,67 +3,69 @@
|
||||
"""
|
||||
|
||||
from astrbot.core import sp, logger
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
class SessionServiceManager:
|
||||
"""管理会话级别的服务启停状态,包括LLM和TTS"""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# LLM 相关方法
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@staticmethod
|
||||
def is_llm_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查LLM是否在指定会话中启用
|
||||
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {})
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
|
||||
|
||||
# 如果配置了该会话的LLM状态,返回该状态
|
||||
llm_enabled = session_services.get("llm_enabled")
|
||||
if llm_enabled is not None:
|
||||
return llm_enabled
|
||||
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
|
||||
@staticmethod
|
||||
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置LLM在指定会话中的启停状态
|
||||
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {})
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
|
||||
# 设置LLM状态
|
||||
session_config[session_id]["llm_enabled"] = enabled
|
||||
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
|
||||
logger.info(f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理LLM请求
|
||||
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
@@ -73,57 +75,59 @@ class SessionServiceManager:
|
||||
# =============================================================================
|
||||
# TTS 相关方法
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@staticmethod
|
||||
def is_tts_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查TTS是否在指定会话中启用
|
||||
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {})
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
|
||||
|
||||
# 如果配置了该会话的TTS状态,返回该状态
|
||||
tts_enabled = session_services.get("tts_enabled")
|
||||
if tts_enabled is not None:
|
||||
return tts_enabled
|
||||
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
|
||||
@staticmethod
|
||||
def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置TTS在指定会话中的启停状态
|
||||
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {})
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
|
||||
# 设置TTS状态
|
||||
session_config[session_id]["tts_enabled"] = enabled
|
||||
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
|
||||
logger.info(f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的TTS状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_tts_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理TTS请求
|
||||
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
@@ -133,57 +137,59 @@ class SessionServiceManager:
|
||||
# =============================================================================
|
||||
# MCP 相关方法
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@staticmethod
|
||||
def is_mcp_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查MCP是否在指定会话中启用
|
||||
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_config = sp.get("session_service_config", {})
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
session_services = session_config.get(session_id, {})
|
||||
|
||||
|
||||
# 如果配置了该会话的MCP状态,返回该状态
|
||||
mcp_enabled = session_services.get("mcp_enabled")
|
||||
if mcp_enabled is not None:
|
||||
return mcp_enabled
|
||||
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
|
||||
@staticmethod
|
||||
def set_mcp_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置MCP在指定会话中的启停状态
|
||||
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_config = sp.get("session_service_config", {})
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
if session_id not in session_config:
|
||||
session_config[session_id] = {}
|
||||
|
||||
|
||||
# 设置MCP状态
|
||||
session_config[session_id]["mcp_enabled"] = enabled
|
||||
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_service_config", session_config)
|
||||
|
||||
logger.info(f"会话 {session_id} 的MCP状态已更新为: {'启用' if enabled else '禁用'}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的MCP状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_mcp_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理MCP请求
|
||||
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
@@ -193,29 +199,32 @@ class SessionServiceManager:
|
||||
# =============================================================================
|
||||
# 通用配置方法
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_session_service_config(session_id: str) -> Dict[str, bool]:
|
||||
"""获取指定会话的服务配置
|
||||
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: 包含llm_enabled、tts_enabled、mcp_enabled的字典
|
||||
"""
|
||||
session_config = sp.get("session_service_config", {})
|
||||
return session_config.get(session_id, {
|
||||
"llm_enabled": True, # 默认启用
|
||||
"tts_enabled": True, # 默认启用
|
||||
"mcp_enabled": True # 默认启用
|
||||
})
|
||||
|
||||
session_config = sp.get("session_service_config", {}) or {}
|
||||
return session_config.get(
|
||||
session_id,
|
||||
{
|
||||
"llm_enabled": True, # 默认启用
|
||||
"tts_enabled": True, # 默认启用
|
||||
"mcp_enabled": True, # 默认启用
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_all_session_configs() -> Dict[str, Dict[str, bool]]:
|
||||
"""获取所有会话的服务配置
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Dict[str, bool]]: 所有会话的服务配置
|
||||
"""
|
||||
return sp.get("session_service_config", {})
|
||||
return sp.get("session_service_config", {}) or {}
|
||||
|
||||
@@ -3,63 +3,65 @@
|
||||
"""
|
||||
|
||||
from astrbot.core import sp, logger
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
class SessionPluginManager:
|
||||
"""管理会话级别的插件启停状态"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
|
||||
"""检查插件是否在指定会话中启用
|
||||
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
plugin_name: 插件名称
|
||||
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话插件配置
|
||||
session_plugin_config = sp.get("session_plugin_config", {})
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
session_config = session_plugin_config.get(session_id, {})
|
||||
|
||||
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
|
||||
# 如果插件在禁用列表中,返回False
|
||||
if plugin_name in disabled_plugins:
|
||||
return False
|
||||
|
||||
|
||||
# 如果插件在启用列表中,返回True
|
||||
if plugin_name in enabled_plugins:
|
||||
return True
|
||||
|
||||
|
||||
# 如果都没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
|
||||
@staticmethod
|
||||
def set_plugin_status_for_session(session_id: str, plugin_name: str, enabled: bool) -> None:
|
||||
def set_plugin_status_for_session(
|
||||
session_id: str, plugin_name: str, enabled: bool
|
||||
) -> None:
|
||||
"""设置插件在指定会话中的启停状态
|
||||
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
plugin_name: 插件名称
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_plugin_config = sp.get("session_plugin_config", {})
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
if session_id not in session_plugin_config:
|
||||
session_plugin_config[session_id] = {
|
||||
"enabled_plugins": [],
|
||||
"disabled_plugins": []
|
||||
"disabled_plugins": [],
|
||||
}
|
||||
|
||||
|
||||
session_config = session_plugin_config[session_id]
|
||||
enabled_plugins = session_config.get("enabled_plugins", [])
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
|
||||
if enabled:
|
||||
# 启用插件
|
||||
if plugin_name in disabled_plugins:
|
||||
@@ -72,47 +74,48 @@ class SessionPluginManager:
|
||||
enabled_plugins.remove(plugin_name)
|
||||
if plugin_name not in disabled_plugins:
|
||||
disabled_plugins.append(plugin_name)
|
||||
|
||||
|
||||
# 保存配置
|
||||
session_config["enabled_plugins"] = enabled_plugins
|
||||
session_config["disabled_plugins"] = disabled_plugins
|
||||
session_plugin_config[session_id] = session_config
|
||||
sp.put("session_plugin_config", session_plugin_config)
|
||||
|
||||
logger.info(f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_session_plugin_config(session_id: str) -> Dict[str, List[str]]:
|
||||
"""获取指定会话的插件配置
|
||||
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
|
||||
"""
|
||||
session_plugin_config = sp.get("session_plugin_config", {})
|
||||
return session_plugin_config.get(session_id, {
|
||||
"enabled_plugins": [],
|
||||
"disabled_plugins": []
|
||||
})
|
||||
|
||||
session_plugin_config = sp.get("session_plugin_config", {}) or {}
|
||||
return session_plugin_config.get(
|
||||
session_id, {"enabled_plugins": [], "disabled_plugins": []}
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def filter_handlers_by_session(event: AstrMessageEvent, handlers: List) -> List:
|
||||
"""根据会话配置过滤处理器列表
|
||||
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
handlers: 原始处理器列表
|
||||
|
||||
|
||||
Returns:
|
||||
List: 过滤后的处理器列表
|
||||
"""
|
||||
from astrbot.core.star.star import star_map
|
||||
|
||||
|
||||
session_id = event.unified_msg_origin
|
||||
filtered_handlers = []
|
||||
|
||||
|
||||
for handler in handlers:
|
||||
# 获取处理器对应的插件
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
@@ -120,16 +123,20 @@ class SessionPluginManager:
|
||||
# 如果找不到插件元数据,允许执行(可能是系统插件)
|
||||
filtered_handlers.append(handler)
|
||||
continue
|
||||
|
||||
|
||||
# 跳过保留插件(系统插件)
|
||||
if plugin.reserved:
|
||||
filtered_handlers.append(handler)
|
||||
continue
|
||||
|
||||
|
||||
# 检查插件是否在当前会话中启用
|
||||
if SessionPluginManager.is_plugin_enabled_for_session(session_id, plugin.name):
|
||||
if SessionPluginManager.is_plugin_enabled_for_session(
|
||||
session_id, plugin.name
|
||||
):
|
||||
filtered_handlers.append(handler)
|
||||
else:
|
||||
logger.debug(f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}")
|
||||
|
||||
logger.debug(
|
||||
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}"
|
||||
)
|
||||
|
||||
return filtered_handlers
|
||||
|
||||
@@ -45,7 +45,8 @@ const MainRoutes = {
|
||||
name: 'Default',
|
||||
path: '/dashboard/default',
|
||||
component: () => import('@/views/dashboards/default/DefaultDashboard.vue')
|
||||
}, {
|
||||
},
|
||||
{
|
||||
name: 'Conversation',
|
||||
path: '/conversation',
|
||||
component: () => import('@/views/ConversationPage.vue')
|
||||
|
||||
@@ -173,11 +173,6 @@
|
||||
color="primary"
|
||||
inset
|
||||
>
|
||||
<template v-slot:label>
|
||||
<span class="text-caption">
|
||||
{{ item.llm_enabled ? tm('status.enabled') : tm('status.disabled') }}
|
||||
</span>
|
||||
</template>
|
||||
</v-switch>
|
||||
</template>
|
||||
|
||||
@@ -192,11 +187,6 @@
|
||||
color="secondary"
|
||||
inset
|
||||
>
|
||||
<template v-slot:label>
|
||||
<span class="text-caption">
|
||||
{{ item.tts_enabled ? tm('status.enabled') : tm('status.disabled') }}
|
||||
</span>
|
||||
</template>
|
||||
</v-switch>
|
||||
</template>
|
||||
|
||||
@@ -211,11 +201,6 @@
|
||||
color="info"
|
||||
inset
|
||||
>
|
||||
<template v-slot:label>
|
||||
<span class="text-caption">
|
||||
{{ item.mcp_enabled ? tm('status.enabled') : tm('status.disabled') }}
|
||||
</span>
|
||||
</template>
|
||||
</v-switch>
|
||||
</template>
|
||||
|
||||
|
||||
Reference in New Issue
Block a user