diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 104a9edb6..16f108ece 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -28,5 +28,3 @@ pip_installer = PipInstaller( astrbot_config.get("pip_install_arg", ""), astrbot_config.get("pypi_index_url", None), ) -web_chat_queue = asyncio.Queue(maxsize=32) -web_chat_back_queue = asyncio.Queue(maxsize=32) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 2ebe4bd42..961463c7a 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -23,7 +23,6 @@ from astrbot.core.provider.entities import ( LLMResponse, ) from astrbot.core.star.star_handler import EventType -from astrbot.core import web_chat_back_queue from ..agent_runner.tool_loop_agent import ToolLoopAgent @@ -283,13 +282,6 @@ class LLMRequestSubStage(Stage): cid=cid, title=title, ) - web_chat_back_queue.put_nowait( - { - "type": "update_title", - "cid": cid, - "data": title, - } - ) async def _save_to_history( self, diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index fa384ed99..41d3e9418 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -2,7 +2,7 @@ import time import asyncio import uuid import os -from typing import Awaitable, Any +from typing import Awaitable, Any, Callable from astrbot.core.platform import ( Platform, AstrBotMessage, @@ -13,7 +13,7 @@ from astrbot.core.platform import ( from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.components import Plain, Image, Record # noqa: F403 from astrbot import logger -from astrbot.core import web_chat_queue +from .webchat_queue_mgr import webchat_queue_mgr, WebChatQueueMgr from .webchat_event import WebChatMessageEvent from astrbot.core.platform.astr_message_event import MessageSesion from ...register import register_platform_adapter @@ -21,14 +21,46 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path class QueueListener: - def __init__(self, queue: asyncio.Queue, callback: callable) -> None: - self.queue = queue + def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None: + self.webchat_queue_mgr = webchat_queue_mgr self.callback = callback + self.running_tasks = set() + + async def listen_to_queue(self, conversation_id: str): + """Listen to a specific conversation queue""" + queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id) + while True: + try: + data = await queue.get() + await self.callback(data) + except Exception as e: + logger.error( + f"Error processing message from conversation {conversation_id}: {e}" + ) + break async def run(self): + """Monitor for new conversation queues and start listeners""" + monitored_conversations = set() + while True: - data = await self.queue.get() - await self.callback(data) + # Check for new conversations + current_conversations = set(self.webchat_queue_mgr.queues.keys()) + new_conversations = current_conversations - monitored_conversations + + # Start listeners for new conversations + for conversation_id in new_conversations: + task = asyncio.create_task(self.listen_to_queue(conversation_id)) + self.running_tasks.add(task) + task.add_done_callback(self.running_tasks.discard) + monitored_conversations.add(conversation_id) + logger.debug(f"Started listener for conversation: {conversation_id}") + + # Clean up monitored conversations that no longer exist + removed_conversations = monitored_conversations - current_conversations + monitored_conversations -= removed_conversations + + await asyncio.sleep(1) # Check for new conversations every second @register_platform_adapter("webchat", "webchat") @@ -45,7 +77,7 @@ class WebChatAdapter(Platform): os.makedirs(self.imgs_dir, exist_ok=True) self.metadata = PlatformMetadata( - name="webchat", description="webchat", id=self.config.get("id") + name="webchat", description="webchat", id=self.config.get("id", "") ) async def send_by_session( @@ -105,7 +137,7 @@ class WebChatAdapter(Platform): abm = await self.convert_message(data) await self.handle_msg(abm) - bot = QueueListener(web_chat_queue, callback) + bot = QueueListener(webchat_queue_mgr, callback) return bot.run() def meta(self) -> PlatformMetadata: diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 111027a5c..c4e5d63c0 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -5,8 +5,8 @@ from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import Plain, Image, Record from astrbot.core.utils.io import download_image_by_url -from astrbot.core import web_chat_back_queue from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from .webchat_queue_mgr import webchat_queue_mgr imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") @@ -18,13 +18,14 @@ class WebChatMessageEvent(AstrMessageEvent): @staticmethod async def _send(message: MessageChain, session_id: str, streaming: bool = False): + cid = session_id.split("!")[-1] + web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) if not message: await web_chat_back_queue.put( {"type": "end", "data": "", "streaming": False} ) return "" - cid = session_id.split("!")[-1] data = "" for comp in message.chain: if isinstance(comp, Plain): @@ -98,18 +99,22 @@ class WebChatMessageEvent(AstrMessageEvent): async def send(self, message: MessageChain): await WebChatMessageEvent._send(message, session_id=self.session_id) + cid = self.session_id.split("!")[-1] + web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) await web_chat_back_queue.put( { "type": "end", "data": "", "streaming": False, - "cid": self.session_id.split("!")[-1], + "cid": cid, } ) await super().send(message) async def send_streaming(self, generator, use_fallback: bool = False): final_data = "" + cid = self.session_id.split("!")[-1] + web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) async for chain in generator: if chain.type == "break" and final_data: # 分割符 @@ -118,7 +123,7 @@ class WebChatMessageEvent(AstrMessageEvent): "type": "end", "data": final_data, "streaming": True, - "cid": self.session_id.split("!")[-1], + "cid": cid, } ) final_data = "" @@ -132,7 +137,7 @@ class WebChatMessageEvent(AstrMessageEvent): "type": "end", "data": final_data, "streaming": True, - "cid": self.session_id.split("!")[-1], + "cid": cid, } ) await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py new file mode 100644 index 000000000..96e172212 --- /dev/null +++ b/astrbot/core/platform/sources/webchat/webchat_queue_mgr.py @@ -0,0 +1,33 @@ +import asyncio + +class WebChatQueueMgr: + def __init__(self) -> None: + self.queues = {} + """Conversation ID to asyncio.Queue mapping""" + self.back_queues = {} + """Conversation ID to asyncio.Queue mapping for responses""" + + def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue: + """Get or create a queue for the given conversation ID""" + if conversation_id not in self.queues: + self.queues[conversation_id] = asyncio.Queue() + return self.queues[conversation_id] + + def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue: + """Get or create a back queue for the given conversation ID""" + if conversation_id not in self.back_queues: + self.back_queues[conversation_id] = asyncio.Queue() + return self.back_queues[conversation_id] + + def remove_queues(self, conversation_id: str): + """Remove queues for the given conversation ID""" + if conversation_id in self.queues: + del self.queues[conversation_id] + if conversation_id in self.back_queues: + del self.back_queues[conversation_id] + + def has_queue(self, conversation_id: str) -> bool: + """Check if a queue exists for the given conversation ID""" + return conversation_id in self.queues + +webchat_queue_mgr = WebChatQueueMgr() diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 270c92b44..a273bccdc 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -2,7 +2,7 @@ import uuid import json import os from .route import Route, Response, RouteContext -from astrbot.core import web_chat_queue, web_chat_back_queue +from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr from quart import request, Response as QuartResponse, g, make_response from astrbot.core.db import BaseDatabase import asyncio @@ -21,7 +21,6 @@ class ChatRoute(Route): super().__init__(context) self.routes = { "/chat/send": ("POST", self.chat), - "/chat/listen": ("GET", self.listener), "/chat/new_conversation": ("GET", self.new_conversation), "/chat/conversations": ("GET", self.get_conversations), "/chat/get_conversation": ("GET", self.get_conversation), @@ -40,9 +39,6 @@ class ChatRoute(Route): self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] - self.curr_user_cid = {} - self.curr_chat_sse = {} - async def status(self): has_llm_enabled = ( self.core_lifecycle.provider_manager.curr_provider_inst is not None @@ -133,21 +129,10 @@ class ChatRoute(Route): if not conversation_id: return Response().error("conversation_id is empty").__dict__ - self.curr_user_cid[username] = conversation_id + # Get conversation-specific queues + back_queue = webchat_queue_mgr.get_or_create_back_queue(conversation_id) - await web_chat_queue.put( - ( - username, - conversation_id, - { - "message": message, - "image_url": image_url, # list - "audio_url": audio_url, - }, - ) - ) - - # 持久化 + # append user message conversation = self.db.get_conversation_by_user_id(username, conversation_id) try: history = json.loads(conversation.history) @@ -164,30 +149,12 @@ class ChatRoute(Route): username, conversation_id, history=json.dumps(history) ) - return Response().ok().__dict__ - - async def listener(self): - """一直保持长连接""" - - username = g.get("username", "guest") - - if username in self.curr_chat_sse: - return Response().error("Already connected").__dict__ - - self.curr_chat_sse[username] = None - - heartbeat = json.dumps({"type": "heartbeat", "data": "ping"}) - async def stream(): try: - yield f"data: {heartbeat}\n\n" # 心跳包 while True: try: - result = await asyncio.wait_for( - web_chat_back_queue.get(), timeout=10 - ) # 设置超时时间为5秒 + result = await asyncio.wait_for(back_queue.get(), timeout=10) except asyncio.TimeoutError: - yield f"data: {heartbeat}\n\n" # 心跳包 continue if not result: @@ -197,9 +164,6 @@ class ChatRoute(Route): type = result.get("type") cid = result.get("cid") streaming = result.get("streaming", False) - if cid != self.curr_user_cid.get(username): - # 丢弃 - continue yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" await asyncio.sleep(0.05) @@ -210,6 +174,7 @@ class ChatRoute(Route): continue if result_text: + # append bot message conversation = self.db.get_conversation_by_user_id( username, cid ) @@ -222,11 +187,25 @@ class ChatRoute(Route): self.db.update_conversation( username, cid, history=json.dumps(history) ) + break except BaseException as _: logger.debug(f"用户 {username} 断开聊天长连接。") - self.curr_chat_sse.pop(username) return + # Put message to conversation-specific queue + chat_queue = webchat_queue_mgr.get_or_create_queue(conversation_id) + await chat_queue.put( + ( + username, + conversation_id, + { + "message": message, + "image_url": image_url, # list + "audio_url": audio_url, + }, + ) + ) + response = await make_response( stream(), { @@ -236,7 +215,6 @@ class ChatRoute(Route): "Connection": "keep-alive", }, ) - response.timeout = None return response async def delete_conversation(self): @@ -245,6 +223,8 @@ class ChatRoute(Route): if not conversation_id: return Response().error("Missing key: conversation_id").__dict__ + # Clean up queues when deleting conversation + webchat_queue_mgr.remove_queues(conversation_id) self.db.delete_conversation(username, conversation_id) return Response().ok().__dict__ @@ -279,6 +259,4 @@ class ChatRoute(Route): conversation = self.db.get_conversation_by_user_id(username, conversation_id) - self.curr_user_cid[username] = conversation_id - return Response().ok(data=conversation).__dict__ diff --git a/astrbot/dashboard/routes/multi_user_chat.py b/astrbot/dashboard/routes/multi_user_chat.py new file mode 100644 index 000000000..e69de29bb diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index 2954153ae..8c68f083e 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -31,7 +31,7 @@ elevation="0">
- +
@@ -49,8 +49,12 @@ }} --> @@ -65,22 +69,6 @@ -
- -
-
- -
- - mdi-delete - {{ tm('actions.deleteChat') }} - -
-
-
@@ -112,7 +100,7 @@