feat: 初步实现 webchat 页面
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import asyncio
|
||||
from .log import LogManager, LogBroker
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
@@ -19,4 +20,6 @@ if os.environ.get('TESTING', ""):
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
sp = SharedPreferences() # 简单的偏好设置存储
|
||||
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
|
||||
web_chat_queue = asyncio.Queue()
|
||||
web_chat_back_queue = asyncio.Queue()
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
|
||||
|
||||
@dataclass
|
||||
class BaseDatabase(abc.ABC):
|
||||
@@ -76,4 +76,28 @@ class BaseDatabase(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
|
||||
'''通过 url 或 path 获取 ATRI 视觉数据'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
|
||||
'''通过 user_id 和 cid 获取 WebChatConversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def webchat_new_conversation(self, user_id: str, cid: str):
|
||||
'''新建 WebChatConversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
|
||||
'''更新 WebChatConversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_webchat_conversation(self, user_id: str, cid: str):
|
||||
'''删除 WebChatConversation'''
|
||||
raise NotImplementedError
|
||||
+12
-1
@@ -51,4 +51,15 @@ class ATRIVision():
|
||||
platform_name: str
|
||||
session_id: str
|
||||
sender_nickname: str
|
||||
timestamp: int = -1
|
||||
timestamp: int = -1
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebChatConversation():
|
||||
user_id: str
|
||||
cid: str
|
||||
history: str = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
|
||||
@@ -5,7 +5,8 @@ from astrbot.core.db.po import (
|
||||
Platform,
|
||||
Stats,
|
||||
LLMHistory,
|
||||
ATRIVision
|
||||
ATRIVision,
|
||||
WebChatConversation
|
||||
)
|
||||
from . import BaseDatabase
|
||||
from typing import Tuple
|
||||
@@ -199,6 +200,69 @@ class SQLiteDatabase(BaseDatabase):
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
|
||||
|
||||
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
''', (user_id, cid)
|
||||
)
|
||||
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
return WebChatConversation(*res)
|
||||
|
||||
def webchat_new_conversation(self, user_id: str, cid: str):
|
||||
history = "[]"
|
||||
updated_at = int(time.time())
|
||||
created_at = updated_at
|
||||
self._exec_sql(
|
||||
'''
|
||||
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
|
||||
''', (user_id, cid, history, updated_at, created_at)
|
||||
)
|
||||
|
||||
def get_webchat_conversations(self, user_id: str) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||
''', (user_id,)
|
||||
)
|
||||
|
||||
res = c.fetchall()
|
||||
c.close()
|
||||
conversations = []
|
||||
for row in res:
|
||||
cid = row[0]
|
||||
created_at = row[1]
|
||||
updated_at = row[2]
|
||||
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
|
||||
return conversations
|
||||
|
||||
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
|
||||
''', (history, user_id, cid)
|
||||
)
|
||||
|
||||
def delete_webchat_conversation(self, user_id: str, cid: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
''', (user_id, cid)
|
||||
)
|
||||
|
||||
|
||||
def insert_atri_vision_data(self, vision: ATRIVision):
|
||||
|
||||
@@ -35,4 +35,12 @@ CREATE TABLE IF NOT EXISTS atri_vision(
|
||||
session_id VARCHAR(32),
|
||||
sender_nickname VARCHAR(32),
|
||||
timestamp INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||
user_id TEXT,
|
||||
cid TEXT,
|
||||
history TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER
|
||||
);
|
||||
@@ -20,6 +20,10 @@ class WhitelistCheckStage(Stage):
|
||||
if not self.enable_whitelist_check:
|
||||
return
|
||||
|
||||
if event.get_platform_name() == 'webchat':
|
||||
# WebChat 豁免
|
||||
return
|
||||
|
||||
# 检查是否在白名单
|
||||
if self.wl_ignore_admin_on_group:
|
||||
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
from asyncio import Queue
|
||||
from .register import platform_cls_map
|
||||
from astrbot.core import logger
|
||||
|
||||
from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||
|
||||
class PlatformManager():
|
||||
def __init__(self, config: AstrBotConfig, event_queue: Queue):
|
||||
@@ -25,6 +25,7 @@ class PlatformManager():
|
||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
|
||||
case "vchat":
|
||||
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
for platform in self.platforms_config:
|
||||
@@ -37,6 +38,8 @@ class PlatformManager():
|
||||
logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
|
||||
inst = cls_type(platform, self.settings, self.event_queue)
|
||||
self.platform_insts.append(inst)
|
||||
|
||||
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
|
||||
|
||||
def get_insts(self):
|
||||
return self.platform_insts
|
||||
@@ -0,0 +1,88 @@
|
||||
import time
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Awaitable, Any
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import * # noqa: F403
|
||||
from astrbot.api import logger
|
||||
from astrbot.core import web_chat_queue, web_chat_back_queue
|
||||
from .webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
class QueueListener:
|
||||
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
|
||||
self.queue = queue
|
||||
self.callback = callback
|
||||
|
||||
async def run(self):
|
||||
while True:
|
||||
data = await self.queue.get()
|
||||
await self.callback(data)
|
||||
|
||||
@register_platform_adapter("webchat", "webchat")
|
||||
class WebChatAdapter(Platform):
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings['unique_session']
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
"webchat",
|
||||
"webchat",
|
||||
)
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
plain = ""
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
plain += comp.text
|
||||
web_chat_back_queue.put_nowait(plain)
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, data: tuple) -> AstrBotMessage:
|
||||
username, cid, message = data
|
||||
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = "webchat"
|
||||
abm.tag = "webchat"
|
||||
abm.sender = MessageMember(username, username)
|
||||
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
abm.session_id = f"webchat!{username}!{cid}"
|
||||
|
||||
abm.message_id = str(uuid.uuid4())
|
||||
abm.message = [Plain(message)]
|
||||
message_str = message
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
abm.raw_message = data
|
||||
return abm
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
async def callback(data: tuple):
|
||||
abm = await self.convert_message(data)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
bot = QueueListener(web_chat_queue, callback)
|
||||
return bot.run()
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
|
||||
message_event = WebChatMessageEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
@@ -0,0 +1,31 @@
|
||||
import os
|
||||
import uuid
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from astrbot.core import web_chat_back_queue
|
||||
|
||||
class WebChatMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str, message_obj, platform_meta, session_id):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
await web_chat_back_queue.put(comp.text)
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
path = os.path.join(self.imgs_dir, filename)
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
ph = comp.file[8:]
|
||||
with open(path, "wb") as f:
|
||||
with open(ph, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
await web_chat_back_queue.put(f"[IMAGE]{filename}")
|
||||
await web_chat_back_queue.put(None)
|
||||
await super().send(message)
|
||||
@@ -65,7 +65,7 @@ def save_temp_img(img: Image) -> str:
|
||||
f.write(img)
|
||||
return p
|
||||
|
||||
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str:
|
||||
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None, path = None) -> str:
|
||||
'''
|
||||
下载图片, 返回 path
|
||||
'''
|
||||
@@ -73,10 +73,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if post:
|
||||
async with session.post(url, json=post_data) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
else:
|
||||
async with session.get(url) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
except aiohttp.client_exceptions.ClientConnectorSSLError:
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
|
||||
@@ -5,6 +5,7 @@ from .update import UpdateRoute
|
||||
from .stat import StatRoute
|
||||
from .log import LogRoute
|
||||
from .static_file import StaticFileRoute
|
||||
from .chat import ChatRoute
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -14,6 +15,7 @@ __all__ = [
|
||||
"UpdateRoute",
|
||||
"StatRoute",
|
||||
"LogRoute",
|
||||
"StaticFileRoute"
|
||||
"StaticFileRoute",
|
||||
"ChatRoute",
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import web_chat_queue, web_chat_back_queue
|
||||
from quart import request, Response as QuartResponse, g
|
||||
from astrbot.core.db import BaseDatabase
|
||||
import asyncio
|
||||
|
||||
class ChatRoute(Route):
|
||||
def __init__(self, context: RouteContext, db: BaseDatabase) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
'/chat/send': ('POST', self.chat),
|
||||
'/chat/new_conversation': ('GET', self.new_conversation),
|
||||
'/chat/conversations': ('GET', self.get_conversations),
|
||||
'/chat/get_conversation': ('GET', self.get_conversation),
|
||||
'/chat/delete_conversation': ('GET', self.delete_conversation),
|
||||
'/chat/get_file': ('GET', self.get_file)
|
||||
}
|
||||
self.db = db
|
||||
self.register_routes()
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
|
||||
async def get_file(self):
|
||||
filename = request.args.get('filename')
|
||||
if not filename:
|
||||
return Response().error("Missing key: filename").__dict__
|
||||
|
||||
try:
|
||||
with open(os.path.join(self.imgs_dir, filename), "rb") as f:
|
||||
return QuartResponse(f.read(), mimetype="image/jpeg")
|
||||
except FileNotFoundError:
|
||||
return Response().error("File not found").__dict__
|
||||
|
||||
async def chat(self):
|
||||
username = g.get('username', 'guest')
|
||||
|
||||
post_data = await request.json
|
||||
if 'message' not in post_data:
|
||||
return Response().error("Missing key: message").__dict__
|
||||
|
||||
if 'conversation_id' not in post_data:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
message = post_data['message']
|
||||
conversation_id = post_data['conversation_id']
|
||||
if not message:
|
||||
return Response().error("Message is empty").__dict__
|
||||
if not conversation_id:
|
||||
return Response().error("conversation_id is empty").__dict__
|
||||
|
||||
await web_chat_queue.put((username, conversation_id, message))
|
||||
|
||||
async def stream():
|
||||
ret = []
|
||||
while True:
|
||||
result = await web_chat_back_queue.get()
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
ret.append(result)
|
||||
|
||||
yield result + '\n'
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
try:
|
||||
history = json.loads(conversation['history'])
|
||||
except BaseException:
|
||||
history = []
|
||||
history.append({
|
||||
'type': 'user',
|
||||
'message': message
|
||||
})
|
||||
# history.append({
|
||||
# 'type': 'bot',
|
||||
# 'message': ret
|
||||
# })
|
||||
for r in ret:
|
||||
history.append({
|
||||
'type': 'bot',
|
||||
'message': r
|
||||
})
|
||||
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
|
||||
|
||||
return QuartResponse(
|
||||
stream(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Content-Type": "text/event-stream",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*" # 如果是跨域请求
|
||||
}
|
||||
)
|
||||
|
||||
async def delete_conversation(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversation_id = request.args.get('conversation_id')
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
self.db.delete_webchat_conversation(username, conversation_id)
|
||||
return Response().ok().__dict__
|
||||
|
||||
async def new_conversation(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.webchat_new_conversation(username, conversation_id)
|
||||
return Response().ok(data={
|
||||
'conversation_id': conversation_id
|
||||
}).__dict__
|
||||
|
||||
async def get_conversations(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversations = self.db.get_webchat_conversations(username)
|
||||
return Response().ok(data=conversations).__dict__
|
||||
|
||||
async def get_conversation(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversation_id = request.args.get('conversation_id')
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
return Response().ok(data=conversation).__dict__
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import jwt
|
||||
import asyncio
|
||||
import os
|
||||
from quart import Quart, request, jsonify
|
||||
from quart import Quart, request, jsonify, g
|
||||
from quart.logging import default_handler
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from .routes import *
|
||||
@@ -31,12 +31,15 @@ class AstrBotDashboard():
|
||||
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
|
||||
self.sfr = StaticFileRoute(self.context)
|
||||
self.ar = AuthRoute(self.context)
|
||||
self.chat_route = ChatRoute(self.context, db)
|
||||
|
||||
async def auth_middleware(self):
|
||||
if not request.path.startswith("/api"):
|
||||
return
|
||||
if request.path == "/api/auth/login":
|
||||
return
|
||||
if request.path == "/api/chat/get_file":
|
||||
return
|
||||
# claim jwt
|
||||
token = request.headers.get("Authorization")
|
||||
if not token:
|
||||
@@ -46,7 +49,8 @@ class AstrBotDashboard():
|
||||
if token.startswith("Bearer "):
|
||||
token = token[7:]
|
||||
try:
|
||||
jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
|
||||
payload = jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
|
||||
g.username = payload["username"]
|
||||
except jwt.ExpiredSignatureError:
|
||||
r = jsonify(Response().error("Token 过期").__dict__)
|
||||
r.status_code = 401
|
||||
|
||||
Generated
+17
@@ -18,6 +18,7 @@
|
||||
"date-fns": "2.30.0",
|
||||
"js-md5": "^0.8.3",
|
||||
"lodash": "4.17.21",
|
||||
"marked": "^15.0.6",
|
||||
"pinia": "2.1.6",
|
||||
"remixicon": "3.5.0",
|
||||
"vee-validate": "4.11.3",
|
||||
@@ -3702,6 +3703,17 @@
|
||||
"markdown-it": "bin/markdown-it.js"
|
||||
}
|
||||
},
|
||||
"node_modules/marked": {
|
||||
"version": "15.0.6",
|
||||
"resolved": "https://registry.npmjs.org/marked/-/marked-15.0.6.tgz",
|
||||
"integrity": "sha512-Y07CUOE+HQXbVDCGl3LXggqJDbXDP2pArc2C1N1RRMN0ONiShoSsIInMd5Gsxupe7fKLpgimTV+HOJ9r7bA+pg==",
|
||||
"bin": {
|
||||
"marked": "bin/marked.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/mdurl": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/mdurl/-/mdurl-1.0.1.tgz",
|
||||
@@ -8460,6 +8472,11 @@
|
||||
"uc.micro": "^1.0.5"
|
||||
}
|
||||
},
|
||||
"marked": {
|
||||
"version": "15.0.6",
|
||||
"resolved": "https://registry.npmjs.org/marked/-/marked-15.0.6.tgz",
|
||||
"integrity": "sha512-Y07CUOE+HQXbVDCGl3LXggqJDbXDP2pArc2C1N1RRMN0ONiShoSsIInMd5Gsxupe7fKLpgimTV+HOJ9r7bA+pg=="
|
||||
},
|
||||
"mdurl": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/mdurl/-/mdurl-1.0.1.tgz",
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
"date-fns": "2.30.0",
|
||||
"js-md5": "^0.8.3",
|
||||
"lodash": "4.17.21",
|
||||
"marked": "^15.0.6",
|
||||
"pinia": "2.1.6",
|
||||
"remixicon": "3.5.0",
|
||||
"vee-validate": "4.11.3",
|
||||
|
||||
@@ -30,6 +30,11 @@ const sidebarItem: menu[] = [
|
||||
icon: 'mdi-puzzle',
|
||||
to: '/extension'
|
||||
},
|
||||
{
|
||||
title: '聊天',
|
||||
icon: 'mdi-chat',
|
||||
to: '/chat'
|
||||
},
|
||||
{
|
||||
title: '控制台',
|
||||
icon: 'mdi-console',
|
||||
|
||||
@@ -36,6 +36,11 @@ const MainRoutes = {
|
||||
name: 'Project ATRI',
|
||||
path: '/project-atri',
|
||||
component: () => import('@/views/ATRIProject.vue')
|
||||
},
|
||||
{
|
||||
name: 'Chat',
|
||||
path: '/chat',
|
||||
component: () => import('@/views/ChatPage.vue')
|
||||
}
|
||||
]
|
||||
};
|
||||
|
||||
@@ -0,0 +1,285 @@
|
||||
<script setup>
|
||||
import axios from 'axios';
|
||||
import { ref } from 'vue';
|
||||
import { marked } from 'marked';
|
||||
|
||||
|
||||
marked.setOptions({
|
||||
breaks: true
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
|
||||
<v-card style="margin-bottom: 16px; width: 100%; background-color: #fff; height: 100%;">
|
||||
<v-card-text style="width: 100%; height: calc(100vh - 120px);">
|
||||
<div style="height: 100%; display: flex; gap: 16px;">
|
||||
<div style="max-width: 120px;">
|
||||
<!-- conversation -->
|
||||
<v-btn style="margin-bottom: 16px;" @click="newC">创建对话</v-btn>
|
||||
<v-btn style="margin-bottom: 16px;" v-if="currCid" @click="deleteConversation(currCid)" color="error">删除对话</v-btn>
|
||||
<v-card class="mx-auto" min-width="200">
|
||||
<v-list dense nav
|
||||
v-if="conversations.length > 0"
|
||||
@update:selected="getConversationMessages"
|
||||
>
|
||||
<v-list-item
|
||||
v-for="(item, i) in conversations"
|
||||
:key="item.cid"
|
||||
:value="item.cid"
|
||||
color="primary"
|
||||
rounded="xl"
|
||||
>
|
||||
<v-list-item-title>新对话</v-list-item-title>
|
||||
<v-list-item-subtitle>{{ formatDate(item.updated_at) }}</v-list-item-subtitle>
|
||||
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-card>
|
||||
</div>
|
||||
<div style="height: 100%; width: 100%;">
|
||||
<div style="height: calc(100% - 64px); overflow-y: auto; padding: 16px; " ref="messageContainer">
|
||||
<div class="fade-in" v-if="messages.length == 0"
|
||||
style="height: 100%; display: flex; justify-content: center; align-items: center;">
|
||||
<span style="font-size: 28px;">Hello, I'm</span>
|
||||
<span style="font-weight: 1000; font-size: 28px; margin-left: 8px;">AstrBot ⭐</span>
|
||||
</div>
|
||||
<div v-else style="max-height: 100%; padding: 16px; max-width: 700px; margin: 0 auto;">
|
||||
<div class="fade-in" v-for="(msg, index) in messages" :key="index"
|
||||
style="margin-bottom: 16px;">
|
||||
<div v-if="msg.type == 'user'" style="display: flex; justify-content: flex-end;">
|
||||
<div
|
||||
style="padding: 12px; border-radius: 8px; background-color: rgba(94, 53, 177, 0.15)">
|
||||
<span>{{ msg.message }}</span>
|
||||
</div>
|
||||
</div>
|
||||
<div v-else style="display: flex; justify-content: flex-start; gap: 16px;">
|
||||
<span style="font-size: 32px;">✨</span>
|
||||
<div v-html="marked(msg.message)" class="mc" style="font-family: inherit;"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="fade-in" style="bottom: 16px; width: 100%; padding: 16px; ">
|
||||
<div style="width: 100%; justify-content: center; align-items: center; display: flex; ">
|
||||
<v-text-field variant="outlined" v-model="prompt" label="聊天吧!" placeholder="Start typing..."
|
||||
loading clear-icon="mdi-close-circle" clearable @click:clear="clearMessage"
|
||||
@keyup.enter="sendMessage" style="width: 100%; max-width: 930px;">
|
||||
<template v-slot:loader>
|
||||
|
||||
</template>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-icon @click="sendMessage" size="35" icon="mdi-arrow-up-circle" />
|
||||
</template>
|
||||
</v-text-field>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</template>
|
||||
<script>
|
||||
export default {
|
||||
name: 'ChatPage',
|
||||
components: {
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
prompt: '',
|
||||
messages: [],
|
||||
conversations: [],
|
||||
currCid: ''
|
||||
}
|
||||
},
|
||||
|
||||
mounted() {
|
||||
this.getConversations();
|
||||
},
|
||||
|
||||
methods: {
|
||||
clearMessage() {
|
||||
this.prompt = '';
|
||||
},
|
||||
getConversations() {
|
||||
axios.get('/api/chat/conversations').then(response => {
|
||||
this.conversations = response.data.data;
|
||||
}).catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
},
|
||||
getConversationMessages(cid) {
|
||||
if (!cid[0])
|
||||
return;
|
||||
axios.get('/api/chat/get_conversation?conversation_id='+cid[0]).then(response => {
|
||||
this.currCid = cid[0];
|
||||
let message = JSON.parse(response.data.data.history);
|
||||
for (let i = 0; i < message.length; i++) {
|
||||
if (message[i].message.startsWith('[IMAGE]')) {
|
||||
let img = message[i].message.replace('[IMAGE]', '');
|
||||
message[i].message = `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
}
|
||||
}
|
||||
this.messages = message;
|
||||
}).catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
},
|
||||
async newConversation() {
|
||||
await axios.get('/api/chat/new_conversation').then(response => {
|
||||
this.currCid = response.data.data.conversation_id;
|
||||
this.getConversations();
|
||||
}).catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
},
|
||||
|
||||
newC() {
|
||||
this.currCid = '';
|
||||
this.messages = [];
|
||||
},
|
||||
|
||||
formatDate(timestamp) {
|
||||
const date = new Date(timestamp * 1000); // 假设时间戳是以秒为单位
|
||||
const options = {
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit',
|
||||
hour12: false
|
||||
};
|
||||
return date.toLocaleString('zh-CN', options).replace(/\//g, '-').replace(/, /g, ' ');
|
||||
},
|
||||
|
||||
deleteConversation(cid) {
|
||||
axios.get('/api/chat/delete_conversation?conversation_id='+cid).then(response => {
|
||||
this.getConversations();
|
||||
this.currCid = '';
|
||||
this.messages = [];
|
||||
}).catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
},
|
||||
|
||||
async sendMessage() {
|
||||
if (this.currCid == '') {
|
||||
await this.newConversation();
|
||||
}
|
||||
|
||||
this.messages.push({
|
||||
type: 'user',
|
||||
message: this.prompt
|
||||
});
|
||||
|
||||
// let bot_resp = {
|
||||
// type: 'bot',
|
||||
// message: ref('')
|
||||
// }
|
||||
|
||||
// this.messages.push(bot_resp);
|
||||
|
||||
this.scrollToBottom();
|
||||
|
||||
|
||||
fetch('/api/chat/send', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||
},
|
||||
body: JSON.stringify({ message: this.prompt, conversation_id: this.currCid }) // 发送请求体
|
||||
})
|
||||
.then(response => {
|
||||
this.prompt = '';
|
||||
|
||||
const reader = response.body.getReader(); // 获取流的 Reader
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
const readStream = async () => {
|
||||
const { done, value } = await reader.read(); // 读取流中的数据
|
||||
if (done) {
|
||||
console.log("Stream finished.");
|
||||
return;
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
// bot_resp.message.value += chunk;
|
||||
|
||||
console.log("!!!!", chunk);
|
||||
if (chunk.startsWith('[IMAGE]')) {
|
||||
let img = chunk.replace('[IMAGE]', '');
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
} else {
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: chunk
|
||||
}
|
||||
|
||||
this.messages.push(bot_resp);
|
||||
}
|
||||
|
||||
this.scrollToBottom();
|
||||
readStream(); // 递归读取流
|
||||
};
|
||||
|
||||
readStream();
|
||||
})
|
||||
.catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
},
|
||||
scrollToBottom() {
|
||||
this.$nextTick(() => {
|
||||
const container = this.$refs.messageContainer;
|
||||
container.scrollTop = container.scrollHeight;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<style>
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.fade-in {
|
||||
animation: fadeIn 0.2s ease-in-out;
|
||||
}
|
||||
|
||||
.mc h1,
|
||||
.mc h2,
|
||||
.mc h3,
|
||||
.mc h4,
|
||||
.mc h5,
|
||||
.mc h6 {
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.mc li {
|
||||
margin-left: 16px;
|
||||
}
|
||||
|
||||
|
||||
.mc p {
|
||||
margin-top: 10px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
</style>
|
||||
Reference in New Issue
Block a user