From e6c985ce4eff48bf68958c26571bf61d0ea43c70 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 13 Jan 2025 12:42:32 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96WebChat=E9=95=BF?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../platform/sources/webchat/webchat_event.py | 6 +- astrbot/dashboard/routes/chat.py | 133 +++++++++++------- dashboard/src/views/ChatPage.vue | 129 ++++++++++++----- 3 files changed, 185 insertions(+), 83 deletions(-) diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 0ef57ed5f..b19526000 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -16,9 +16,11 @@ class WebChatMessageEvent(AstrMessageEvent): web_chat_back_queue.put_nowait(None) return + cid = self.session_id.split("!")[-1] + for comp in message.chain: if isinstance(comp, Plain): - web_chat_back_queue.put_nowait(comp.text) + web_chat_back_queue.put_nowait((comp.text, cid)) elif isinstance(comp, Image): # save image to local filename = str(uuid.uuid4()) + ".jpg" @@ -30,6 +32,6 @@ class WebChatMessageEvent(AstrMessageEvent): f.write(f2.read()) elif comp.file and comp.file.startswith("http"): await download_image_by_url(comp.file, path=path) - web_chat_back_queue.put_nowait(f"[IMAGE]{filename}") + web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid)) web_chat_back_queue.put_nowait(None) await super().send(message) \ No newline at end of file diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 639c7cafa..382132121 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -3,9 +3,10 @@ import json import os from .route import Route, Response, RouteContext from astrbot.core import web_chat_queue, web_chat_back_queue -from quart import request, Response as QuartResponse, g +from quart import request, Response as QuartResponse, g, make_response from astrbot.core.db import BaseDatabase import asyncio +from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -14,6 +15,7 @@ class ChatRoute(Route): super().__init__(context) self.routes = { '/chat/send': ('POST', self.chat), + '/chat/listen': ('GET', self.listener), '/chat/new_conversation': ('GET', self.new_conversation), '/chat/conversations': ('GET', self.get_conversations), '/chat/get_conversation': ('GET', self.get_conversation), @@ -30,6 +32,9 @@ class ChatRoute(Route): self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp'] + self.curr_user_cid = {} + self.curr_chat_sse = {} + async def status(self): has_llm_enabled = self.core_lifecycle.provider_manager.curr_provider_inst is not None has_stt_enabled = self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None @@ -107,63 +112,92 @@ class ChatRoute(Route): if not conversation_id: return Response().error("conversation_id is empty").__dict__ + self.curr_user_cid[username] = conversation_id + await web_chat_queue.put((username, conversation_id, { 'message': message, 'image_url': image_url, # list 'audio_url': audio_url })) - async def stream(): - ret = [] - while True: - try: - result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=30) # 设置超时时间为5秒 - except asyncio.TimeoutError: - yield '[Error] 30 秒内没有返回数据,已放弃。\n' - return - - if result is None: - break - - ret.append(result) - - yield result + '\n' - - await asyncio.sleep(0.5) - - conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id) - try: - history = json.loads(conversation.history) - except BaseException as e: - print(e) - history = [] - - new_his = { - 'type': 'user', - 'message': message - } - if image_url: - new_his['image_url'] = image_url - if audio_url: - new_his['audio_url'] = audio_url - history.append(new_his) - for r in ret: - history.append({ - 'type': 'bot', - 'message': r - }) - self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history)) + # 持久化 + conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id) + try: + history = json.loads(conversation.history) + except BaseException as e: + print(e) + history = [] + new_his = { + 'type': 'user', + 'message': message + } + if image_url: + new_his['image_url'] = image_url + if audio_url: + new_his['audio_url'] = audio_url + history.append(new_his) + self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history)) - return QuartResponse( + return Response().ok().__dict__ + + async def listener(self): + '''一直保持长连接''' + + username = g.get('username', 'guest') + + if username in self.curr_chat_sse: + return "[ERROR]\n" + + self.curr_chat_sse[username] = None + + async def stream(): + try: + yield '[HB]\n' + while True: + try: + result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=10) # 设置超时时间为5秒 + except asyncio.TimeoutError: + yield '[HB]\n' # 心跳包 + continue + + if not result: + continue + result_text, cid = result + if cid != self.curr_user_cid.get(username): + # 丢弃 + continue + yield result_text + '\n' + + conversation = self.db.get_webchat_conversation_by_user_id(username, cid) + try: + history = json.loads(conversation.history) + except BaseException as e: + print(e) + history = [] + history.append({ + 'type': 'bot', + 'message': result_text + }) + self.db.update_webchat_conversation(username, cid, history=json.dumps(history)) + + await asyncio.sleep(0.5) + except BaseException as e: + logger.error(e) + logger.error(f"与用户 {username} 断开聊天长连接。") + self.curr_chat_sse.pop(username) + return + + response = await make_response( stream(), - mimetype="text/event-stream", - headers={ - "Content-Type": "text/event-stream", - "Transfer-Encoding": "chunked", - "Connection": "keep-alive", - "Access-Control-Allow-Origin": "*" # 如果是跨域请求 + { + 'Content-Type': 'text/event-stream', + 'Cache-Control': 'no-cache', + 'Transfer-Encoding': 'chunked', + 'Connection': 'keep-alive' } ) + response.timeout = None + return response async def delete_conversation(self): username = g.get('username', 'guest') @@ -194,4 +228,7 @@ class ChatRoute(Route): return Response().error("Missing key: conversation_id").__dict__ conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id) + + self.curr_user_cid[username] = conversation_id + return Response().ok(data=conversation).__dict__ \ No newline at end of file diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index cfe96e544..3ea9c9a2e 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -1,9 +1,7 @@