Compare commits

..

10 Commits

Author SHA1 Message Date
Soulter 29374f8d8a fix: 修复 /dashbord_update 指令 2025-01-11 00:25:02 +08:00
Soulter 359b971103 Merge pull request #235 from Soulter/feat-webchat
WebChat 支持
2025-01-11 00:17:18 +08:00
Soulter fbdb1ae208 chore: bump to v3.4.4 2025-01-11 00:14:08 +08:00
Soulter 22c13c1eff perf: webchat支持传图 2025-01-11 00:06:19 +08:00
Soulter 5fc63aeaf1 perf: ui 2025-01-10 22:45:14 +08:00
Soulter d4f32673ab fix: 修复持久化问题 2025-01-10 22:08:43 +08:00
Soulter 480dffb51b feat: 初步实现 webchat 页面 2025-01-10 21:48:15 +08:00
Soulter 966df00124 feat: 支持从管理面板(控制台页)手动安装 pip 库 2025-01-10 15:35:57 +08:00
Soulter 3e2b4bc727 feat: 支持动态设置会话变量以适用 Dify 输入变量 2025-01-10 12:32:20 +08:00
Soulter 5929a8d42b Update README.md 2025-01-09 23:11:11 +08:00
34 changed files with 1074 additions and 52 deletions
+2 -1
View File
@@ -20,4 +20,5 @@ chroma
node_modules/
.DS_Store
package-lock.json
package.json
package.json
venv/*
+21 -1
View File
@@ -45,7 +45,27 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
## ✨ 支持 Dify
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流!
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流![接入 Dify - AstrBot 文档](https://astrbot.lwl.lol/others/dify.html)
## ✨ 代码执行器(Beta)
基于 Docker 的沙箱化代码执行器(Beta 测试中)
> [!NOTE]
> 文件输入/输出目前仅支持 Napcat(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/700a545e-7450-4f23-90ff-af6d0d60e501" height=300>
<img src="https://github.com/user-attachments/assets/0b0c5344-e98b-4902-92ad-fe9f0bb10c2a" height=300>
<img src="https://github.com/user-attachments/assets/b9b98ff4-8630-46fb-9a39-ecbad9d601ae" height=300>
<img src="https://github.com/user-attachments/assets/9fe6e44c-e4f6-4347-9d5f-281677d47feb" height=300>
</div>
## ✨ Demo
+7
View File
@@ -1,12 +1,16 @@
import os
import asyncio
from .log import LogManager, LogBroker
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
os.makedirs("data", exist_ok=True)
astrbot_config = AstrBotConfig()
html_renderer = HtmlRenderer()
logger = LogManager.GetLogger(log_name='astrbot')
@@ -15,4 +19,7 @@ if os.environ.get('TESTING', ""):
db_helper = SQLiteDatabase(DB_PATH)
sp = SharedPreferences() # 简单的偏好设置存储
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
web_chat_queue = asyncio.Queue()
web_chat_back_queue = asyncio.Queue()
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
+1 -1
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.3"
VERSION = "3.4.4"
DB_PATH = "data/data_v3.db"
# 默认配置
+2 -1
View File
@@ -3,6 +3,7 @@ import time
import threading
import os
from .event_bus import EventBus
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -21,7 +22,7 @@ from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
self.astrbot_config = AstrBotConfig()
self.astrbot_config = astrbot_config
self.db = db
if self.astrbot_config['http_proxy']:
+25 -1
View File
@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
from typing import List
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
@dataclass
class BaseDatabase(abc.ABC):
@@ -76,4 +76,28 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
'''通过 url 或 path 获取 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
'''通过 user_id 和 cid 获取 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def webchat_new_conversation(self, user_id: str, cid: str):
'''新建 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
raise NotImplementedError
@abc.abstractmethod
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
'''更新 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def delete_webchat_conversation(self, user_id: str, cid: str):
'''删除 WebChatConversation'''
raise NotImplementedError
+12 -1
View File
@@ -51,4 +51,15 @@ class ATRIVision():
platform_name: str
session_id: str
sender_nickname: str
timestamp: int = -1
timestamp: int = -1
@dataclass
class WebChatConversation():
user_id: str
cid: str
history: str = ""
created_at: int = 0
updated_at: int = 0
+65 -1
View File
@@ -5,7 +5,8 @@ from astrbot.core.db.po import (
Platform,
Stats,
LLMHistory,
ATRIVision
ATRIVision,
WebChatConversation
)
from . import BaseDatabase
from typing import Tuple
@@ -199,6 +200,69 @@ class SQLiteDatabase(BaseDatabase):
c.close()
return Stats(platform, [], [])
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
res = c.fetchone()
c.close()
return WebChatConversation(*res)
def webchat_new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
'''
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
''', (user_id, cid, history, updated_at, created_at)
)
def get_webchat_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
''', (user_id,)
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
return conversations
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
''', (history, user_id, cid)
)
def delete_webchat_conversation(self, user_id: str, cid: str):
self._exec_sql(
'''
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
def insert_atri_vision_data(self, vision: ATRIVision):
+8
View File
@@ -35,4 +35,12 @@ CREATE TABLE IF NOT EXISTS atri_vision(
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT,
cid TEXT,
history TEXT,
created_at INTEGER,
updated_at INTEGER
);
@@ -20,6 +20,10 @@ class WhitelistCheckStage(Stage):
if not self.enable_whitelist_check:
return
if event.get_platform_name() == 'webchat':
# WebChat 豁免
return
# 检查是否在白名单
if self.wl_ignore_admin_on_group:
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
+4 -1
View File
@@ -4,7 +4,7 @@ from typing import List
from asyncio import Queue
from .register import platform_cls_map
from astrbot.core import logger
from .sources.webchat.webchat_adapter import WebChatAdapter
class PlatformManager():
def __init__(self, config: AstrBotConfig, event_queue: Queue):
@@ -25,6 +25,7 @@ class PlatformManager():
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "vchat":
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
async def initialize(self):
for platform in self.platforms_config:
@@ -37,6 +38,8 @@ class PlatformManager():
logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
inst = cls_type(platform, self.settings, self.event_queue)
self.platform_insts.append(inst)
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
def get_insts(self):
return self.platform_insts
@@ -0,0 +1,102 @@
import time
import asyncio
import uuid
import os
from typing import Awaitable, Any
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, Image # noqa: F403
from astrbot.api import logger
from astrbot.core import web_chat_queue, web_chat_back_queue
from .webchat_event import WebChatMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
class QueueListener:
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
self.queue = queue
self.callback = callback
async def run(self):
while True:
data = await self.queue.get()
await self.callback(data)
@register_platform_adapter("webchat", "webchat")
class WebChatAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings['unique_session']
self.imgs_dir = "data/webchat/imgs"
self.metadata = PlatformMetadata(
"webchat",
"webchat",
)
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
plain = ""
for comp in message_chain.chain:
if isinstance(comp, Plain):
plain += comp.text
web_chat_back_queue.put_nowait(plain)
await super().send_by_session(session, message_chain)
async def convert_message(self, data: tuple) -> AstrBotMessage:
username, cid, payload = data
abm = AstrBotMessage()
abm.self_id = "webchat"
abm.tag = "webchat"
abm.sender = MessageMember(username, username)
abm.type = MessageType.FRIEND_MESSAGE
abm.session_id = f"webchat!{username}!{cid}"
abm.message_id = str(uuid.uuid4())
abm.message = []
if payload['message']:
abm.message.append(Plain(payload['message']))
if payload['image_url']:
if isinstance(payload['image_url'], list):
for img in payload['image_url']:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img)))
else:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url'])))
logger.debug(f"WebChatAdapter: {abm.message}")
message_str = payload['message']
abm.timestamp = int(time.time())
abm.message_str = message_str
abm.raw_message = data
return abm
def run(self) -> Awaitable[Any]:
async def callback(data: tuple):
abm = await self.convert_message(data)
await self.handle_msg(abm)
bot = QueueListener(web_chat_queue, callback)
return bot.run()
def meta(self) -> PlatformMetadata:
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
message_event = WebChatMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id
)
self.commit_event(message_event)
@@ -0,0 +1,31 @@
import os
import uuid
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.core import web_chat_back_queue
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.imgs_dir = "data/webchat/imgs"
os.makedirs(self.imgs_dir, exist_ok=True)
async def send(self, message: MessageChain):
for comp in message.chain:
if isinstance(comp, Plain):
await web_chat_back_queue.put(comp.text)
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
await web_chat_back_queue.put(f"[IMAGE]{filename}")
await web_chat_back_queue.put(None)
await super().send(message)
+10 -4
View File
@@ -6,8 +6,7 @@ from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import logger
from astrbot.core import logger, sp
@register_provider_adapter("dify", "Dify APP 适配器。")
class ProviderDify(Provider):
@@ -67,10 +66,16 @@ class ProviderDify(Provider):
logger.debug(files_payload)
# 获得会话变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
match self.api_type:
case "chat" | "agent":
async for chunk in self.api_client.chat_messages(
inputs={},
inputs={
**session_var
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
@@ -88,7 +93,8 @@ class ProviderDify(Provider):
async for chunk in self.api_client.workflow_run(
inputs={
"astrbot_text_query": prompt,
"astrbot_session_id": session_id
"astrbot_session_id": session_id,
**session_var
},
user=session_id,
files=files_payload
@@ -209,6 +209,8 @@ class ProviderOpenAIOfficial(Provider):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
else:
if image_url.startswith("file:///"):
image_url = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
+3 -12
View File
@@ -9,7 +9,7 @@ from types import ModuleType
from typing import List
from pip import main as pip_main
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger, sp
from astrbot.core import logger, sp, pip_installer
from .context import Context
from . import StarMetadata
from .updator import PluginUpdator
@@ -92,21 +92,12 @@ class PluginManager:
plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在检查插件 {p} 的依赖: {pth}")
logger.info(f"正在安装插件 {p} 所需的依赖: {pth}")
try:
self._update_plugin_dept(os.path.join(plugin_path, "requirements.txt"))
pip_installer.install(requirements_path=pth)
except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
def _update_plugin_dept(self, path):
'''更新插件的依赖'''
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
if self.config.pip_install_arg:
args.extend([self.config.pip_install_arg])
result_code = pip_main(args)
if result_code != 0:
raise Exception(str(result_code))
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata:
'''v3.4.0 以前的方式载入插件元数据
+21 -9
View File
@@ -1,4 +1,5 @@
import json
from astrbot.core import logger
from aiohttp import ClientSession
from typing import Dict, List, Any, AsyncGenerator
@@ -29,11 +30,18 @@ class DifyAPIClient:
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
async for data in resp.content:
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip():
continue
if data.startswith(b"data:"):
yield json.loads(data[5:])
elif data.startswith(b"data:"):
try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def workflow_run(
self,
@@ -50,11 +58,18 @@ class DifyAPIClient:
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
async for data in resp.content:
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip():
continue
if data.startswith(b"data:"):
yield json.loads(data[5:])
elif data.startswith(b"data:"):
try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def file_upload(
self,
@@ -70,9 +85,6 @@ class DifyAPIClient:
url, data=payload, headers=self.headers
) as resp:
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()
+13 -3
View File
@@ -65,7 +65,7 @@ def save_temp_img(img: Image) -> str:
f.write(img)
return p
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str:
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None, path = None) -> str:
'''
下载图片, 返回 path
'''
@@ -73,10 +73,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
async with aiohttp.ClientSession() as session:
if post:
async with session.post(url, json=post_data) as resp:
return save_temp_img(await resp.read())
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url) as resp:
return save_temp_img(await resp.read())
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
except aiohttp.client_exceptions.ClientConnectorSSLError:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
+33
View File
@@ -0,0 +1,33 @@
import logging
from pip import main as pip_main
class PipInstaller():
def __init__(self, pip_install_arg: str):
self.pip_install_arg = pip_install_arg
def install(self, package_name: str = None, requirements_path: str = None, mirror: str = None):
args = ['install']
if package_name:
args.append(package_name)
elif requirements_path:
args.extend(['-r', requirements_path])
if not mirror:
mirror = 'https://mirrors.aliyun.com/pypi/simple/'
args.extend(['--trusted-host', 'mirrors.aliyun.com', '-i', mirror])
if self.pip_install_arg:
args.extend(self.pip_install_arg.split())
print(f"Pip 包管理器: {' '.join(args)}")
result_code = pip_main(args)
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if result_code != 0:
raise Exception(f"安装失败,错误码:{result_code}")
+3 -1
View File
@@ -5,6 +5,7 @@ from .update import UpdateRoute
from .stat import StatRoute
from .log import LogRoute
from .static_file import StaticFileRoute
from .chat import ChatRoute
__all__ = [
@@ -14,6 +15,7 @@ __all__ = [
"UpdateRoute",
"StatRoute",
"LogRoute",
"StaticFileRoute"
"StaticFileRoute",
"ChatRoute",
]
+149
View File
@@ -0,0 +1,149 @@
import uuid
import json
import os
from .route import Route, Response, RouteContext
from astrbot.core import web_chat_queue, web_chat_back_queue
from quart import request, Response as QuartResponse, g
from astrbot.core.db import BaseDatabase
import asyncio
class ChatRoute(Route):
def __init__(self, context: RouteContext, db: BaseDatabase) -> None:
super().__init__(context)
self.routes = {
'/chat/send': ('POST', self.chat),
'/chat/new_conversation': ('GET', self.new_conversation),
'/chat/conversations': ('GET', self.get_conversations),
'/chat/get_conversation': ('GET', self.get_conversation),
'/chat/delete_conversation': ('GET', self.delete_conversation),
'/chat/get_file': ('GET', self.get_file),
'/chat/post_image': ('POST', self.post_image)
}
self.db = db
self.register_routes()
self.imgs_dir = "data/webchat/imgs"
async def get_file(self):
filename = request.args.get('filename')
if not filename:
return Response().error("Missing key: filename").__dict__
try:
with open(os.path.join(self.imgs_dir, filename), "rb") as f:
return QuartResponse(f.read(), mimetype="image/jpeg")
except FileNotFoundError:
return Response().error("File not found").__dict__
async def post_image(self):
post_data = await request.files
if 'file' not in post_data:
return Response().error("Missing key: file").__dict__
file = post_data['file']
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
await file.save(path)
return Response().ok(data={
'filename': filename
}).__dict__
async def chat(self):
username = g.get('username', 'guest')
post_data = await request.json
if 'message' not in post_data and 'image_url' not in post_data:
return Response().error("Missing key: message or image_url").__dict__
if 'conversation_id' not in post_data:
return Response().error("Missing key: conversation_id").__dict__
message = post_data['message']
conversation_id = post_data['conversation_id']
image_url = post_data.get('image_url')
if not message and not image_url:
return Response().error("Message and image_url are empty").__dict__
if not conversation_id:
return Response().error("conversation_id is empty").__dict__
await web_chat_queue.put((username, conversation_id, {
'message': message,
'image_url': image_url # list
}))
async def stream():
ret = []
while True:
result = await web_chat_back_queue.get()
if result is None:
break
ret.append(result)
yield result + '\n'
await asyncio.sleep(0.5)
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
new_his = {
'type': 'user',
'message': message
}
if image_url:
new_his['image_url'] = image_url
history.append(new_his)
for r in ret:
history.append({
'type': 'bot',
'message': r
})
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
return QuartResponse(
stream(),
mimetype="text/event-stream",
headers={
"Content-Type": "text/event-stream",
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*" # 如果是跨域请求
}
)
async def delete_conversation(self):
username = g.get('username', 'guest')
conversation_id = request.args.get('conversation_id')
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
self.db.delete_webchat_conversation(username, conversation_id)
return Response().ok().__dict__
async def new_conversation(self):
username = g.get('username', 'guest')
conversation_id = str(uuid.uuid4())
self.db.webchat_new_conversation(username, conversation_id)
return Response().ok(data={
'conversation_id': conversation_id
}).__dict__
async def get_conversations(self):
username = g.get('username', 'guest')
conversations = self.db.get_webchat_conversations(username)
return Response().ok(data=conversations).__dict__
async def get_conversation(self):
username = g.get('username', 'guest')
conversation_id = request.args.get('conversation_id')
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
return Response().ok(data=conversation).__dict__
+14 -1
View File
@@ -3,7 +3,7 @@ import traceback
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core import logger, pip_installer
class UpdateRoute(Route):
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None:
@@ -11,6 +11,7 @@ class UpdateRoute(Route):
self.routes = {
'/update/check': ('GET', self.check_update),
'/update/do': ('POST', self.update_project),
'/update/pip-install': ('POST', self.install_pip_package)
}
self.astrbot_updator = astrbot_updator
self.register_routes()
@@ -47,4 +48,16 @@ class UpdateRoute(Route):
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
except Exception as e:
logger.error(f"/api/update_project: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
async def install_pip_package(self):
data = await request.json
package = data.get('package', '')
if not package:
return Response().error("缺少参数 package 或不合法。").__dict__
try:
pip_installer.install(package)
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(f"/api/update_pip: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
+6 -2
View File
@@ -2,7 +2,7 @@ import logging
import jwt
import asyncio
import os
from quart import Quart, request, jsonify
from quart import Quart, request, jsonify, g
from quart.logging import default_handler
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .routes import *
@@ -31,12 +31,15 @@ class AstrBotDashboard():
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
self.sfr = StaticFileRoute(self.context)
self.ar = AuthRoute(self.context)
self.chat_route = ChatRoute(self.context, db)
async def auth_middleware(self):
if not request.path.startswith("/api"):
return
if request.path == "/api/auth/login":
return
if request.path == "/api/chat/get_file":
return
# claim jwt
token = request.headers.get("Authorization")
if not token:
@@ -46,7 +49,8 @@ class AstrBotDashboard():
if token.startswith("Bearer "):
token = token[7:]
try:
jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
payload = jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
g.username = payload["username"]
except jwt.ExpiredSignatureError:
r = jsonify(Response().error("Token 过期").__dict__)
r.status_code = 401
+6
View File
@@ -0,0 +1,6 @@
# What's Changed
1. 支持通过 /set <k> <v> 设置持久化的会话变量, 方便 Dify App 输入变量
2. 管理面板支持 Web Chat
3. 管理面板支持手动安装 Pip 库, 在 `控制台` 页中可找到
+17
View File
@@ -18,6 +18,7 @@
"date-fns": "2.30.0",
"js-md5": "^0.8.3",
"lodash": "4.17.21",
"marked": "^15.0.6",
"pinia": "2.1.6",
"remixicon": "3.5.0",
"vee-validate": "4.11.3",
@@ -3702,6 +3703,17 @@
"markdown-it": "bin/markdown-it.js"
}
},
"node_modules/marked": {
"version": "15.0.6",
"resolved": "https://registry.npmjs.org/marked/-/marked-15.0.6.tgz",
"integrity": "sha512-Y07CUOE+HQXbVDCGl3LXggqJDbXDP2pArc2C1N1RRMN0ONiShoSsIInMd5Gsxupe7fKLpgimTV+HOJ9r7bA+pg==",
"bin": {
"marked": "bin/marked.js"
},
"engines": {
"node": ">= 18"
}
},
"node_modules/mdurl": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/mdurl/-/mdurl-1.0.1.tgz",
@@ -8460,6 +8472,11 @@
"uc.micro": "^1.0.5"
}
},
"marked": {
"version": "15.0.6",
"resolved": "https://registry.npmjs.org/marked/-/marked-15.0.6.tgz",
"integrity": "sha512-Y07CUOE+HQXbVDCGl3LXggqJDbXDP2pArc2C1N1RRMN0ONiShoSsIInMd5Gsxupe7fKLpgimTV+HOJ9r7bA+pg=="
},
"mdurl": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/mdurl/-/mdurl-1.0.1.tgz",
+1
View File
@@ -23,6 +23,7 @@
"date-fns": "2.30.0",
"js-md5": "^0.8.3",
"lodash": "4.17.21",
"marked": "^15.0.6",
"pinia": "2.1.6",
"remixicon": "3.5.0",
"vee-validate": "4.11.3",
@@ -30,6 +30,11 @@ const sidebarItem: menu[] = [
icon: 'mdi-puzzle',
to: '/extension'
},
{
title: '聊天',
icon: 'mdi-chat',
to: '/chat'
},
{
title: '控制台',
icon: 'mdi-console',
+5
View File
@@ -36,6 +36,11 @@ const MainRoutes = {
name: 'Project ATRI',
path: '/project-atri',
component: () => import('@/views/ATRIProject.vue')
},
{
name: 'Chat',
path: '/chat',
component: () => import('@/views/ChatPage.vue')
}
]
};
+372
View File
@@ -0,0 +1,372 @@
<script setup>
import axios from 'axios';
import { ref } from 'vue';
import { marked } from 'marked';
marked.setOptions({
breaks: true
});
</script>
<template>
<v-card style="margin-bottom: 16px; width: 100%; background-color: #fff; height: 100%;">
<v-card-text style="width: 100%; height: calc(100vh - 120px);">
<div style="height: 100%; display: flex; gap: 16px;">
<div style="max-width: 200px;">
<!-- conversation -->
<v-btn variant="tonal" rounded="xl" style="margin-bottom: 16px; min-width: 200px;" @click="newC"
:disabled="!currCid">+ 创建对话</v-btn>
<v-card class="mx-auto" min-width="200">
<v-list dense nav rounded="xl" v-if="conversations.length > 0"
@update:selected="getConversationMessages">
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
color="primary" rounded="xl">
<v-list-item-title>新对话</v-list-item-title>
<v-list-item-subtitle>{{ formatDate(item.updated_at) }}</v-list-item-subtitle>
</v-list-item>
</v-list>
</v-card>
<v-btn variant="tonal" rounded="xl"
style="position: fixed; bottom: 48px; margin-bottom: 16px; min-width: 200px;" v-if="currCid"
@click="deleteConversation(currCid)" color="error">删除此对话</v-btn>
</div>
<div style="height: 100%; width: 100%;">
<div style="height: calc(100% - 130px); overflow-y: auto; padding: 16px; " ref="messageContainer">
<div class="fade-in" v-if="messages.length == 0"
style="height: 100%; display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div>
<span style="font-size: 28px;">Hello, I'm</span>
<span style="font-weight: 1000; font-size: 28px; margin-left: 8px;">AstrBot ⭐</span>
</div>
<div style="margin-top: 8px; color: #aaa;">
<span>输入</span>
<span
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">/help</span>
<span>获取帮助 😊</span>
</div>
</div>
<div v-else style="max-height: 100%; padding: 16px; max-width: 700px; margin: 0 auto;">
<div class="fade-in" v-for="(msg, index) in messages" :key="index"
style="margin-bottom: 16px;">
<div v-if="msg.type == 'user'" style="display: flex; justify-content: flex-end;">
<div
style="padding: 12px; border-radius: 8px; background-color: rgba(94, 53, 177, 0.15)">
<span>{{ msg.message }}</span>
<div style="display: flex; gap: 8px; margin-top: 8px;" v-if="msg.image_url && msg.image_url.length > 0">
<div v-for="(img, index) in msg.image_url" :key="index"
style="position: relative; display: inline-block;">
<img :src="img"
style="width: 100px; height: 100px; border-radius: 8px; box-shadow: 0 0 5px rgba(0, 0, 0, 0.1);" />
</div>
</div>
</div>
</div>
<div v-else style="display: flex; justify-content: flex-start; gap: 16px;">
<span style="font-size: 32px;">✨</span>
<div v-html="marked(msg.message)" class="mc" style="font-family: inherit;"></div>
</div>
</div>
</div>
</div>
<div class="fade-in" style="bottom: 16px; width: 100%; padding: 8px; ">
<div
style="width: 100%; justify-content: center; align-items: center; display: flex; flex-direction: column; margin-top: 8px;">
<v-text-field id="input-field" variant="outlined" v-model="prompt" label="聊天吧!"
placeholder="Start typing..." loading clear-icon="mdi-close-circle" clearable
@click:clear="clearMessage" @keyup.enter="sendMessage"
style="width: 100%; max-width: 930px;">
<template v-slot:loader>
<v-progress-linear
:active="loadingChat"
:color="color"
height="6"
indeterminate
></v-progress-linear>
</template>
<template v-slot:append>
<v-icon @click="sendMessage" size="35" icon="mdi-arrow-up-circle" />
</template>
</v-text-field>
<div>
<div v-for="(img, index) in stagedImagesUrl" :key="index"
style="position: relative; display: inline-block;">
<img :src="img"
style="width: 50px; height: 50px; border-radius: 8px; box-shadow: 0 0 5px rgba(0, 0, 0, 0.1);" />
<v-icon @click="removeImage(index)" size="20" color="red"
style="position: absolute; top: 0; right: 0; cursor: pointer;">mdi-close-circle</v-icon>
</div>
</div>
</div>
</div>
</div>
</div>
</v-card-text>
</v-card>
</template>
<script>
export default {
name: 'ChatPage',
components: {
},
data() {
return {
prompt: '',
messages: [],
conversations: [],
currCid: '',
stagedImagesUrl: [],
loadingChat: false
}
},
mounted() {
this.getConversations();
let inputField = document.getElementById('input-field');
inputField.addEventListener('paste', this.handlePaste);
},
methods: {
async handlePaste(event) {
console.log('Pasting image...');
const items = event.clipboardData.items;
for (let i = 0; i < items.length; i++) {
if (items[i].type.indexOf('image') !== -1) {
const file = items[i].getAsFile();
const formData = new FormData();
formData.append('file', file);
try {
const response = await axios.post('/api/chat/post_image', formData, {
headers: {
'Content-Type': 'multipart/form-data',
'Authorization': 'Bearer ' + localStorage.getItem('token')
}
});
const img = response.data.data.filename;
this.stagedImagesUrl.push(`/api/chat/get_file?filename=${img}`);
scrollToBottom();
} catch (err) {
console.error('Error uploading image:', err);
}
}
}
},
removeImage(index) {
this.stagedImagesUrl.splice(index, 1);
},
clearMessage() {
this.prompt = '';
},
getConversations() {
axios.get('/api/chat/conversations').then(response => {
this.conversations = response.data.data;
}).catch(err => {
console.error(err);
});
},
getConversationMessages(cid) {
if (!cid[0])
return;
axios.get('/api/chat/get_conversation?conversation_id=' + cid[0]).then(response => {
this.currCid = cid[0];
let message = JSON.parse(response.data.data.history);
for (let i = 0; i < message.length; i++) {
if (message[i].message.startsWith('[IMAGE]')) {
let img = message[i].message.replace('[IMAGE]', '');
message[i].message = `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
}
if (message[i].image_url && message[i].image_url.length > 0) {
for (let j = 0; j < message[i].image_url.length; j++) {
message[i].image_url[j] = `/api/chat/get_file?filename=${message[i].image_url[j]}`;
}
}
}
this.messages = message;
}).catch(err => {
console.error(err);
});
},
async newConversation() {
await axios.get('/api/chat/new_conversation').then(response => {
this.currCid = response.data.data.conversation_id;
this.getConversations();
}).catch(err => {
console.error(err);
});
},
newC() {
this.currCid = '';
this.messages = [];
},
formatDate(timestamp) {
const date = new Date(timestamp * 1000); // 假设时间戳是以秒为单位
const options = {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
};
return date.toLocaleString('zh-CN', options).replace(/\//g, '-').replace(/, /g, ' ');
},
deleteConversation(cid) {
axios.get('/api/chat/delete_conversation?conversation_id=' + cid).then(response => {
this.getConversations();
this.currCid = '';
this.messages = [];
}).catch(err => {
console.error(err);
});
},
async sendMessage() {
if (this.currCid == '') {
await this.newConversation();
}
this.messages.push({
type: 'user',
message: this.prompt,
image_url: this.stagedImagesUrl
});
// let bot_resp = {
// type: 'bot',
// message: ref('')
// }
// this.messages.push(bot_resp);
this.scrollToBottom();
let image_filenames = [];
for (let i = 0; i < this.stagedImagesUrl.length; i++) {
let img = this.stagedImagesUrl[i].replace('/api/chat/get_file?filename=', '');
image_filenames.push(img);
}
this.loadingChat = true;
fetch('/api/chat/send', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + localStorage.getItem('token')
},
body: JSON.stringify({ message: this.prompt, conversation_id: this.currCid, image_url: image_filenames }) // 发送请求体
})
.then(response => {
this.prompt = '';
this.stagedImagesUrl = [];
this.loadingChat = false;
const reader = response.body.getReader(); // 获取流的 Reader
const decoder = new TextDecoder();
const readStream = async () => {
const { done, value } = await reader.read(); // 读取流中的数据
if (done) {
console.log("Stream finished.");
return;
}
const chunk = decoder.decode(value, { stream: true });
// bot_resp.message.value += chunk;
console.log("!!!!", chunk);
if (chunk.startsWith('[IMAGE]')) {
let img = chunk.replace('[IMAGE]', '');
let bot_resp = {
type: 'bot',
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
}
this.messages.push(bot_resp);
} else {
let bot_resp = {
type: 'bot',
message: chunk
}
this.messages.push(bot_resp);
}
this.scrollToBottom();
readStream(); // 递归读取流
};
readStream();
})
.catch(err => {
console.error(err);
});
},
scrollToBottom() {
this.$nextTick(() => {
const container = this.$refs.messageContainer;
container.scrollTop = container.scrollHeight;
});
}
}
}
</script>
<style>
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
.mc h1,
.mc h2,
.mc h3,
.mc h4,
.mc h5,
.mc h6 {
margin-bottom: 10px;
}
.mc li {
margin-left: 16px;
}
.mc p {
margin-top: 10px;
margin-bottom: 10px;
}
</style>
+61 -1
View File
@@ -1,5 +1,7 @@
<script setup>
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import axios from 'axios';
</script>
<template>
@@ -7,8 +9,34 @@ import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
<div
style="background-color: white; padding: 8px; padding-left: 16px; border-radius: 8px; margin-bottom: 16px; display: flex; flex-direction: row; align-items: center; justify-content: space-between;">
<h4>控制台</h4>
<v-dialog v-model="pipDialog" width="400">
<template v-slot:activator="{ props }">
<v-btn variant="plain" v-bind="props">安装 pip </v-btn>
</template>
<v-card>
<v-card-title>
<span class="text-h5">安装 Pip </span>
</v-card-title>
<v-card-text>
<v-text-field v-model="pipInstallPayload.package" label="*库名,如 llmtuner" variant="outlined"></v-text-field>
<v-text-field v-model="pipInstallPayload.mirror" label="镜像站链接(可选)" variant="outlined"></v-text-field>
<small>如果不填镜像站链接默认使用阿里云镜像https://mirrors.aliyun.com/pypi/simple/</small>
<div>
<small>{{ status }}</small>
</div>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="pipInstall" :loading="loading">
安装
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</div>
<ConsoleDisplayer style="height: calc(100vh - 160px); "/>
<ConsoleDisplayer style="height: calc(100vh - 160px); " />
</div>
</template>
<script>
@@ -17,6 +45,36 @@ export default {
components: {
ConsoleDisplayer
},
data() {
return {
pipDialog: false,
pipInstallPayload: {
package: '',
mirror: ''
},
loading: false,
status: ''
}
},
methods: {
pipInstall() {
this.loading = true;
axios.post('/api/update/pip-install', this.pipInstallPayload)
.then(res => {
this.status = res.data.message;
setTimeout(() => {
this.status = '';
this.pipDialog = false;
}, 2000);
})
.catch(err => {
this.status = err.response.data.message;
}).finally(() => {
this.loading = false;
});
}
}
}
</script>
@@ -26,10 +84,12 @@ export default {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
+3 -3
View File
@@ -9,8 +9,8 @@ import axios from 'axios';
<template>
<v-row>
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以前往 配置->其他配置->插件仓库镜像 修改安装镜像源。2. 如需插件帮助请点击 `仓库` 查看 README"
title="💡提示" type="info" variant="tonal">
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README"
title="💡提示" type="info" variant="tonal">
</v-alert>
<v-col cols="12" md="12">
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
@@ -80,7 +80,7 @@ import axios from 'axios';
</v-card>
</v-dialog>
<v-dialog v-model="dialog" persistent width="700">
<v-dialog v-model="dialog" width="700">
<template v-slot:activator="{ props }">
<v-btn v-bind="props" icon="mdi-plus" size="x-large" style="position: fixed; right: 52px; bottom: 52px;"
color="darkprimary">
-2
View File
@@ -2,8 +2,6 @@ import os
import asyncio
import sys
import mimetypes
import aiohttp
import zipfile
from astrbot.dashboard import AstrBotDashBoardLifecycle
from astrbot.core import db_helper
from astrbot.core import logger, LogManager, LogBroker
+33 -6
View File
@@ -56,6 +56,10 @@ class Main(star.Star):
/persona: 情境人格设置
/tool ls: 查看、激活、停用当前注册的函数工具
[其他]
/set <变量名> <值>: 为当前会话定义一个变量。适用于 Dify 工作流输入。
/unset <变量名>: 删除当前会话的变量。
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
{notice}"""
@@ -345,7 +349,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dashboard update")
@filter.command("dashboard_update")
async def update_dashboard(self, event: AstrMessageEvent):
yield event.plain_result("正在尝试更新管理面板...")
await download_dashboard()
@@ -365,12 +369,35 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
if provider.curr_personality['prompt']:
req.system_prompt += f"\n{provider.curr_personality['prompt']}"
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
async def other_message(self, event: AstrMessageEvent):
print("triggered")
event.stop_event()
@filter.command("set")
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
session_id = event.get_session_id()
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
session_var[key] = value
session_vars[session_id] = session_var
sp.put("session_variables", session_vars)
yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。")
@filter.command("unset")
async def unset_variable(self, event: AstrMessageEvent, key: str):
session_id = event.get_session_id()
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
if key not in session_var:
yield event.plain_result("没有那个变量名。")
else:
del session_var[key]
sp.put("session_variables", session_vars)
yield event.plain_result(f"会话 {session_id} 变量 {key} 移除成功。")
@filter.command_group("kdb")
def kdb(self):
pass
+33
View File
@@ -0,0 +1,33 @@
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
batch_size=16, # batch size for inference - set based on your device
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])