Compare commits

...

10 Commits

Author SHA1 Message Date
Soulter 29374f8d8a fix: 修复 /dashbord_update 指令 2025-01-11 00:25:02 +08:00
Soulter 359b971103 Merge pull request #235 from Soulter/feat-webchat
WebChat 支持
2025-01-11 00:17:18 +08:00
Soulter fbdb1ae208 chore: bump to v3.4.4 2025-01-11 00:14:08 +08:00
Soulter 22c13c1eff perf: webchat支持传图 2025-01-11 00:06:19 +08:00
Soulter 5fc63aeaf1 perf: ui 2025-01-10 22:45:14 +08:00
Soulter d4f32673ab fix: 修复持久化问题 2025-01-10 22:08:43 +08:00
Soulter 480dffb51b feat: 初步实现 webchat 页面 2025-01-10 21:48:15 +08:00
Soulter 966df00124 feat: 支持从管理面板(控制台页)手动安装 pip 库 2025-01-10 15:35:57 +08:00
Soulter 3e2b4bc727 feat: 支持动态设置会话变量以适用 Dify 输入变量 2025-01-10 12:32:20 +08:00
Soulter 5929a8d42b Update README.md 2025-01-09 23:11:11 +08:00
34 changed files with 1074 additions and 52 deletions
+2 -1
View File
@@ -20,4 +20,5 @@ chroma
node_modules/ node_modules/
.DS_Store .DS_Store
package-lock.json package-lock.json
package.json package.json
venv/*
+21 -1
View File
@@ -45,7 +45,27 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
## ✨ 支持 Dify ## ✨ 支持 Dify
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流! 1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流![接入 Dify - AstrBot 文档](https://astrbot.lwl.lol/others/dify.html)
## ✨ 代码执行器(Beta)
基于 Docker 的沙箱化代码执行器(Beta 测试中)
> [!NOTE]
> 文件输入/输出目前仅支持 Napcat(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/700a545e-7450-4f23-90ff-af6d0d60e501" height=300>
<img src="https://github.com/user-attachments/assets/0b0c5344-e98b-4902-92ad-fe9f0bb10c2a" height=300>
<img src="https://github.com/user-attachments/assets/b9b98ff4-8630-46fb-9a39-ecbad9d601ae" height=300>
<img src="https://github.com/user-attachments/assets/9fe6e44c-e4f6-4347-9d5f-281677d47feb" height=300>
</div>
## ✨ Demo ## ✨ Demo
+7
View File
@@ -1,12 +1,16 @@
import os import os
import asyncio
from .log import LogManager, LogBroker from .log import LogManager, LogBroker
from astrbot.core.utils.t2i.renderer import HtmlRenderer from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
os.makedirs("data", exist_ok=True) os.makedirs("data", exist_ok=True)
astrbot_config = AstrBotConfig()
html_renderer = HtmlRenderer() html_renderer = HtmlRenderer()
logger = LogManager.GetLogger(log_name='astrbot') logger = LogManager.GetLogger(log_name='astrbot')
@@ -15,4 +19,7 @@ if os.environ.get('TESTING', ""):
db_helper = SQLiteDatabase(DB_PATH) db_helper = SQLiteDatabase(DB_PATH)
sp = SharedPreferences() # 简单的偏好设置存储 sp = SharedPreferences() # 简单的偏好设置存储
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
web_chat_queue = asyncio.Queue()
web_chat_back_queue = asyncio.Queue()
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool" WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
+1 -1
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。 如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
""" """
VERSION = "3.4.3" VERSION = "3.4.4"
DB_PATH = "data/data_v3.db" DB_PATH = "data/data_v3.db"
# 默认配置 # 默认配置
+2 -1
View File
@@ -3,6 +3,7 @@ import time
import threading import threading
import os import os
from .event_bus import EventBus from .event_bus import EventBus
from . import astrbot_config
from asyncio import Queue from asyncio import Queue
from typing import List from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -21,7 +22,7 @@ from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class AstrBotCoreLifecycle: class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase): def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker self.log_broker = log_broker
self.astrbot_config = AstrBotConfig() self.astrbot_config = astrbot_config
self.db = db self.db = db
if self.astrbot_config['http_proxy']: if self.astrbot_config['http_proxy']:
+25 -1
View File
@@ -1,7 +1,7 @@
import abc import abc
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
@dataclass @dataclass
class BaseDatabase(abc.ABC): class BaseDatabase(abc.ABC):
@@ -76,4 +76,28 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision: def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
'''通过 url 或 path 获取 ATRI 视觉数据''' '''通过 url 或 path 获取 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
'''通过 user_id 和 cid 获取 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def webchat_new_conversation(self, user_id: str, cid: str):
'''新建 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
raise NotImplementedError
@abc.abstractmethod
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
'''更新 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def delete_webchat_conversation(self, user_id: str, cid: str):
'''删除 WebChatConversation'''
raise NotImplementedError raise NotImplementedError
+12 -1
View File
@@ -51,4 +51,15 @@ class ATRIVision():
platform_name: str platform_name: str
session_id: str session_id: str
sender_nickname: str sender_nickname: str
timestamp: int = -1 timestamp: int = -1
@dataclass
class WebChatConversation():
user_id: str
cid: str
history: str = ""
created_at: int = 0
updated_at: int = 0
+65 -1
View File
@@ -5,7 +5,8 @@ from astrbot.core.db.po import (
Platform, Platform,
Stats, Stats,
LLMHistory, LLMHistory,
ATRIVision ATRIVision,
WebChatConversation
) )
from . import BaseDatabase from . import BaseDatabase
from typing import Tuple from typing import Tuple
@@ -199,6 +200,69 @@ class SQLiteDatabase(BaseDatabase):
c.close() c.close()
return Stats(platform, [], []) return Stats(platform, [], [])
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
res = c.fetchone()
c.close()
return WebChatConversation(*res)
def webchat_new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
'''
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
''', (user_id, cid, history, updated_at, created_at)
)
def get_webchat_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
''', (user_id,)
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
return conversations
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
''', (history, user_id, cid)
)
def delete_webchat_conversation(self, user_id: str, cid: str):
self._exec_sql(
'''
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
def insert_atri_vision_data(self, vision: ATRIVision): def insert_atri_vision_data(self, vision: ATRIVision):
+8
View File
@@ -35,4 +35,12 @@ CREATE TABLE IF NOT EXISTS atri_vision(
session_id VARCHAR(32), session_id VARCHAR(32),
sender_nickname VARCHAR(32), sender_nickname VARCHAR(32),
timestamp INTEGER timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT,
cid TEXT,
history TEXT,
created_at INTEGER,
updated_at INTEGER
); );
@@ -20,6 +20,10 @@ class WhitelistCheckStage(Stage):
if not self.enable_whitelist_check: if not self.enable_whitelist_check:
return return
if event.get_platform_name() == 'webchat':
# WebChat 豁免
return
# 检查是否在白名单 # 检查是否在白名单
if self.wl_ignore_admin_on_group: if self.wl_ignore_admin_on_group:
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE: if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
+4 -1
View File
@@ -4,7 +4,7 @@ from typing import List
from asyncio import Queue from asyncio import Queue
from .register import platform_cls_map from .register import platform_cls_map
from astrbot.core import logger from astrbot.core import logger
from .sources.webchat.webchat_adapter import WebChatAdapter
class PlatformManager(): class PlatformManager():
def __init__(self, config: AstrBotConfig, event_queue: Queue): def __init__(self, config: AstrBotConfig, event_queue: Queue):
@@ -25,6 +25,7 @@ class PlatformManager():
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401 from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "vchat": case "vchat":
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401 from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
async def initialize(self): async def initialize(self):
for platform in self.platforms_config: for platform in self.platforms_config:
@@ -37,6 +38,8 @@ class PlatformManager():
logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...") logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
inst = cls_type(platform, self.settings, self.event_queue) inst = cls_type(platform, self.settings, self.event_queue)
self.platform_insts.append(inst) self.platform_insts.append(inst)
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
def get_insts(self): def get_insts(self):
return self.platform_insts return self.platform_insts
@@ -0,0 +1,102 @@
import time
import asyncio
import uuid
import os
from typing import Awaitable, Any
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, Image # noqa: F403
from astrbot.api import logger
from astrbot.core import web_chat_queue, web_chat_back_queue
from .webchat_event import WebChatMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
class QueueListener:
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
self.queue = queue
self.callback = callback
async def run(self):
while True:
data = await self.queue.get()
await self.callback(data)
@register_platform_adapter("webchat", "webchat")
class WebChatAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings['unique_session']
self.imgs_dir = "data/webchat/imgs"
self.metadata = PlatformMetadata(
"webchat",
"webchat",
)
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
plain = ""
for comp in message_chain.chain:
if isinstance(comp, Plain):
plain += comp.text
web_chat_back_queue.put_nowait(plain)
await super().send_by_session(session, message_chain)
async def convert_message(self, data: tuple) -> AstrBotMessage:
username, cid, payload = data
abm = AstrBotMessage()
abm.self_id = "webchat"
abm.tag = "webchat"
abm.sender = MessageMember(username, username)
abm.type = MessageType.FRIEND_MESSAGE
abm.session_id = f"webchat!{username}!{cid}"
abm.message_id = str(uuid.uuid4())
abm.message = []
if payload['message']:
abm.message.append(Plain(payload['message']))
if payload['image_url']:
if isinstance(payload['image_url'], list):
for img in payload['image_url']:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img)))
else:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url'])))
logger.debug(f"WebChatAdapter: {abm.message}")
message_str = payload['message']
abm.timestamp = int(time.time())
abm.message_str = message_str
abm.raw_message = data
return abm
def run(self) -> Awaitable[Any]:
async def callback(data: tuple):
abm = await self.convert_message(data)
await self.handle_msg(abm)
bot = QueueListener(web_chat_queue, callback)
return bot.run()
def meta(self) -> PlatformMetadata:
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
message_event = WebChatMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id
)
self.commit_event(message_event)
@@ -0,0 +1,31 @@
import os
import uuid
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.core import web_chat_back_queue
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.imgs_dir = "data/webchat/imgs"
os.makedirs(self.imgs_dir, exist_ok=True)
async def send(self, message: MessageChain):
for comp in message.chain:
if isinstance(comp, Plain):
await web_chat_back_queue.put(comp.text)
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
await web_chat_back_queue.put(f"[IMAGE]{filename}")
await web_chat_back_queue.put(None)
await super().send(message)
+10 -4
View File
@@ -6,8 +6,7 @@ from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
from astrbot.core import logger from astrbot.core import logger, sp
@register_provider_adapter("dify", "Dify APP 适配器。") @register_provider_adapter("dify", "Dify APP 适配器。")
class ProviderDify(Provider): class ProviderDify(Provider):
@@ -67,10 +66,16 @@ class ProviderDify(Provider):
logger.debug(files_payload) logger.debug(files_payload)
# 获得会话变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
match self.api_type: match self.api_type:
case "chat" | "agent": case "chat" | "agent":
async for chunk in self.api_client.chat_messages( async for chunk in self.api_client.chat_messages(
inputs={}, inputs={
**session_var
},
query=prompt, query=prompt,
user=session_id, user=session_id,
conversation_id=conversation_id, conversation_id=conversation_id,
@@ -88,7 +93,8 @@ class ProviderDify(Provider):
async for chunk in self.api_client.workflow_run( async for chunk in self.api_client.workflow_run(
inputs={ inputs={
"astrbot_text_query": prompt, "astrbot_text_query": prompt,
"astrbot_session_id": session_id "astrbot_session_id": session_id,
**session_var
}, },
user=session_id, user=session_id,
files=files_payload files=files_payload
@@ -209,6 +209,8 @@ class ProviderOpenAIOfficial(Provider):
image_path = await download_image_by_url(image_url) image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path) image_data = await self.encode_image_bs64(image_path)
else: else:
if image_url.startswith("file:///"):
image_url = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_url) image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}}) user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content return user_content
+3 -12
View File
@@ -9,7 +9,7 @@ from types import ModuleType
from typing import List from typing import List
from pip import main as pip_main from pip import main as pip_main
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger, sp from astrbot.core import logger, sp, pip_installer
from .context import Context from .context import Context
from . import StarMetadata from . import StarMetadata
from .updator import PluginUpdator from .updator import PluginUpdator
@@ -92,21 +92,12 @@ class PluginManager:
plugin_path = os.path.join(plugin_dir, p) plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")): if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt") pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在检查插件 {p} 的依赖: {pth}") logger.info(f"正在安装插件 {p} 所需的依赖: {pth}")
try: try:
self._update_plugin_dept(os.path.join(plugin_path, "requirements.txt")) pip_installer.install(requirements_path=pth)
except Exception as e: except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}") logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
def _update_plugin_dept(self, path):
'''更新插件的依赖'''
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
if self.config.pip_install_arg:
args.extend([self.config.pip_install_arg])
result_code = pip_main(args)
if result_code != 0:
raise Exception(str(result_code))
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata: def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata:
'''v3.4.0 以前的方式载入插件元数据 '''v3.4.0 以前的方式载入插件元数据
+21 -9
View File
@@ -1,4 +1,5 @@
import json import json
from astrbot.core import logger
from aiohttp import ClientSession from aiohttp import ClientSession
from typing import Dict, List, Any, AsyncGenerator from typing import Dict, List, Any, AsyncGenerator
@@ -29,11 +30,18 @@ class DifyAPIClient:
async with self.session.post( async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout url, json=payload, headers=self.headers, timeout=timeout
) as resp: ) as resp:
async for data in resp.content: while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip(): if not data.strip():
continue continue
if data.startswith(b"data:"): elif data.startswith(b"data:"):
yield json.loads(data[5:]) try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def workflow_run( async def workflow_run(
self, self,
@@ -50,11 +58,18 @@ class DifyAPIClient:
async with self.session.post( async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout url, json=payload, headers=self.headers, timeout=timeout
) as resp: ) as resp:
async for data in resp.content: while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip(): if not data.strip():
continue continue
if data.startswith(b"data:"): elif data.startswith(b"data:"):
yield json.loads(data[5:]) try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def file_upload( async def file_upload(
self, self,
@@ -70,9 +85,6 @@ class DifyAPIClient:
url, data=payload, headers=self.headers url, data=payload, headers=self.headers
) as resp: ) as resp:
return await resp.json() # {"id": "xxx", ...} return await resp.json() # {"id": "xxx", ...}
async def close(self): async def close(self):
await self.session.close() await self.session.close()
+13 -3
View File
@@ -65,7 +65,7 @@ def save_temp_img(img: Image) -> str:
f.write(img) f.write(img)
return p return p
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str: async def download_image_by_url(url: str, post: bool = False, post_data: dict = None, path = None) -> str:
''' '''
下载图片, 返回 path 下载图片, 返回 path
''' '''
@@ -73,10 +73,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
if post: if post:
async with session.post(url, json=post_data) as resp: async with session.post(url, json=post_data) as resp:
return save_temp_img(await resp.read()) if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
else: else:
async with session.get(url) as resp: async with session.get(url) as resp:
return save_temp_img(await resp.read()) if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
except aiohttp.client_exceptions.ClientConnectorSSLError: except aiohttp.client_exceptions.ClientConnectorSSLError:
# 关闭SSL验证 # 关闭SSL验证
ssl_context = ssl.create_default_context() ssl_context = ssl.create_default_context()
+33
View File
@@ -0,0 +1,33 @@
import logging
from pip import main as pip_main
class PipInstaller():
def __init__(self, pip_install_arg: str):
self.pip_install_arg = pip_install_arg
def install(self, package_name: str = None, requirements_path: str = None, mirror: str = None):
args = ['install']
if package_name:
args.append(package_name)
elif requirements_path:
args.extend(['-r', requirements_path])
if not mirror:
mirror = 'https://mirrors.aliyun.com/pypi/simple/'
args.extend(['--trusted-host', 'mirrors.aliyun.com', '-i', mirror])
if self.pip_install_arg:
args.extend(self.pip_install_arg.split())
print(f"Pip 包管理器: {' '.join(args)}")
result_code = pip_main(args)
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if result_code != 0:
raise Exception(f"安装失败,错误码:{result_code}")
+3 -1
View File
@@ -5,6 +5,7 @@ from .update import UpdateRoute
from .stat import StatRoute from .stat import StatRoute
from .log import LogRoute from .log import LogRoute
from .static_file import StaticFileRoute from .static_file import StaticFileRoute
from .chat import ChatRoute
__all__ = [ __all__ = [
@@ -14,6 +15,7 @@ __all__ = [
"UpdateRoute", "UpdateRoute",
"StatRoute", "StatRoute",
"LogRoute", "LogRoute",
"StaticFileRoute" "StaticFileRoute",
"ChatRoute",
] ]
+149
View File
@@ -0,0 +1,149 @@
import uuid
import json
import os
from .route import Route, Response, RouteContext
from astrbot.core import web_chat_queue, web_chat_back_queue
from quart import request, Response as QuartResponse, g
from astrbot.core.db import BaseDatabase
import asyncio
class ChatRoute(Route):
def __init__(self, context: RouteContext, db: BaseDatabase) -> None:
super().__init__(context)
self.routes = {
'/chat/send': ('POST', self.chat),
'/chat/new_conversation': ('GET', self.new_conversation),
'/chat/conversations': ('GET', self.get_conversations),
'/chat/get_conversation': ('GET', self.get_conversation),
'/chat/delete_conversation': ('GET', self.delete_conversation),
'/chat/get_file': ('GET', self.get_file),
'/chat/post_image': ('POST', self.post_image)
}
self.db = db
self.register_routes()
self.imgs_dir = "data/webchat/imgs"
async def get_file(self):
filename = request.args.get('filename')
if not filename:
return Response().error("Missing key: filename").__dict__
try:
with open(os.path.join(self.imgs_dir, filename), "rb") as f:
return QuartResponse(f.read(), mimetype="image/jpeg")
except FileNotFoundError:
return Response().error("File not found").__dict__
async def post_image(self):
post_data = await request.files
if 'file' not in post_data:
return Response().error("Missing key: file").__dict__
file = post_data['file']
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
await file.save(path)
return Response().ok(data={
'filename': filename
}).__dict__
async def chat(self):
username = g.get('username', 'guest')
post_data = await request.json
if 'message' not in post_data and 'image_url' not in post_data:
return Response().error("Missing key: message or image_url").__dict__
if 'conversation_id' not in post_data:
return Response().error("Missing key: conversation_id").__dict__
message = post_data['message']
conversation_id = post_data['conversation_id']
image_url = post_data.get('image_url')
if not message and not image_url:
return Response().error("Message and image_url are empty").__dict__
if not conversation_id:
return Response().error("conversation_id is empty").__dict__
await web_chat_queue.put((username, conversation_id, {
'message': message,
'image_url': image_url # list
}))
async def stream():
ret = []
while True:
result = await web_chat_back_queue.get()
if result is None:
break
ret.append(result)
yield result + '\n'
await asyncio.sleep(0.5)
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
new_his = {
'type': 'user',
'message': message
}
if image_url:
new_his['image_url'] = image_url
history.append(new_his)
for r in ret:
history.append({
'type': 'bot',
'message': r
})
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
return QuartResponse(
stream(),
mimetype="text/event-stream",
headers={
"Content-Type": "text/event-stream",
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*" # 如果是跨域请求
}
)
async def delete_conversation(self):
username = g.get('username', 'guest')
conversation_id = request.args.get('conversation_id')
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
self.db.delete_webchat_conversation(username, conversation_id)
return Response().ok().__dict__
async def new_conversation(self):
username = g.get('username', 'guest')
conversation_id = str(uuid.uuid4())
self.db.webchat_new_conversation(username, conversation_id)
return Response().ok(data={
'conversation_id': conversation_id
}).__dict__
async def get_conversations(self):
username = g.get('username', 'guest')
conversations = self.db.get_webchat_conversations(username)
return Response().ok(data=conversations).__dict__
async def get_conversation(self):
username = g.get('username', 'guest')
conversation_id = request.args.get('conversation_id')
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
return Response().ok(data=conversation).__dict__
+14 -1
View File
@@ -3,7 +3,7 @@ import traceback
from .route import Route, Response, RouteContext from .route import Route, Response, RouteContext
from quart import request from quart import request
from astrbot.core.updator import AstrBotUpdator from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger from astrbot.core import logger, pip_installer
class UpdateRoute(Route): class UpdateRoute(Route):
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None: def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None:
@@ -11,6 +11,7 @@ class UpdateRoute(Route):
self.routes = { self.routes = {
'/update/check': ('GET', self.check_update), '/update/check': ('GET', self.check_update),
'/update/do': ('POST', self.update_project), '/update/do': ('POST', self.update_project),
'/update/pip-install': ('POST', self.install_pip_package)
} }
self.astrbot_updator = astrbot_updator self.astrbot_updator = astrbot_updator
self.register_routes() self.register_routes()
@@ -47,4 +48,16 @@ class UpdateRoute(Route):
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__ return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
except Exception as e: except Exception as e:
logger.error(f"/api/update_project: {traceback.format_exc()}") logger.error(f"/api/update_project: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
async def install_pip_package(self):
data = await request.json
package = data.get('package', '')
if not package:
return Response().error("缺少参数 package 或不合法。").__dict__
try:
pip_installer.install(package)
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(f"/api/update_pip: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__ return Response().error(e.__str__()).__dict__
+6 -2
View File
@@ -2,7 +2,7 @@ import logging
import jwt import jwt
import asyncio import asyncio
import os import os
from quart import Quart, request, jsonify from quart import Quart, request, jsonify, g
from quart.logging import default_handler from quart.logging import default_handler
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .routes import * from .routes import *
@@ -31,12 +31,15 @@ class AstrBotDashboard():
self.lr = LogRoute(self.context, core_lifecycle.log_broker) self.lr = LogRoute(self.context, core_lifecycle.log_broker)
self.sfr = StaticFileRoute(self.context) self.sfr = StaticFileRoute(self.context)
self.ar = AuthRoute(self.context) self.ar = AuthRoute(self.context)
self.chat_route = ChatRoute(self.context, db)
async def auth_middleware(self): async def auth_middleware(self):
if not request.path.startswith("/api"): if not request.path.startswith("/api"):
return return
if request.path == "/api/auth/login": if request.path == "/api/auth/login":
return return
if request.path == "/api/chat/get_file":
return
# claim jwt # claim jwt
token = request.headers.get("Authorization") token = request.headers.get("Authorization")
if not token: if not token:
@@ -46,7 +49,8 @@ class AstrBotDashboard():
if token.startswith("Bearer "): if token.startswith("Bearer "):
token = token[7:] token = token[7:]
try: try:
jwt.decode(token, WEBUI_SK, algorithms=["HS256"]) payload = jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
g.username = payload["username"]
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
r = jsonify(Response().error("Token 过期").__dict__) r = jsonify(Response().error("Token 过期").__dict__)
r.status_code = 401 r.status_code = 401
+6
View File
@@ -0,0 +1,6 @@
# What's Changed
1. 支持通过 /set <k> <v> 设置持久化的会话变量, 方便 Dify App 输入变量
2. 管理面板支持 Web Chat
3. 管理面板支持手动安装 Pip 库, 在 `控制台` 页中可找到
+17
View File
@@ -18,6 +18,7 @@
"date-fns": "2.30.0", "date-fns": "2.30.0",
"js-md5": "^0.8.3", "js-md5": "^0.8.3",
"lodash": "4.17.21", "lodash": "4.17.21",
"marked": "^15.0.6",
"pinia": "2.1.6", "pinia": "2.1.6",
"remixicon": "3.5.0", "remixicon": "3.5.0",
"vee-validate": "4.11.3", "vee-validate": "4.11.3",
@@ -3702,6 +3703,17 @@
"markdown-it": "bin/markdown-it.js" "markdown-it": "bin/markdown-it.js"
} }
}, },
"node_modules/marked": {
"version": "15.0.6",
"resolved": "https://registry.npmjs.org/marked/-/marked-15.0.6.tgz",
"integrity": "sha512-Y07CUOE+HQXbVDCGl3LXggqJDbXDP2pArc2C1N1RRMN0ONiShoSsIInMd5Gsxupe7fKLpgimTV+HOJ9r7bA+pg==",
"bin": {
"marked": "bin/marked.js"
},
"engines": {
"node": ">= 18"
}
},
"node_modules/mdurl": { "node_modules/mdurl": {
"version": "1.0.1", "version": "1.0.1",
"resolved": "https://registry.npmjs.org/mdurl/-/mdurl-1.0.1.tgz", "resolved": "https://registry.npmjs.org/mdurl/-/mdurl-1.0.1.tgz",
@@ -8460,6 +8472,11 @@
"uc.micro": "^1.0.5" "uc.micro": "^1.0.5"
} }
}, },
"marked": {
"version": "15.0.6",
"resolved": "https://registry.npmjs.org/marked/-/marked-15.0.6.tgz",
"integrity": "sha512-Y07CUOE+HQXbVDCGl3LXggqJDbXDP2pArc2C1N1RRMN0ONiShoSsIInMd5Gsxupe7fKLpgimTV+HOJ9r7bA+pg=="
},
"mdurl": { "mdurl": {
"version": "1.0.1", "version": "1.0.1",
"resolved": "https://registry.npmjs.org/mdurl/-/mdurl-1.0.1.tgz", "resolved": "https://registry.npmjs.org/mdurl/-/mdurl-1.0.1.tgz",
+1
View File
@@ -23,6 +23,7 @@
"date-fns": "2.30.0", "date-fns": "2.30.0",
"js-md5": "^0.8.3", "js-md5": "^0.8.3",
"lodash": "4.17.21", "lodash": "4.17.21",
"marked": "^15.0.6",
"pinia": "2.1.6", "pinia": "2.1.6",
"remixicon": "3.5.0", "remixicon": "3.5.0",
"vee-validate": "4.11.3", "vee-validate": "4.11.3",
@@ -30,6 +30,11 @@ const sidebarItem: menu[] = [
icon: 'mdi-puzzle', icon: 'mdi-puzzle',
to: '/extension' to: '/extension'
}, },
{
title: '聊天',
icon: 'mdi-chat',
to: '/chat'
},
{ {
title: '控制台', title: '控制台',
icon: 'mdi-console', icon: 'mdi-console',
+5
View File
@@ -36,6 +36,11 @@ const MainRoutes = {
name: 'Project ATRI', name: 'Project ATRI',
path: '/project-atri', path: '/project-atri',
component: () => import('@/views/ATRIProject.vue') component: () => import('@/views/ATRIProject.vue')
},
{
name: 'Chat',
path: '/chat',
component: () => import('@/views/ChatPage.vue')
} }
] ]
}; };
+372
View File
@@ -0,0 +1,372 @@
<script setup>
import axios from 'axios';
import { ref } from 'vue';
import { marked } from 'marked';
marked.setOptions({
breaks: true
});
</script>
<template>
<v-card style="margin-bottom: 16px; width: 100%; background-color: #fff; height: 100%;">
<v-card-text style="width: 100%; height: calc(100vh - 120px);">
<div style="height: 100%; display: flex; gap: 16px;">
<div style="max-width: 200px;">
<!-- conversation -->
<v-btn variant="tonal" rounded="xl" style="margin-bottom: 16px; min-width: 200px;" @click="newC"
:disabled="!currCid">+ 创建对话</v-btn>
<v-card class="mx-auto" min-width="200">
<v-list dense nav rounded="xl" v-if="conversations.length > 0"
@update:selected="getConversationMessages">
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
color="primary" rounded="xl">
<v-list-item-title>新对话</v-list-item-title>
<v-list-item-subtitle>{{ formatDate(item.updated_at) }}</v-list-item-subtitle>
</v-list-item>
</v-list>
</v-card>
<v-btn variant="tonal" rounded="xl"
style="position: fixed; bottom: 48px; margin-bottom: 16px; min-width: 200px;" v-if="currCid"
@click="deleteConversation(currCid)" color="error">删除此对话</v-btn>
</div>
<div style="height: 100%; width: 100%;">
<div style="height: calc(100% - 130px); overflow-y: auto; padding: 16px; " ref="messageContainer">
<div class="fade-in" v-if="messages.length == 0"
style="height: 100%; display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div>
<span style="font-size: 28px;">Hello, I'm</span>
<span style="font-weight: 1000; font-size: 28px; margin-left: 8px;">AstrBot ⭐</span>
</div>
<div style="margin-top: 8px; color: #aaa;">
<span>输入</span>
<span
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">/help</span>
<span>获取帮助 😊</span>
</div>
</div>
<div v-else style="max-height: 100%; padding: 16px; max-width: 700px; margin: 0 auto;">
<div class="fade-in" v-for="(msg, index) in messages" :key="index"
style="margin-bottom: 16px;">
<div v-if="msg.type == 'user'" style="display: flex; justify-content: flex-end;">
<div
style="padding: 12px; border-radius: 8px; background-color: rgba(94, 53, 177, 0.15)">
<span>{{ msg.message }}</span>
<div style="display: flex; gap: 8px; margin-top: 8px;" v-if="msg.image_url && msg.image_url.length > 0">
<div v-for="(img, index) in msg.image_url" :key="index"
style="position: relative; display: inline-block;">
<img :src="img"
style="width: 100px; height: 100px; border-radius: 8px; box-shadow: 0 0 5px rgba(0, 0, 0, 0.1);" />
</div>
</div>
</div>
</div>
<div v-else style="display: flex; justify-content: flex-start; gap: 16px;">
<span style="font-size: 32px;">✨</span>
<div v-html="marked(msg.message)" class="mc" style="font-family: inherit;"></div>
</div>
</div>
</div>
</div>
<div class="fade-in" style="bottom: 16px; width: 100%; padding: 8px; ">
<div
style="width: 100%; justify-content: center; align-items: center; display: flex; flex-direction: column; margin-top: 8px;">
<v-text-field id="input-field" variant="outlined" v-model="prompt" label="聊天吧!"
placeholder="Start typing..." loading clear-icon="mdi-close-circle" clearable
@click:clear="clearMessage" @keyup.enter="sendMessage"
style="width: 100%; max-width: 930px;">
<template v-slot:loader>
<v-progress-linear
:active="loadingChat"
:color="color"
height="6"
indeterminate
></v-progress-linear>
</template>
<template v-slot:append>
<v-icon @click="sendMessage" size="35" icon="mdi-arrow-up-circle" />
</template>
</v-text-field>
<div>
<div v-for="(img, index) in stagedImagesUrl" :key="index"
style="position: relative; display: inline-block;">
<img :src="img"
style="width: 50px; height: 50px; border-radius: 8px; box-shadow: 0 0 5px rgba(0, 0, 0, 0.1);" />
<v-icon @click="removeImage(index)" size="20" color="red"
style="position: absolute; top: 0; right: 0; cursor: pointer;">mdi-close-circle</v-icon>
</div>
</div>
</div>
</div>
</div>
</div>
</v-card-text>
</v-card>
</template>
<script>
export default {
name: 'ChatPage',
components: {
},
data() {
return {
prompt: '',
messages: [],
conversations: [],
currCid: '',
stagedImagesUrl: [],
loadingChat: false
}
},
mounted() {
this.getConversations();
let inputField = document.getElementById('input-field');
inputField.addEventListener('paste', this.handlePaste);
},
methods: {
async handlePaste(event) {
console.log('Pasting image...');
const items = event.clipboardData.items;
for (let i = 0; i < items.length; i++) {
if (items[i].type.indexOf('image') !== -1) {
const file = items[i].getAsFile();
const formData = new FormData();
formData.append('file', file);
try {
const response = await axios.post('/api/chat/post_image', formData, {
headers: {
'Content-Type': 'multipart/form-data',
'Authorization': 'Bearer ' + localStorage.getItem('token')
}
});
const img = response.data.data.filename;
this.stagedImagesUrl.push(`/api/chat/get_file?filename=${img}`);
scrollToBottom();
} catch (err) {
console.error('Error uploading image:', err);
}
}
}
},
removeImage(index) {
this.stagedImagesUrl.splice(index, 1);
},
clearMessage() {
this.prompt = '';
},
getConversations() {
axios.get('/api/chat/conversations').then(response => {
this.conversations = response.data.data;
}).catch(err => {
console.error(err);
});
},
getConversationMessages(cid) {
if (!cid[0])
return;
axios.get('/api/chat/get_conversation?conversation_id=' + cid[0]).then(response => {
this.currCid = cid[0];
let message = JSON.parse(response.data.data.history);
for (let i = 0; i < message.length; i++) {
if (message[i].message.startsWith('[IMAGE]')) {
let img = message[i].message.replace('[IMAGE]', '');
message[i].message = `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
}
if (message[i].image_url && message[i].image_url.length > 0) {
for (let j = 0; j < message[i].image_url.length; j++) {
message[i].image_url[j] = `/api/chat/get_file?filename=${message[i].image_url[j]}`;
}
}
}
this.messages = message;
}).catch(err => {
console.error(err);
});
},
async newConversation() {
await axios.get('/api/chat/new_conversation').then(response => {
this.currCid = response.data.data.conversation_id;
this.getConversations();
}).catch(err => {
console.error(err);
});
},
newC() {
this.currCid = '';
this.messages = [];
},
formatDate(timestamp) {
const date = new Date(timestamp * 1000); // 假设时间戳是以秒为单位
const options = {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
};
return date.toLocaleString('zh-CN', options).replace(/\//g, '-').replace(/, /g, ' ');
},
deleteConversation(cid) {
axios.get('/api/chat/delete_conversation?conversation_id=' + cid).then(response => {
this.getConversations();
this.currCid = '';
this.messages = [];
}).catch(err => {
console.error(err);
});
},
async sendMessage() {
if (this.currCid == '') {
await this.newConversation();
}
this.messages.push({
type: 'user',
message: this.prompt,
image_url: this.stagedImagesUrl
});
// let bot_resp = {
// type: 'bot',
// message: ref('')
// }
// this.messages.push(bot_resp);
this.scrollToBottom();
let image_filenames = [];
for (let i = 0; i < this.stagedImagesUrl.length; i++) {
let img = this.stagedImagesUrl[i].replace('/api/chat/get_file?filename=', '');
image_filenames.push(img);
}
this.loadingChat = true;
fetch('/api/chat/send', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + localStorage.getItem('token')
},
body: JSON.stringify({ message: this.prompt, conversation_id: this.currCid, image_url: image_filenames }) // 发送请求体
})
.then(response => {
this.prompt = '';
this.stagedImagesUrl = [];
this.loadingChat = false;
const reader = response.body.getReader(); // 获取流的 Reader
const decoder = new TextDecoder();
const readStream = async () => {
const { done, value } = await reader.read(); // 读取流中的数据
if (done) {
console.log("Stream finished.");
return;
}
const chunk = decoder.decode(value, { stream: true });
// bot_resp.message.value += chunk;
console.log("!!!!", chunk);
if (chunk.startsWith('[IMAGE]')) {
let img = chunk.replace('[IMAGE]', '');
let bot_resp = {
type: 'bot',
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
}
this.messages.push(bot_resp);
} else {
let bot_resp = {
type: 'bot',
message: chunk
}
this.messages.push(bot_resp);
}
this.scrollToBottom();
readStream(); // 递归读取流
};
readStream();
})
.catch(err => {
console.error(err);
});
},
scrollToBottom() {
this.$nextTick(() => {
const container = this.$refs.messageContainer;
container.scrollTop = container.scrollHeight;
});
}
}
}
</script>
<style>
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
.mc h1,
.mc h2,
.mc h3,
.mc h4,
.mc h5,
.mc h6 {
margin-bottom: 10px;
}
.mc li {
margin-left: 16px;
}
.mc p {
margin-top: 10px;
margin-bottom: 10px;
}
</style>
+61 -1
View File
@@ -1,5 +1,7 @@
<script setup> <script setup>
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue'; import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import axios from 'axios';
</script> </script>
<template> <template>
@@ -7,8 +9,34 @@ import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
<div <div
style="background-color: white; padding: 8px; padding-left: 16px; border-radius: 8px; margin-bottom: 16px; display: flex; flex-direction: row; align-items: center; justify-content: space-between;"> style="background-color: white; padding: 8px; padding-left: 16px; border-radius: 8px; margin-bottom: 16px; display: flex; flex-direction: row; align-items: center; justify-content: space-between;">
<h4>控制台</h4> <h4>控制台</h4>
<v-dialog v-model="pipDialog" width="400">
<template v-slot:activator="{ props }">
<v-btn variant="plain" v-bind="props">安装 pip </v-btn>
</template>
<v-card>
<v-card-title>
<span class="text-h5">安装 Pip </span>
</v-card-title>
<v-card-text>
<v-text-field v-model="pipInstallPayload.package" label="*库名,如 llmtuner" variant="outlined"></v-text-field>
<v-text-field v-model="pipInstallPayload.mirror" label="镜像站链接(可选)" variant="outlined"></v-text-field>
<small>如果不填镜像站链接默认使用阿里云镜像https://mirrors.aliyun.com/pypi/simple/</small>
<div>
<small>{{ status }}</small>
</div>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="pipInstall" :loading="loading">
安装
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</div> </div>
<ConsoleDisplayer style="height: calc(100vh - 160px); "/> <ConsoleDisplayer style="height: calc(100vh - 160px); " />
</div> </div>
</template> </template>
<script> <script>
@@ -17,6 +45,36 @@ export default {
components: { components: {
ConsoleDisplayer ConsoleDisplayer
}, },
data() {
return {
pipDialog: false,
pipInstallPayload: {
package: '',
mirror: ''
},
loading: false,
status: ''
}
},
methods: {
pipInstall() {
this.loading = true;
axios.post('/api/update/pip-install', this.pipInstallPayload)
.then(res => {
this.status = res.data.message;
setTimeout(() => {
this.status = '';
this.pipDialog = false;
}, 2000);
})
.catch(err => {
this.status = err.response.data.message;
}).finally(() => {
this.loading = false;
});
}
}
} }
</script> </script>
@@ -26,10 +84,12 @@ export default {
from { from {
opacity: 0; opacity: 0;
} }
to { to {
opacity: 1; opacity: 1;
} }
} }
.fade-in { .fade-in {
animation: fadeIn 0.2s ease-in-out; animation: fadeIn 0.2s ease-in-out;
} }
+3 -3
View File
@@ -9,8 +9,8 @@ import axios from 'axios';
<template> <template>
<v-row> <v-row>
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以前往 配置->其他配置->插件仓库镜像 修改安装镜像源。2. 如需插件帮助请点击 `仓库` 查看 README" <v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README"
title="💡提示" type="info" variant="tonal"> title="💡提示" type="info" variant="tonal">
</v-alert> </v-alert>
<v-col cols="12" md="12"> <v-col cols="12" md="12">
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;"> <div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
@@ -80,7 +80,7 @@ import axios from 'axios';
</v-card> </v-card>
</v-dialog> </v-dialog>
<v-dialog v-model="dialog" persistent width="700"> <v-dialog v-model="dialog" width="700">
<template v-slot:activator="{ props }"> <template v-slot:activator="{ props }">
<v-btn v-bind="props" icon="mdi-plus" size="x-large" style="position: fixed; right: 52px; bottom: 52px;" <v-btn v-bind="props" icon="mdi-plus" size="x-large" style="position: fixed; right: 52px; bottom: 52px;"
color="darkprimary"> color="darkprimary">
-2
View File
@@ -2,8 +2,6 @@ import os
import asyncio import asyncio
import sys import sys
import mimetypes import mimetypes
import aiohttp
import zipfile
from astrbot.dashboard import AstrBotDashBoardLifecycle from astrbot.dashboard import AstrBotDashBoardLifecycle
from astrbot.core import db_helper from astrbot.core import db_helper
from astrbot.core import logger, LogManager, LogBroker from astrbot.core import logger, LogManager, LogBroker
+33 -6
View File
@@ -56,6 +56,10 @@ class Main(star.Star):
/persona: 情境人格设置 /persona: 情境人格设置
/tool ls: 查看、激活、停用当前注册的函数工具 /tool ls: 查看、激活、停用当前注册的函数工具
[其他]
/set <变量名> <值>: 为当前会话定义一个变量。适用于 Dify 工作流输入。
/unset <变量名>: 删除当前会话的变量。
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。 提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
{notice}""" {notice}"""
@@ -345,7 +349,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}")) MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
@filter.permission_type(filter.PermissionType.ADMIN) @filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dashboard update") @filter.command("dashboard_update")
async def update_dashboard(self, event: AstrMessageEvent): async def update_dashboard(self, event: AstrMessageEvent):
yield event.plain_result("正在尝试更新管理面板...") yield event.plain_result("正在尝试更新管理面板...")
await download_dashboard() await download_dashboard()
@@ -365,12 +369,35 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}" req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
if provider.curr_personality['prompt']: if provider.curr_personality['prompt']:
req.system_prompt += f"\n{provider.curr_personality['prompt']}" req.system_prompt += f"\n{provider.curr_personality['prompt']}"
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
async def other_message(self, event: AstrMessageEvent):
print("triggered")
event.stop_event()
@filter.command("set")
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
session_id = event.get_session_id()
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
session_var[key] = value
session_vars[session_id] = session_var
sp.put("session_variables", session_vars)
yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。")
@filter.command("unset")
async def unset_variable(self, event: AstrMessageEvent, key: str):
session_id = event.get_session_id()
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
if key not in session_var:
yield event.plain_result("没有那个变量名。")
else:
del session_var[key]
sp.put("session_variables", session_vars)
yield event.plain_result(f"会话 {session_id} 变量 {key} 移除成功。")
@filter.command_group("kdb") @filter.command_group("kdb")
def kdb(self): def kdb(self):
pass pass
+33
View File
@@ -0,0 +1,33 @@
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
batch_size=16, # batch size for inference - set based on your device
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])