feat: 优化WebChat长连接的逻辑
This commit is contained in:
@@ -16,9 +16,11 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
return
|
||||
|
||||
cid = self.session_id.split("!")[-1]
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
web_chat_back_queue.put_nowait(comp.text)
|
||||
web_chat_back_queue.put_nowait((comp.text, cid))
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
@@ -30,6 +32,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
f.write(f2.read())
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
web_chat_back_queue.put_nowait(f"[IMAGE]{filename}")
|
||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
await super().send(message)
|
||||
@@ -3,9 +3,10 @@ import json
|
||||
import os
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import web_chat_queue, web_chat_back_queue
|
||||
from quart import request, Response as QuartResponse, g
|
||||
from quart import request, Response as QuartResponse, g, make_response
|
||||
from astrbot.core.db import BaseDatabase
|
||||
import asyncio
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
|
||||
@@ -14,6 +15,7 @@ class ChatRoute(Route):
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
'/chat/send': ('POST', self.chat),
|
||||
'/chat/listen': ('GET', self.listener),
|
||||
'/chat/new_conversation': ('GET', self.new_conversation),
|
||||
'/chat/conversations': ('GET', self.get_conversations),
|
||||
'/chat/get_conversation': ('GET', self.get_conversation),
|
||||
@@ -30,6 +32,9 @@ class ChatRoute(Route):
|
||||
|
||||
self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp']
|
||||
|
||||
self.curr_user_cid = {}
|
||||
self.curr_chat_sse = {}
|
||||
|
||||
async def status(self):
|
||||
has_llm_enabled = self.core_lifecycle.provider_manager.curr_provider_inst is not None
|
||||
has_stt_enabled = self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None
|
||||
@@ -107,63 +112,92 @@ class ChatRoute(Route):
|
||||
if not conversation_id:
|
||||
return Response().error("conversation_id is empty").__dict__
|
||||
|
||||
self.curr_user_cid[username] = conversation_id
|
||||
|
||||
await web_chat_queue.put((username, conversation_id, {
|
||||
'message': message,
|
||||
'image_url': image_url, # list
|
||||
'audio_url': audio_url
|
||||
}))
|
||||
|
||||
async def stream():
|
||||
ret = []
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=30) # 设置超时时间为5秒
|
||||
except asyncio.TimeoutError:
|
||||
yield '[Error] 30 秒内没有返回数据,已放弃。\n'
|
||||
return
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
ret.append(result)
|
||||
|
||||
yield result + '\n'
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
history = []
|
||||
|
||||
new_his = {
|
||||
'type': 'user',
|
||||
'message': message
|
||||
}
|
||||
if image_url:
|
||||
new_his['image_url'] = image_url
|
||||
if audio_url:
|
||||
new_his['audio_url'] = audio_url
|
||||
history.append(new_his)
|
||||
for r in ret:
|
||||
history.append({
|
||||
'type': 'bot',
|
||||
'message': r
|
||||
})
|
||||
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
|
||||
# 持久化
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
history = []
|
||||
new_his = {
|
||||
'type': 'user',
|
||||
'message': message
|
||||
}
|
||||
if image_url:
|
||||
new_his['image_url'] = image_url
|
||||
if audio_url:
|
||||
new_his['audio_url'] = audio_url
|
||||
history.append(new_his)
|
||||
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
|
||||
|
||||
return QuartResponse(
|
||||
return Response().ok().__dict__
|
||||
|
||||
async def listener(self):
|
||||
'''一直保持长连接'''
|
||||
|
||||
username = g.get('username', 'guest')
|
||||
|
||||
if username in self.curr_chat_sse:
|
||||
return "[ERROR]\n"
|
||||
|
||||
self.curr_chat_sse[username] = None
|
||||
|
||||
async def stream():
|
||||
try:
|
||||
yield '[HB]\n'
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=10) # 设置超时时间为5秒
|
||||
except asyncio.TimeoutError:
|
||||
yield '[HB]\n' # 心跳包
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
result_text, cid = result
|
||||
if cid != self.curr_user_cid.get(username):
|
||||
# 丢弃
|
||||
continue
|
||||
yield result_text + '\n'
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, cid)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
history = []
|
||||
history.append({
|
||||
'type': 'bot',
|
||||
'message': result_text
|
||||
})
|
||||
self.db.update_webchat_conversation(username, cid, history=json.dumps(history))
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
logger.error(f"与用户 {username} 断开聊天长连接。")
|
||||
self.curr_chat_sse.pop(username)
|
||||
return
|
||||
|
||||
response = await make_response(
|
||||
stream(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Content-Type": "text/event-stream",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*" # 如果是跨域请求
|
||||
{
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Transfer-Encoding': 'chunked',
|
||||
'Connection': 'keep-alive'
|
||||
}
|
||||
)
|
||||
response.timeout = None
|
||||
return response
|
||||
|
||||
async def delete_conversation(self):
|
||||
username = g.get('username', 'guest')
|
||||
@@ -194,4 +228,7 @@ class ChatRoute(Route):
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
|
||||
self.curr_user_cid[username] = conversation_id
|
||||
|
||||
return Response().ok(data=conversation).__dict__
|
||||
@@ -1,9 +1,7 @@
|
||||
<script setup>
|
||||
import axios from 'axios';
|
||||
import { ref } from 'vue';
|
||||
import { marked } from 'marked';
|
||||
|
||||
|
||||
marked.setOptions({
|
||||
breaks: true
|
||||
});
|
||||
@@ -183,11 +181,14 @@ export default {
|
||||
mediaRecorder: null,
|
||||
|
||||
status: {},
|
||||
statusText: ''
|
||||
statusText: '',
|
||||
|
||||
eventSource: null
|
||||
}
|
||||
},
|
||||
|
||||
mounted() {
|
||||
this.startListeningEvent();
|
||||
this.checkStatus();
|
||||
this.getConversations();
|
||||
let inputField = document.getElementById('input-field');
|
||||
@@ -205,8 +206,70 @@ export default {
|
||||
}.bind(this));
|
||||
},
|
||||
|
||||
beforeUnmount() {
|
||||
console.log("111")
|
||||
if (this.eventSource) {
|
||||
this.eventSource.cancel();
|
||||
console.log('SSE连接已断开');
|
||||
}
|
||||
},
|
||||
|
||||
methods: {
|
||||
|
||||
async startListeningEvent() {
|
||||
const response = await fetch('/api/chat/listen', {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||
}
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
console.error('SSE连接失败:', response.statusText);
|
||||
return;
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
this.eventSource = reader
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) {
|
||||
console.log('SSE连接关闭');
|
||||
break;
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
console.log("!!!!", chunk);
|
||||
|
||||
if (chunk === '[HB]\n') {
|
||||
continue; // 心跳包
|
||||
}
|
||||
if (chunk === '[ERROR]\n') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (chunk.startsWith('[IMAGE]')) {
|
||||
let img = chunk.replace('[IMAGE]', '');
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
} else {
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: chunk
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
}
|
||||
this.scrollToBottom();
|
||||
}
|
||||
},
|
||||
|
||||
removeAudio() {
|
||||
this.stagedAudioUrl = null;
|
||||
},
|
||||
@@ -417,41 +480,41 @@ export default {
|
||||
|
||||
this.loadingChat = false;
|
||||
|
||||
const reader = response.body.getReader(); // 获取流的 Reader
|
||||
const decoder = new TextDecoder();
|
||||
// const reader = response.body.getReader(); // 获取流的 Reader
|
||||
// const decoder = new TextDecoder();
|
||||
|
||||
const readStream = async () => {
|
||||
const { done, value } = await reader.read(); // 读取流中的数据
|
||||
if (done) {
|
||||
console.log("Stream finished.");
|
||||
return;
|
||||
}
|
||||
// const readStream = async () => {
|
||||
// const { done, value } = await reader.read(); // 读取流中的数据
|
||||
// if (done) {
|
||||
// console.log("Stream finished.");
|
||||
// return;
|
||||
// }
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
// bot_resp.message.value += chunk;
|
||||
// const chunk = decoder.decode(value, { stream: true });
|
||||
// // bot_resp.message.value += chunk;
|
||||
|
||||
console.log("!!!!", chunk);
|
||||
if (chunk.startsWith('[IMAGE]')) {
|
||||
let img = chunk.replace('[IMAGE]', '');
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
} else {
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: chunk
|
||||
}
|
||||
// console.log("!!!!", chunk);
|
||||
// if (chunk.startsWith('[IMAGE]')) {
|
||||
// let img = chunk.replace('[IMAGE]', '');
|
||||
// let bot_resp = {
|
||||
// type: 'bot',
|
||||
// message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
// }
|
||||
// this.messages.push(bot_resp);
|
||||
// } else {
|
||||
// let bot_resp = {
|
||||
// type: 'bot',
|
||||
// message: chunk
|
||||
// }
|
||||
|
||||
this.messages.push(bot_resp);
|
||||
}
|
||||
// this.messages.push(bot_resp);
|
||||
// }
|
||||
|
||||
this.scrollToBottom();
|
||||
readStream(); // 递归读取流
|
||||
};
|
||||
// this.scrollToBottom();
|
||||
// readStream(); // 递归读取流
|
||||
// };
|
||||
|
||||
readStream();
|
||||
// readStream();
|
||||
})
|
||||
.catch(err => {
|
||||
console.error(err);
|
||||
@@ -463,7 +526,7 @@ export default {
|
||||
container.scrollTop = container.scrollHeight;
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
Reference in New Issue
Block a user