diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 528a8cab8..e229038d7 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -303,6 +303,7 @@ class WecomPlatformAdapter(Platform): abm.session_id = external_userid abm.type = MessageType.FRIEND_MESSAGE abm.message_id = msg.get("msgid", uuid.uuid4().hex[:8]) + abm.message_str = "" if msgtype == "text": text = msg.get("text", {}).get("content", "").strip() abm.message = [Plain(text=text)] @@ -316,7 +317,29 @@ class WecomPlatformAdapter(Platform): with open(path, "wb") as f: f.write(resp.content) abm.message = [Image(file=path, url=path)] - abm.message_str = "[图片]" + elif msgtype == "voice": + media_id = msg.get("voice", {}).get("media_id", "") + resp: Response = await asyncio.get_event_loop().run_in_executor( + None, self.client.media.download, media_id + ) + + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr") + with open(path, "wb") as f: + f.write(resp.content) + + try: + from pydub import AudioSegment + + path_wav = os.path.join(temp_dir, f"weixinkefu_{media_id}.wav") + audio = AudioSegment.from_file(path) + audio.export(path_wav, format="wav") + except Exception as e: + logger.error(f"转换音频失败: {e}。如果没有安装 ffmpeg 请先安装。") + path_wav = path + return + + abm.message = [Record(file=path_wav, url=path_wav)] else: logger.warning(f"未实现的微信客服消息事件: {msg}") return diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 1c1c09c91..e8078a9ac 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -120,6 +120,30 @@ class WecomPlatformEvent(AstrMessageEvent): self.get_self_id(), response["media_id"], ) + elif isinstance(comp, Record): + record_path = await comp.convert_to_file_path() + # 转成amr + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr") + pydub.AudioSegment.from_wav(record_path).export( + record_path_amr, format="amr" + ) + + with open(record_path_amr, "rb") as f: + try: + response = self.client.media.upload("voice", f) + except Exception as e: + logger.error(f"微信客服上传语音失败: {e}") + await self.send( + MessageChain().message(f"微信客服上传语音失败: {e}") + ) + return + logger.info(f"微信客服上传语音返回: {response}") + kf_message_api.send_voice( + user_id, + self.get_self_id(), + response["media_id"], + ) else: logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。") else: diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index f9309c3eb..8d08b9d53 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -9,6 +9,7 @@ from .chat import ChatRoute from .tools import ToolsRoute # 导入新的ToolsRoute from .conversation import ConversationRoute from .file import FileRoute +from .session_management import SessionManagementRoute __all__ = [ @@ -23,4 +24,5 @@ __all__ = [ "ToolsRoute", "ConversationRoute", "FileRoute", + "SessionManagementRoute", ] diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py new file mode 100644 index 000000000..dd7da8add --- /dev/null +++ b/astrbot/dashboard/routes/session_management.py @@ -0,0 +1,361 @@ +import traceback +from .route import Route, Response, RouteContext +from astrbot.core import logger, sp +from quart import request +from astrbot.core.db import BaseDatabase +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.provider.entities import ProviderType + + +class SessionManagementRoute(Route): + def __init__( + self, + context: RouteContext, + db_helper: BaseDatabase, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: + super().__init__(context) + self.routes = { + "/session/list": ("GET", self.list_sessions), + "/session/update_persona": ("POST", self.update_session_persona), + "/session/update_provider": ("POST", self.update_session_provider), + "/session/get_session_info": ("POST", self.get_session_info), + } + self.db_helper = db_helper + self.core_lifecycle = core_lifecycle + self.register_routes() + + async def list_sessions(self): + """获取所有会话的列表,包括 persona 和 provider 信息""" + try: + # 获取所有会话的对话信息 + conversations = self.db_helper.get_all_conversations() + + # 获取会话对话映射 + session_conversations = sp.get("session_conversation", {}) + + # 获取会话提供商偏好设置 + session_provider_perf = sp.get("session_provider_perf", {}) + + # 获取可用的 personas + personas = self.core_lifecycle.star_context.provider_manager.personas + + # 获取可用的 providers + provider_manager = self.core_lifecycle.star_context.provider_manager + + sessions = [] + + # 构建会话信息 + for session_id, conversation_id in session_conversations.items(): + session_info = { + "session_id": session_id, + "conversation_id": conversation_id, + "persona_id": None, + "persona_name": None, + "chat_provider_id": None, + "chat_provider_name": None, + "stt_provider_id": None, + "stt_provider_name": None, + "tts_provider_id": None, + "tts_provider_name": None, + "platform": session_id.split(":")[0] if ":" in session_id else "unknown", + "message_type": session_id.split(":")[1] if session_id.count(":") >= 1 else "unknown", + "session_name": session_id.split(":")[2] if session_id.count(":") >= 2 else session_id, + } + + # 获取对话信息 + conversation = self.db_helper.get_conversation_by_user_id(session_id, conversation_id) + if conversation: + session_info["persona_id"] = conversation.persona_id + # 查找 persona 名称 + if conversation.persona_id and conversation.persona_id != "[%None]": + for persona in personas: + if persona["name"] == conversation.persona_id: + session_info["persona_name"] = persona["name"] + break + elif conversation.persona_id == "[%None]": + session_info["persona_name"] = "无人格" + else: + # 使用默认人格 + default_persona = provider_manager.selected_default_persona + if default_persona: + session_info["persona_id"] = default_persona["name"] + session_info["persona_name"] = default_persona["name"] + + # 获取会话的 provider 偏好设置 + session_perf = session_provider_perf.get(session_id, {}) + + # Chat completion provider + chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value) + if chat_provider_id: + chat_provider = provider_manager.inst_map.get(chat_provider_id) + if chat_provider: + session_info["chat_provider_id"] = chat_provider_id + session_info["chat_provider_name"] = chat_provider.meta().id + else: + # 使用默认 provider + default_provider = provider_manager.curr_provider_inst + if default_provider: + session_info["chat_provider_id"] = default_provider.meta().id + session_info["chat_provider_name"] = default_provider.meta().id + + # STT provider + stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value) + if stt_provider_id: + stt_provider = provider_manager.inst_map.get(stt_provider_id) + if stt_provider: + session_info["stt_provider_id"] = stt_provider_id + session_info["stt_provider_name"] = stt_provider.meta().id + else: + # 使用默认 STT provider + default_stt_provider = provider_manager.curr_stt_provider_inst + if default_stt_provider: + session_info["stt_provider_id"] = default_stt_provider.meta().id + session_info["stt_provider_name"] = default_stt_provider.meta().id + + # TTS provider + tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value) + if tts_provider_id: + tts_provider = provider_manager.inst_map.get(tts_provider_id) + if tts_provider: + session_info["tts_provider_id"] = tts_provider_id + session_info["tts_provider_name"] = tts_provider.meta().id + else: + # 使用默认 TTS provider + default_tts_provider = provider_manager.curr_tts_provider_inst + if default_tts_provider: + session_info["tts_provider_id"] = default_tts_provider.meta().id + session_info["tts_provider_name"] = default_tts_provider.meta().id + + sessions.append(session_info) + + # 获取可用的 personas 和 providers 列表 + available_personas = [{"name": p["name"], "prompt": p.get("prompt", "")} for p in personas] + + available_chat_providers = [] + for provider in provider_manager.provider_insts: + meta = provider.meta() + available_chat_providers.append({ + "id": meta.id, + "name": meta.id, + "model": meta.model, + "type": meta.type, + }) + + available_stt_providers = [] + for provider in provider_manager.stt_provider_insts: + meta = provider.meta() + available_stt_providers.append({ + "id": meta.id, + "name": meta.id, + "model": meta.model, + "type": meta.type, + }) + + available_tts_providers = [] + for provider in provider_manager.tts_provider_insts: + meta = provider.meta() + available_tts_providers.append({ + "id": meta.id, + "name": meta.id, + "model": meta.model, + "type": meta.type, + }) + + result = { + "sessions": sessions, + "available_personas": available_personas, + "available_chat_providers": available_chat_providers, + "available_stt_providers": available_stt_providers, + "available_tts_providers": available_tts_providers, + } + + return Response().ok(result).__dict__ + + except Exception as e: + error_msg = f"获取会话列表失败: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + return Response().error(f"获取会话列表失败: {str(e)}").__dict__ + + async def update_session_persona(self): + """更新指定会话的 persona""" + try: + data = await request.get_json() + session_id = data.get("session_id") + persona_name = data.get("persona_name") + + if not session_id: + return Response().error("缺少必要参数: session_id").__dict__ + + if persona_name is None: + return Response().error("缺少必要参数: persona_name").__dict__ + + # 获取会话当前的对话 ID + conversation_manager = self.core_lifecycle.star_context.conversation_manager + conversation_id = await conversation_manager.get_curr_conversation_id(session_id) + + if not conversation_id: + # 如果没有对话,创建一个新的对话 + conversation_id = await conversation_manager.new_conversation(session_id) + + # 更新 persona + await conversation_manager.update_conversation_persona_id(session_id, persona_name) + + return Response().ok({"message": f"成功更新会话 {session_id} 的人格为 {persona_name}"}).__dict__ + + except Exception as e: + error_msg = f"更新会话人格失败: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + return Response().error(f"更新会话人格失败: {str(e)}").__dict__ + + async def update_session_provider(self): + """更新指定会话的 provider""" + try: + data = await request.get_json() + session_id = data.get("session_id") + provider_id = data.get("provider_id") + provider_type = data.get("provider_type") # "chat_completion", "speech_to_text", "text_to_speech" + + if not session_id or not provider_id or not provider_type: + return Response().error("缺少必要参数: session_id, provider_id, provider_type").__dict__ + + # 转换 provider_type 字符串为枚举 + try: + if provider_type == "chat_completion": + provider_type_enum = ProviderType.CHAT_COMPLETION + elif provider_type == "speech_to_text": + provider_type_enum = ProviderType.SPEECH_TO_TEXT + elif provider_type == "text_to_speech": + provider_type_enum = ProviderType.TEXT_TO_SPEECH + else: + return Response().error(f"不支持的 provider_type: {provider_type}").__dict__ + except Exception as e: + return Response().error(f"无效的 provider_type: {provider_type}").__dict__ + + # 设置 provider + provider_manager = self.core_lifecycle.star_context.provider_manager + await provider_manager.set_provider( + provider_id=provider_id, + provider_type=provider_type_enum, + umo=session_id, + ) + + return Response().ok({ + "message": f"成功更新会话 {session_id} 的 {provider_type} 提供商为 {provider_id}" + }).__dict__ + + except Exception as e: + error_msg = f"更新会话提供商失败: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + return Response().error(f"更新会话提供商失败: {str(e)}").__dict__ + + async def get_session_info(self): + """获取指定会话的详细信息""" + try: + data = await request.get_json() + session_id = data.get("session_id") + + if not session_id: + return Response().error("缺少必要参数: session_id").__dict__ + + # 获取会话对话信息 + session_conversations = sp.get("session_conversation", {}) + conversation_id = session_conversations.get(session_id) + + if not conversation_id: + return Response().error(f"会话 {session_id} 未找到对话").__dict__ + + session_info = { + "session_id": session_id, + "conversation_id": conversation_id, + "persona_id": None, + "persona_name": None, + "chat_provider_id": None, + "chat_provider_name": None, + "stt_provider_id": None, + "stt_provider_name": None, + "tts_provider_id": None, + "tts_provider_name": None, + "platform": session_id.split(":")[0] if ":" in session_id else "unknown", + "message_type": session_id.split(":")[1] if session_id.count(":") >= 1 else "unknown", + "session_name": session_id.split(":")[2] if session_id.count(":") >= 2 else session_id, + } + + # 获取对话信息 + conversation = self.db_helper.get_conversation_by_user_id(session_id, conversation_id) + if conversation: + session_info["persona_id"] = conversation.persona_id + + # 查找 persona 名称 + provider_manager = self.core_lifecycle.star_context.provider_manager + personas = provider_manager.personas + + if conversation.persona_id and conversation.persona_id != "[%None]": + for persona in personas: + if persona["name"] == conversation.persona_id: + session_info["persona_name"] = persona["name"] + break + elif conversation.persona_id == "[%None]": + session_info["persona_name"] = "无人格" + else: + # 使用默认人格 + default_persona = provider_manager.selected_default_persona + if default_persona: + session_info["persona_id"] = default_persona["name"] + session_info["persona_name"] = default_persona["name"] + + # 获取会话的 provider 偏好设置 + session_provider_perf = sp.get("session_provider_perf", {}) + session_perf = session_provider_perf.get(session_id, {}) + + # 获取 provider 信息 + provider_manager = self.core_lifecycle.star_context.provider_manager + + # Chat completion provider + chat_provider_id = session_perf.get(ProviderType.CHAT_COMPLETION.value) + if chat_provider_id: + chat_provider = provider_manager.inst_map.get(chat_provider_id) + if chat_provider: + session_info["chat_provider_id"] = chat_provider_id + session_info["chat_provider_name"] = chat_provider.meta().id + else: + # 使用默认 provider + default_provider = provider_manager.curr_provider_inst + if default_provider: + session_info["chat_provider_id"] = default_provider.meta().id + session_info["chat_provider_name"] = default_provider.meta().id + + # STT provider + stt_provider_id = session_perf.get(ProviderType.SPEECH_TO_TEXT.value) + if stt_provider_id: + stt_provider = provider_manager.inst_map.get(stt_provider_id) + if stt_provider: + session_info["stt_provider_id"] = stt_provider_id + session_info["stt_provider_name"] = stt_provider.meta().id + else: + # 使用默认 STT provider + default_stt_provider = provider_manager.curr_stt_provider_inst + if default_stt_provider: + session_info["stt_provider_id"] = default_stt_provider.meta().id + session_info["stt_provider_name"] = default_stt_provider.meta().id + + # TTS provider + tts_provider_id = session_perf.get(ProviderType.TEXT_TO_SPEECH.value) + if tts_provider_id: + tts_provider = provider_manager.inst_map.get(tts_provider_id) + if tts_provider: + session_info["tts_provider_id"] = tts_provider_id + session_info["tts_provider_name"] = tts_provider.meta().id + else: + # 使用默认 TTS provider + default_tts_provider = provider_manager.curr_tts_provider_inst + if default_tts_provider: + session_info["tts_provider_id"] = default_tts_provider.meta().id + session_info["tts_provider_name"] = default_tts_provider.meta().id + + return Response().ok(session_info).__dict__ + + except Exception as e: + error_msg = f"获取会话信息失败: {str(e)}\n{traceback.format_exc()}" + logger.error(error_msg) + return Response().error(f"获取会话信息失败: {str(e)}").__dict__ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index acdc8a49f..074a095d3 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -10,6 +10,7 @@ from quart.logging import default_handler from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from .routes import * from .routes.route import RouteContext, Response +from .routes.session_management import SessionManagementRoute from astrbot.core import logger, WEBUI_SK from astrbot.core.db import BaseDatabase from astrbot.core.utils.io import get_local_ip_addresses @@ -35,8 +36,7 @@ class AstrBotDashboard: ) # 将 Flask 允许的最大上传文件体大小设置为 128 MB self.app.json.sort_keys = False self.app.before_request(self.auth_middleware) - # token 用于验证请求 - logging.getLogger(self.app.name).removeHandler(default_handler) + # token 用于验证请求 logging.getLogger(self.app.name).removeHandler(default_handler) self.context = RouteContext(self.config, self.app) self.ur = UpdateRoute( self.context, core_lifecycle.astrbot_updator, core_lifecycle @@ -53,6 +53,7 @@ class AstrBotDashboard: self.tools_root = ToolsRoute(self.context, core_lifecycle) self.conversation_route = ConversationRoute(self.context, db, core_lifecycle) self.file_route = FileRoute(self.context) + self.session_management_route = SessionManagementRoute(self.context, db, core_lifecycle) self.app.add_url_rule( "/api/plug/", diff --git a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts index e8f49c741..71400f50c 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts +++ b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts @@ -49,8 +49,7 @@ const sidebarItem: menu[] = [ title: '插件市场', icon: 'mdi-storefront', to: '/extension-marketplace' - }, - { + }, { title: '聊天', icon: 'mdi-chat', to: '/chat' @@ -60,6 +59,11 @@ const sidebarItem: menu[] = [ icon: 'mdi-database', to: '/conversation' }, + { + title: '会话管理', + icon: 'mdi-account-group', + to: '/session-management' + }, { title: '控制台', icon: 'mdi-console', diff --git a/dashboard/src/router/MainRoutes.ts b/dashboard/src/router/MainRoutes.ts index 0c2d9cf03..cadb90860 100644 --- a/dashboard/src/router/MainRoutes.ts +++ b/dashboard/src/router/MainRoutes.ts @@ -45,12 +45,16 @@ const MainRoutes = { name: 'Default', path: '/dashboard/default', component: () => import('@/views/dashboards/default/DefaultDashboard.vue') - }, - { + }, { name: 'Conversation', path: '/conversation', component: () => import('@/views/ConversationPage.vue') }, + { + name: 'SessionManagement', + path: '/session-management', + component: () => import('@/views/SessionManagementPage.vue') + }, { name: 'Console', path: '/console', @@ -81,15 +85,7 @@ const MainRoutes = { { name: 'Chat', path: '/chat', - component: () => import('@/views/ChatPage.vue'), - children: [ - { - path: ':conversationId', - name: 'ChatDetail', - component: () => import('@/views/ChatPage.vue'), - props: true - } - ] + component: () => import('@/views/ChatPage.vue') }, { name: 'Settings', diff --git a/dashboard/src/views/SessionManagementPage.vue b/dashboard/src/views/SessionManagementPage.vue new file mode 100644 index 000000000..b3f8d95ad --- /dev/null +++ b/dashboard/src/views/SessionManagementPage.vue @@ -0,0 +1,594 @@ + + + + +