Compare commits
22 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 94618e8feb | |||
| 55de7d4494 | |||
| 7ed639f741 | |||
| 41f2870c29 | |||
| ba198490fa | |||
| 0f9ab082ab | |||
| 97b58965f2 | |||
| f2566c68e3 | |||
| a456bf5449 | |||
| a09998f910 | |||
| be662b913c | |||
| e7ddc8448d | |||
| 29374f8d8a | |||
| 359b971103 | |||
| fbdb1ae208 | |||
| 22c13c1eff | |||
| 5fc63aeaf1 | |||
| d4f32673ab | |||
| 480dffb51b | |||
| 966df00124 | |||
| 3e2b4bc727 | |||
| 5929a8d42b |
+2
-1
@@ -20,4 +20,5 @@ chroma
|
||||
node_modules/
|
||||
.DS_Store
|
||||
package-lock.json
|
||||
package.json
|
||||
package.json
|
||||
venv/*
|
||||
@@ -35,6 +35,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
3. 支持 LLMTuner 载入微调模型。
|
||||
4. 支持 Ollama 载入自部署模型。
|
||||
4. 支持网页搜索(Web Search)、自然语言待办提醒。
|
||||
5. 支持 Whisper 语音转文字
|
||||
|
||||
## ✨ 管理面板
|
||||
|
||||
@@ -45,45 +46,21 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
|
||||
## ✨ 支持 Dify
|
||||
|
||||
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流!
|
||||
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流
|
||||
|
||||
## ✨ Demo
|
||||
## ✨ 代码执行器(Beta)
|
||||
|
||||
基于 Docker 的沙箱化代码执行器(Beta 测试中)
|
||||
|
||||
> [!NOTE]
|
||||
> 文件输入/输出目前仅支持 Napcat(QQ)
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
|
||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||
|
||||
_✨ 自然语言待办事项 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||
|
||||
_✨ 插件系统——部分插件展示 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/caadf2bd-a0ee-43d0-a95e-566d63e3e34d" height=330>
|
||||
<img src="https://github.com/user-attachments/assets/b418f281-e920-49db-9fe1-d6a13ce28a84" height=350>
|
||||
|
||||
_✨ 管理面板 ✨_
|
||||
|
||||
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<!-- ## ✨ ATRI [Beta 测试]
|
||||
|
||||
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
||||
|
||||
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
|
||||
2. 长期记忆
|
||||
3. 表情包理解与回复
|
||||
4. TTS
|
||||
-->
|
||||
## ✨ 云部署
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
@@ -104,3 +81,44 @@ _✨ 管理面板 ✨_
|
||||
- Star 这个项目!
|
||||
- 在[爱发电](https://afdian.com/a/soulter)支持我!
|
||||
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
|
||||
|
||||
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
|
||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||
|
||||
_✨ 自然语言待办事项 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||
|
||||
_✨ 插件系统——部分插件展示 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
|
||||
|
||||
_✨ 管理面板 ✨_
|
||||
|
||||

|
||||
|
||||
_✨ 内置 Web Chat,在线与机器人交互 ✨_
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
<!-- ## ✨ ATRI [Beta 测试]
|
||||
|
||||
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
|
||||
|
||||
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
|
||||
2. 长期记忆
|
||||
3. 表情包理解与回复
|
||||
4. TTS
|
||||
-->
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from astrbot.core.provider import Provider, Personality, ProviderMetaData
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core.provider import Provider, STTProvider, Personality
|
||||
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData
|
||||
@@ -1,12 +1,16 @@
|
||||
import os
|
||||
import asyncio
|
||||
from .log import LogManager, LogBroker
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.utils.pip_installer import PipInstaller
|
||||
from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
os.makedirs("data", exist_ok=True)
|
||||
|
||||
astrbot_config = AstrBotConfig()
|
||||
html_renderer = HtmlRenderer()
|
||||
logger = LogManager.GetLogger(log_name='astrbot')
|
||||
|
||||
@@ -15,4 +19,7 @@ if os.environ.get('TESTING', ""):
|
||||
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
sp = SharedPreferences() # 简单的偏好设置存储
|
||||
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
|
||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.4.3"
|
||||
VERSION = "3.4.5"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -33,6 +33,10 @@ DEFAULT_CONFIG = {
|
||||
"default_personality": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
|
||||
"prompt_prefix": "",
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
},
|
||||
"content_safety": {
|
||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
|
||||
@@ -315,9 +319,30 @@ CONFIG_METADATA_2 = {
|
||||
"dify_api_key": "",
|
||||
"dify_api_base": "https://api.dify.ai/v1",
|
||||
"dify_workflow_output_key": "",
|
||||
},
|
||||
"whisper(API)": {
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_api",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"model": "whisper-1",
|
||||
},
|
||||
"whisper(本地加载)": {
|
||||
"whisper_hint": "(不用修改我)",
|
||||
"enable": False,
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"model": "tiny",
|
||||
}
|
||||
},
|
||||
"items": {
|
||||
"whisper_hint": {
|
||||
"description": "本地部署 Whisper 模型须知",
|
||||
"type": "string",
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"obvious_hint": True
|
||||
},
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
@@ -416,7 +441,8 @@ CONFIG_METADATA_2 = {
|
||||
"enable": {
|
||||
"description": "启用大语言模型聊天",
|
||||
"type": "bool",
|
||||
"hint": "是否启用大语言模型聊天。默认启用",
|
||||
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
|
||||
"obvious_hint": True
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
@@ -450,6 +476,23 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"description": "语音转文本(STT)",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"description": "启用语音转文本(STT)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
|
||||
"obvious_hint": True
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个STT提供商",
|
||||
"type": "string",
|
||||
"hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"misc_config_group": {
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
import threading
|
||||
import os
|
||||
from .event_bus import EventBus
|
||||
from . import astrbot_config
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
@@ -21,7 +22,7 @@ from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
class AstrBotCoreLifecycle:
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
self.log_broker = log_broker
|
||||
self.astrbot_config = AstrBotConfig()
|
||||
self.astrbot_config = astrbot_config
|
||||
self.db = db
|
||||
|
||||
if self.astrbot_config['http_proxy']:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision
|
||||
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
|
||||
|
||||
@dataclass
|
||||
class BaseDatabase(abc.ABC):
|
||||
@@ -76,4 +76,28 @@ class BaseDatabase(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
|
||||
'''通过 url 或 path 获取 ATRI 视觉数据'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
|
||||
'''通过 user_id 和 cid 获取 WebChatConversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def webchat_new_conversation(self, user_id: str, cid: str):
|
||||
'''新建 WebChatConversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
|
||||
'''更新 WebChatConversation'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete_webchat_conversation(self, user_id: str, cid: str):
|
||||
'''删除 WebChatConversation'''
|
||||
raise NotImplementedError
|
||||
+12
-1
@@ -51,4 +51,15 @@ class ATRIVision():
|
||||
platform_name: str
|
||||
session_id: str
|
||||
sender_nickname: str
|
||||
timestamp: int = -1
|
||||
timestamp: int = -1
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebChatConversation():
|
||||
user_id: str
|
||||
cid: str
|
||||
history: str = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
|
||||
@@ -5,7 +5,8 @@ from astrbot.core.db.po import (
|
||||
Platform,
|
||||
Stats,
|
||||
LLMHistory,
|
||||
ATRIVision
|
||||
ATRIVision,
|
||||
WebChatConversation
|
||||
)
|
||||
from . import BaseDatabase
|
||||
from typing import Tuple
|
||||
@@ -199,6 +200,69 @@ class SQLiteDatabase(BaseDatabase):
|
||||
c.close()
|
||||
|
||||
return Stats(platform, [], [])
|
||||
|
||||
|
||||
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
''', (user_id, cid)
|
||||
)
|
||||
|
||||
res = c.fetchone()
|
||||
c.close()
|
||||
return WebChatConversation(*res)
|
||||
|
||||
def webchat_new_conversation(self, user_id: str, cid: str):
|
||||
history = "[]"
|
||||
updated_at = int(time.time())
|
||||
created_at = updated_at
|
||||
self._exec_sql(
|
||||
'''
|
||||
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
|
||||
''', (user_id, cid, history, updated_at, created_at)
|
||||
)
|
||||
|
||||
def get_webchat_conversations(self, user_id: str) -> Tuple:
|
||||
try:
|
||||
c = self.conn.cursor()
|
||||
except sqlite3.ProgrammingError:
|
||||
c = self._get_conn(self.db_path).cursor()
|
||||
|
||||
c.execute(
|
||||
'''
|
||||
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
|
||||
''', (user_id,)
|
||||
)
|
||||
|
||||
res = c.fetchall()
|
||||
c.close()
|
||||
conversations = []
|
||||
for row in res:
|
||||
cid = row[0]
|
||||
created_at = row[1]
|
||||
updated_at = row[2]
|
||||
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
|
||||
return conversations
|
||||
|
||||
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
|
||||
''', (history, user_id, cid)
|
||||
)
|
||||
|
||||
def delete_webchat_conversation(self, user_id: str, cid: str):
|
||||
self._exec_sql(
|
||||
'''
|
||||
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
|
||||
''', (user_id, cid)
|
||||
)
|
||||
|
||||
|
||||
def insert_atri_vision_data(self, vision: ATRIVision):
|
||||
|
||||
@@ -35,4 +35,12 @@ CREATE TABLE IF NOT EXISTS atri_vision(
|
||||
session_id VARCHAR(32),
|
||||
sender_nickname VARCHAR(32),
|
||||
timestamp INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS webchat_conversation(
|
||||
user_id TEXT,
|
||||
cid TEXT,
|
||||
history TEXT,
|
||||
created_at INTEGER,
|
||||
updated_at INTEGER
|
||||
);
|
||||
@@ -123,7 +123,7 @@ class Record(BaseMessageComponent):
|
||||
proxy: T.Optional[bool] = True
|
||||
timeout: T.Optional[int] = 0
|
||||
# 额外
|
||||
path: T.Optional[str]
|
||||
path: T.Optional[str] # 用这个
|
||||
|
||||
def __init__(self, file: T.Optional[str], **_):
|
||||
for k in _.keys():
|
||||
|
||||
@@ -3,6 +3,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR
|
||||
from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .respond.stage import RespondStage
|
||||
@@ -12,6 +13,7 @@ STAGES_ORDER = [
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"RateLimitCheckStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PreProcessStage", # 预处理
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
"RespondStage" # 发送消息
|
||||
@@ -21,6 +23,7 @@ __all__ = [
|
||||
"WakingCheckStage",
|
||||
"WhitelistCheckStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PreProcessStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
"RespondStage",
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
import traceback
|
||||
import asyncio
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Record
|
||||
|
||||
@register_stage
|
||||
class PreProcessStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.plugin_manager = ctx.plugin_manager
|
||||
|
||||
self.stt_settings: dict = self.config.get('provider_stt_settings', {})
|
||||
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
'''在处理事件之前的预处理'''
|
||||
|
||||
if self.stt_settings.get('enable', False):
|
||||
# STT 处理
|
||||
# TODO: 独立
|
||||
stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||
if stt_provider:
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.path:
|
||||
|
||||
path = component.path
|
||||
|
||||
retry = 5
|
||||
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
if result:
|
||||
logger.info("语音转文本结果: " + result)
|
||||
message_chain[idx] = Plain(result)
|
||||
event.message_str += result
|
||||
event.message_obj.message_str += result
|
||||
break
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"语音转文本失败: {e}")
|
||||
break
|
||||
@@ -41,4 +41,8 @@ class PipelineScheduler():
|
||||
async def execute(self, event: AstrMessageEvent):
|
||||
'''执行 pipeline'''
|
||||
await self._process_stages(event)
|
||||
|
||||
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
||||
await event.send(None)
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
@@ -20,6 +20,10 @@ class WhitelistCheckStage(Stage):
|
||||
if not self.enable_whitelist_check:
|
||||
return
|
||||
|
||||
if event.get_platform_name() == 'webchat':
|
||||
# WebChat 豁免
|
||||
return
|
||||
|
||||
# 检查是否在白名单
|
||||
if self.wl_ignore_admin_on_group:
|
||||
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List
|
||||
from asyncio import Queue
|
||||
from .register import platform_cls_map
|
||||
from astrbot.core import logger
|
||||
|
||||
from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||
|
||||
class PlatformManager():
|
||||
def __init__(self, config: AstrBotConfig, event_queue: Queue):
|
||||
@@ -25,6 +25,7 @@ class PlatformManager():
|
||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
|
||||
case "vchat":
|
||||
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
for platform in self.platforms_config:
|
||||
@@ -37,6 +38,8 @@ class PlatformManager():
|
||||
logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
|
||||
inst = cls_type(platform, self.settings, self.event_queue)
|
||||
self.platform_insts.append(inst)
|
||||
|
||||
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
|
||||
|
||||
def get_insts(self):
|
||||
return self.platform_insts
|
||||
@@ -0,0 +1,110 @@
|
||||
import time
|
||||
import asyncio
|
||||
import uuid
|
||||
import os
|
||||
from typing import Awaitable, Any
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.message_components import Plain, Image, Record # noqa: F403
|
||||
from astrbot.api import logger
|
||||
from astrbot.core import web_chat_queue, web_chat_back_queue
|
||||
from .webchat_event import WebChatMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
|
||||
class QueueListener:
|
||||
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
|
||||
self.queue = queue
|
||||
self.callback = callback
|
||||
|
||||
async def run(self):
|
||||
while True:
|
||||
data = await self.queue.get()
|
||||
await self.callback(data)
|
||||
|
||||
@register_platform_adapter("webchat", "webchat")
|
||||
class WebChatAdapter(Platform):
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings['unique_session']
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
"webchat",
|
||||
"webchat",
|
||||
)
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
plain = ""
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
plain += comp.text
|
||||
web_chat_back_queue.put_nowait(plain)
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
async def convert_message(self, data: tuple) -> AstrBotMessage:
|
||||
username, cid, payload = data
|
||||
|
||||
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = "webchat"
|
||||
abm.tag = "webchat"
|
||||
abm.sender = MessageMember(username, username)
|
||||
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
|
||||
abm.session_id = f"webchat!{username}!{cid}"
|
||||
|
||||
abm.message_id = str(uuid.uuid4())
|
||||
abm.message = []
|
||||
|
||||
if payload['message']:
|
||||
abm.message.append(Plain(payload['message']))
|
||||
if payload['image_url']:
|
||||
if isinstance(payload['image_url'], list):
|
||||
for img in payload['image_url']:
|
||||
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img)))
|
||||
else:
|
||||
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url'])))
|
||||
if payload['audio_url']:
|
||||
if isinstance(payload['audio_url'], list):
|
||||
for audio in payload['audio_url']:
|
||||
path = os.path.join(self.imgs_dir, audio)
|
||||
abm.message.append(Record(file=path, path=path))
|
||||
else:
|
||||
path = os.path.join(self.imgs_dir, payload['audio_url'])
|
||||
abm.message.append(Record(file=path, path=path))
|
||||
|
||||
logger.debug(f"WebChatAdapter: {abm.message}")
|
||||
|
||||
message_str = payload['message']
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
abm.raw_message = data
|
||||
return abm
|
||||
|
||||
def run(self) -> Awaitable[Any]:
|
||||
async def callback(data: tuple):
|
||||
abm = await self.convert_message(data)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
bot = QueueListener(web_chat_queue, callback)
|
||||
return bot.run()
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return self.metadata
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
|
||||
message_event = WebChatMessageEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
@@ -0,0 +1,35 @@
|
||||
import os
|
||||
import uuid
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from astrbot.core import web_chat_back_queue
|
||||
|
||||
class WebChatMessageEvent(AstrMessageEvent):
|
||||
def __init__(self, message_str, message_obj, platform_meta, session_id):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
if not message:
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
return
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
web_chat_back_queue.put_nowait(comp.text)
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
path = os.path.join(self.imgs_dir, filename)
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
ph = comp.file[8:]
|
||||
with open(path, "wb") as f:
|
||||
with open(ph, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
web_chat_back_queue.put_nowait(f"[IMAGE]{filename}")
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
await super().send(message)
|
||||
@@ -1,4 +1,4 @@
|
||||
from .provider import Provider, Personality
|
||||
from .provider import Provider, Personality, STTProvider
|
||||
|
||||
from .entites import ProviderMetaData
|
||||
|
||||
@@ -6,4 +6,5 @@ __all__ = [
|
||||
"Provider",
|
||||
"Personality",
|
||||
"ProviderMetaData",
|
||||
"STTProvider"
|
||||
]
|
||||
@@ -1,13 +1,22 @@
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Type
|
||||
from .func_tool_manager import FuncCall
|
||||
|
||||
|
||||
class ProviderType(enum.Enum):
|
||||
CHAT_COMPLETION = "chat_completion"
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
|
||||
@dataclass
|
||||
class ProviderMetaData():
|
||||
type: str # 提供商适配器名称,如 openai, ollama
|
||||
desc: str = "" # 提供商适配器描述.
|
||||
|
||||
type: str
|
||||
'''提供商适配器名称,如 openai, ollama'''
|
||||
desc: str = ""
|
||||
'''提供商适配器描述.'''
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
cls_type: Type = None
|
||||
|
||||
@dataclass
|
||||
class ProviderRequest():
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import traceback
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .provider import Provider
|
||||
from .provider import Provider, STTProvider
|
||||
from .entites import ProviderType
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from collections import defaultdict
|
||||
@@ -11,10 +12,17 @@ class ProviderManager():
|
||||
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
|
||||
self.providers_config: List = config['provider']
|
||||
self.provider_settings: dict = config['provider_settings']
|
||||
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
'''加载的 Provider 的实例'''
|
||||
self.stt_provider_insts: List[STTProvider] = []
|
||||
'''加载的 Speech To Text Provider 的实例'''
|
||||
self.llm_tools = llm_tools
|
||||
self.curr_provider_inst: Provider = None
|
||||
'''当前使用的 Provider 实例'''
|
||||
self.curr_stt_provider_inst: STTProvider = None
|
||||
'''当前使用的 Speech To Text Provider 实例'''
|
||||
self.loaded_ids = defaultdict(bool)
|
||||
self.db_helper = db_helper
|
||||
|
||||
@@ -31,19 +39,29 @@ class ProviderManager():
|
||||
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}。")
|
||||
self.loaded_ids[provider_cfg['id']] = True
|
||||
|
||||
match provider_cfg['type']:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu # noqa: F401
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify # noqa: F401
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
|
||||
|
||||
try:
|
||||
match provider_cfg['type']:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu # noqa: F401
|
||||
case "llm_tuner":
|
||||
logger.info("加载 LLM Tuner 工具 ...")
|
||||
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
|
||||
case "dify":
|
||||
from .sources.dify_source import ProviderDify # noqa: F401
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
|
||||
case "openai_whisper_api":
|
||||
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
|
||||
continue
|
||||
|
||||
async def initialize(self):
|
||||
for provider_config in self.providers_config:
|
||||
@@ -53,23 +71,54 @@ class ProviderManager():
|
||||
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
|
||||
continue
|
||||
selected_provider_id = sp.get("curr_provider")
|
||||
cls_type = provider_cls_map[provider_config['type']]
|
||||
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
|
||||
provider_enabled = self.provider_settings.get("enable", False)
|
||||
stt_enabled = self.provider_stt_settings.get("enable", False)
|
||||
|
||||
provider_metadata = provider_cls_map[provider_config['type']]
|
||||
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
|
||||
try:
|
||||
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
|
||||
self.provider_insts.append(inst)
|
||||
if selected_provider_id == provider_config['id']:
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
|
||||
# 按任务实例化提供商
|
||||
|
||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||
# STT 任务
|
||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.stt_provider_insts.append(inst)
|
||||
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
|
||||
self.curr_stt_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
|
||||
|
||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||
# 文本生成任务
|
||||
inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
|
||||
|
||||
if getattr(inst, "initialize", None):
|
||||
await inst.initialize()
|
||||
|
||||
self.provider_insts.append(inst)
|
||||
if selected_provider_id == provider_config['id'] and provider_enabled:
|
||||
self.curr_provider_inst = inst
|
||||
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
|
||||
|
||||
if len(self.provider_insts) > 0 and not self.curr_provider_inst:
|
||||
if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
|
||||
if not self.curr_provider_inst:
|
||||
logger.warning("未启用任何提供商适配器。")
|
||||
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
|
||||
if self.provider_stt_settings.get("enable"):
|
||||
if not self.curr_stt_provider_inst:
|
||||
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
|
||||
@@ -125,6 +125,33 @@ class Provider(abc.ABC):
|
||||
'''重置某一个 session_id 的上下文'''
|
||||
raise NotImplementedError()
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
id=self.provider_config['id'],
|
||||
model=self.get_model(),
|
||||
type=self.provider_config['type']
|
||||
)
|
||||
|
||||
|
||||
class STTProvider():
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
'''获取音频的文本'''
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_model(self, model_name: str):
|
||||
'''设置当前使用的模型名称'''
|
||||
self.model_name = model_name
|
||||
|
||||
def get_model(self) -> str:
|
||||
'''获取当前使用的模型'''
|
||||
return self.provider_config.get("model", "")
|
||||
|
||||
def meta(self) -> ProviderMeta:
|
||||
'''获取 Provider 的元数据'''
|
||||
return ProviderMeta(
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
from typing import List, Dict, Type
|
||||
from .entites import ProviderMetaData
|
||||
from .entites import ProviderMetaData, ProviderType
|
||||
from astrbot.core import logger
|
||||
from .func_tool_manager import FuncCall
|
||||
|
||||
provider_registry: List[ProviderMetaData] = []
|
||||
'''维护了通过装饰器注册的 Provider'''
|
||||
provider_cls_map: Dict[str, Type] = {}
|
||||
'''维护了 Provider 类型名称和 Provider 类的映射'''
|
||||
provider_cls_map: Dict[str, ProviderMetaData] = {}
|
||||
'''维护了 Provider 类型名称和 ProviderMetadata 的映射'''
|
||||
|
||||
llm_tools = FuncCall()
|
||||
|
||||
def register_provider_adapter(provider_type_name: str, desc: str):
|
||||
def register_provider_adapter(
|
||||
provider_type_name: str,
|
||||
desc: str,
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
):
|
||||
'''用于注册平台适配器的带参装饰器'''
|
||||
def decorator(cls):
|
||||
if provider_type_name in provider_cls_map:
|
||||
@@ -19,9 +23,11 @@ def register_provider_adapter(provider_type_name: str, desc: str):
|
||||
pm = ProviderMetaData(
|
||||
type=provider_type_name,
|
||||
desc=desc,
|
||||
provider_type=provider_type,
|
||||
cls_type=cls
|
||||
)
|
||||
provider_registry.append(pm)
|
||||
provider_cls_map[provider_type_name] = cls
|
||||
provider_cls_map[provider_type_name] = pm
|
||||
logger.debug(f"Provider {provider_type_name} 已注册")
|
||||
return cls
|
||||
|
||||
|
||||
@@ -6,8 +6,7 @@ from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import logger
|
||||
|
||||
from astrbot.core import logger, sp
|
||||
|
||||
@register_provider_adapter("dify", "Dify APP 适配器。")
|
||||
class ProviderDify(Provider):
|
||||
@@ -67,10 +66,16 @@ class ProviderDify(Provider):
|
||||
|
||||
logger.debug(files_payload)
|
||||
|
||||
# 获得会话变量
|
||||
session_vars = sp.get("session_variables", {})
|
||||
session_var = session_vars.get(session_id, {})
|
||||
|
||||
match self.api_type:
|
||||
case "chat" | "agent":
|
||||
async for chunk in self.api_client.chat_messages(
|
||||
inputs={},
|
||||
inputs={
|
||||
**session_var
|
||||
},
|
||||
query=prompt,
|
||||
user=session_id,
|
||||
conversation_id=conversation_id,
|
||||
@@ -88,7 +93,8 @@ class ProviderDify(Provider):
|
||||
async for chunk in self.api_client.workflow_run(
|
||||
inputs={
|
||||
"astrbot_text_query": prompt,
|
||||
"astrbot_session_id": session_id
|
||||
"astrbot_session_id": session_id,
|
||||
**session_var
|
||||
},
|
||||
user=session_id,
|
||||
files=files_payload
|
||||
|
||||
@@ -209,6 +209,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
if image_url.startswith("file:///"):
|
||||
image_url = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
|
||||
return user_content
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
import uuid
|
||||
import os
|
||||
import io
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import STTProvider
|
||||
from ..entites import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
|
||||
@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT)
|
||||
class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key = provider_config.get("api_key", "")
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base", None),
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def _convert_audio(self, path: str) -> str:
|
||||
from pyffmpeg import FFmpeg
|
||||
filename = str(uuid.uuid4()) + '.mp3'
|
||||
ff = FFmpeg()
|
||||
output_path = ff.convert(path, os.path.join('data/temp', filename))
|
||||
return output_path
|
||||
|
||||
async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
|
||||
import wave
|
||||
|
||||
with wave.open(output_path, 'wb') as wav:
|
||||
wav.setnchannels(1)
|
||||
wav.setsampwidth(2)
|
||||
wav.setframerate(24000)
|
||||
wav.writeframes(input_io.read())
|
||||
|
||||
return output_path
|
||||
|
||||
async def _convert_silk(self, path: str) -> str:
|
||||
import pysilk
|
||||
filename = str(uuid.uuid4()) + '.wav'
|
||||
output_path = os.path.join('data/temp', filename)
|
||||
with open(path, "rb") as f:
|
||||
input_data = f.read()
|
||||
if input_data.startswith(b'\x02'):
|
||||
# tencent 我爱你
|
||||
input_data = input_data[1:]
|
||||
input_io = io.BytesIO(input_data)
|
||||
output_io = io.BytesIO()
|
||||
pysilk.decode(input_io, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
await self._pcm_to_wav(output_io, output_path)
|
||||
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
|
||||
if silk_header in file_header:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
'''only supports mp3, mp4, mpeg, m4a, wav, webm'''
|
||||
if audio_url.startswith("http"):
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
audio_url = await download_file(audio_url, path)
|
||||
|
||||
if not os.path.exists(audio_url):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk"):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
audio_url = await self._convert_silk(audio_url)
|
||||
|
||||
|
||||
result = await self.client.audio.transcriptions.create(
|
||||
model=self.model_name,
|
||||
file=open(audio_url, "rb"),
|
||||
)
|
||||
return result.text
|
||||
@@ -0,0 +1,92 @@
|
||||
import uuid
|
||||
import os
|
||||
import io
|
||||
import asyncio
|
||||
import whisper
|
||||
from ..provider import STTProvider
|
||||
from ..entites import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT)
|
||||
class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.set_model(provider_config.get("model", None))
|
||||
self.model = None
|
||||
|
||||
async def initialize(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
|
||||
self.model = await loop.run_in_executor(None, whisper.load_model, self.model_name)
|
||||
logger.info("Whisper 模型加载完成。")
|
||||
|
||||
async def _convert_audio(self, path: str) -> str:
|
||||
from pyffmpeg import FFmpeg
|
||||
filename = str(uuid.uuid4()) + '.mp3'
|
||||
ff = FFmpeg()
|
||||
output_path = ff.convert(path, os.path.join('data/temp', filename))
|
||||
return output_path
|
||||
|
||||
async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
|
||||
import wave
|
||||
|
||||
with wave.open(output_path, 'wb') as wav:
|
||||
wav.setnchannels(1)
|
||||
wav.setsampwidth(2)
|
||||
wav.setframerate(24000)
|
||||
wav.writeframes(input_io.read())
|
||||
|
||||
return output_path
|
||||
|
||||
async def _convert_silk(self, path: str) -> str:
|
||||
import pysilk
|
||||
filename = str(uuid.uuid4()) + '.wav'
|
||||
output_path = os.path.join('data/temp', filename)
|
||||
with open(path, "rb") as f:
|
||||
input_data = f.read()
|
||||
if input_data.startswith(b'\x02'):
|
||||
# tencent 我爱你
|
||||
input_data = input_data[1:]
|
||||
input_io = io.BytesIO(input_data)
|
||||
output_io = io.BytesIO()
|
||||
pysilk.decode(input_io, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
await self._pcm_to_wav(output_io, output_path)
|
||||
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
file_header = f.read(8)
|
||||
|
||||
if silk_header in file_header:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def get_text(self, audio_url: str) -> str:
|
||||
loop = asyncio.get_event_loop()
|
||||
if audio_url.startswith("http"):
|
||||
name = str(uuid.uuid4())
|
||||
path = os.path.join("data/temp", name)
|
||||
audio_url = await download_file(audio_url, path)
|
||||
|
||||
if not os.path.exists(audio_url):
|
||||
raise FileNotFoundError(f"文件不存在: {audio_url}")
|
||||
|
||||
if audio_url.endswith(".amr") or audio_url.endswith(".silk"):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
audio_url = await self._convert_silk(audio_url)
|
||||
|
||||
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
|
||||
return result['text']
|
||||
@@ -17,10 +17,6 @@ from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
|
||||
|
||||
class StarCommand(TypedDict):
|
||||
full_command_name: str
|
||||
command_name: str
|
||||
|
||||
class Context:
|
||||
'''
|
||||
暴露给插件的接口上下文。
|
||||
@@ -168,13 +164,13 @@ class Context:
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
'''
|
||||
注册一个 LLM Provider。
|
||||
注册一个 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
self.provider_manager.provider_insts.append(provider)
|
||||
|
||||
def get_provider_by_id(self, provider_id: str) -> Provider:
|
||||
'''
|
||||
通过 ID 获取 LLM Provider。
|
||||
通过 ID 获取 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
for provider in self.provider_manager.provider_insts:
|
||||
if provider.meta().id == provider_id:
|
||||
@@ -183,13 +179,13 @@ class Context:
|
||||
|
||||
def get_all_providers(self) -> List[Provider]:
|
||||
'''
|
||||
获取所有 LLM Provider。
|
||||
获取所有 LLM Provider(Chat_Completion 类型)。
|
||||
'''
|
||||
return self.provider_manager.provider_insts
|
||||
|
||||
def get_using_provider(self) -> Provider:
|
||||
'''
|
||||
获取当前使用的 LLM Provider。
|
||||
获取当前使用的 LLM Provider(Chat_Completion 类型)。
|
||||
|
||||
通过 /provider 指令切换。
|
||||
'''
|
||||
|
||||
@@ -9,7 +9,7 @@ from types import ModuleType
|
||||
from typing import List
|
||||
from pip import main as pip_main
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core import logger, sp, pip_installer
|
||||
from .context import Context
|
||||
from . import StarMetadata
|
||||
from .updator import PluginUpdator
|
||||
@@ -92,21 +92,12 @@ class PluginManager:
|
||||
plugin_path = os.path.join(plugin_dir, p)
|
||||
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
|
||||
pth = os.path.join(plugin_path, "requirements.txt")
|
||||
logger.info(f"正在检查插件 {p} 的依赖: {pth}")
|
||||
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
|
||||
try:
|
||||
self._update_plugin_dept(os.path.join(plugin_path, "requirements.txt"))
|
||||
pip_installer.install(requirements_path=pth)
|
||||
except Exception as e:
|
||||
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
|
||||
|
||||
def _update_plugin_dept(self, path):
|
||||
'''更新插件的依赖'''
|
||||
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
|
||||
if self.config.pip_install_arg:
|
||||
args.extend([self.config.pip_install_arg])
|
||||
result_code = pip_main(args)
|
||||
if result_code != 0:
|
||||
raise Exception(str(result_code))
|
||||
|
||||
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata:
|
||||
'''v3.4.0 以前的方式载入插件元数据
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from astrbot.core import logger
|
||||
from aiohttp import ClientSession
|
||||
from typing import Dict, List, Any, AsyncGenerator
|
||||
|
||||
@@ -29,11 +30,18 @@ class DifyAPIClient:
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
async for data in resp.content:
|
||||
while True:
|
||||
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
|
||||
if not data:
|
||||
break
|
||||
if not data.strip():
|
||||
continue
|
||||
if data.startswith(b"data:"):
|
||||
yield json.loads(data[5:])
|
||||
elif data.startswith(b"data:"):
|
||||
try:
|
||||
json_ = json.loads(data[5:])
|
||||
yield json_
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
async def workflow_run(
|
||||
self,
|
||||
@@ -50,11 +58,18 @@ class DifyAPIClient:
|
||||
async with self.session.post(
|
||||
url, json=payload, headers=self.headers, timeout=timeout
|
||||
) as resp:
|
||||
async for data in resp.content:
|
||||
while True:
|
||||
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
|
||||
if not data:
|
||||
break
|
||||
if not data.strip():
|
||||
continue
|
||||
if data.startswith(b"data:"):
|
||||
yield json.loads(data[5:])
|
||||
elif data.startswith(b"data:"):
|
||||
try:
|
||||
json_ = json.loads(data[5:])
|
||||
yield json_
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
async def file_upload(
|
||||
self,
|
||||
@@ -70,9 +85,6 @@ class DifyAPIClient:
|
||||
url, data=payload, headers=self.headers
|
||||
) as resp:
|
||||
return await resp.json() # {"id": "xxx", ...}
|
||||
|
||||
|
||||
|
||||
|
||||
async def close(self):
|
||||
await self.session.close()
|
||||
@@ -65,7 +65,7 @@ def save_temp_img(img: Image) -> str:
|
||||
f.write(img)
|
||||
return p
|
||||
|
||||
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str:
|
||||
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None, path = None) -> str:
|
||||
'''
|
||||
下载图片, 返回 path
|
||||
'''
|
||||
@@ -73,10 +73,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if post:
|
||||
async with session.post(url, json=post_data) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
else:
|
||||
async with session.get(url) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
if not path:
|
||||
return save_temp_img(await resp.read())
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
except aiohttp.client_exceptions.ClientConnectorSSLError:
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
import logging
|
||||
from pip import main as pip_main
|
||||
|
||||
class PipInstaller():
|
||||
def __init__(self, pip_install_arg: str):
|
||||
self.pip_install_arg = pip_install_arg
|
||||
|
||||
def install(self, package_name: str = None, requirements_path: str = None, mirror: str = None):
|
||||
args = ['install']
|
||||
if package_name:
|
||||
args.append(package_name)
|
||||
elif requirements_path:
|
||||
args.extend(['-r', requirements_path])
|
||||
|
||||
if not mirror:
|
||||
mirror = 'https://mirrors.aliyun.com/pypi/simple/'
|
||||
|
||||
args.extend(['--trusted-host', 'mirrors.aliyun.com', '-i', mirror])
|
||||
|
||||
if self.pip_install_arg:
|
||||
args.extend(self.pip_install_arg.split())
|
||||
|
||||
print(f"Pip 包管理器: {' '.join(args)}")
|
||||
|
||||
result_code = pip_main(args)
|
||||
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
if result_code != 0:
|
||||
raise Exception(f"安装失败,错误码:{result_code}")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from .server import AstrBotDashboard
|
||||
@@ -13,8 +14,16 @@ class AstrBotDashBoardLifecycle:
|
||||
|
||||
async def start(self):
|
||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||
await core_lifecycle.initialize()
|
||||
core_task = core_lifecycle.start()
|
||||
|
||||
core_task = []
|
||||
try:
|
||||
await core_lifecycle.initialize()
|
||||
core_task = core_lifecycle.start()
|
||||
except Exception as e:
|
||||
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
|
||||
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
|
||||
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
|
||||
|
||||
self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db)
|
||||
task = asyncio.gather(core_task, self.dashboard_server.run())
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from .update import UpdateRoute
|
||||
from .stat import StatRoute
|
||||
from .log import LogRoute
|
||||
from .static_file import StaticFileRoute
|
||||
from .chat import ChatRoute
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -14,6 +15,7 @@ __all__ = [
|
||||
"UpdateRoute",
|
||||
"StatRoute",
|
||||
"LogRoute",
|
||||
"StaticFileRoute"
|
||||
"StaticFileRoute",
|
||||
"ChatRoute",
|
||||
]
|
||||
|
||||
|
||||
@@ -0,0 +1,197 @@
|
||||
import uuid
|
||||
import json
|
||||
import os
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import web_chat_queue, web_chat_back_queue
|
||||
from quart import request, Response as QuartResponse, g
|
||||
from astrbot.core.db import BaseDatabase
|
||||
import asyncio
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
|
||||
class ChatRoute(Route):
|
||||
def __init__(self, context: RouteContext, db: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
'/chat/send': ('POST', self.chat),
|
||||
'/chat/new_conversation': ('GET', self.new_conversation),
|
||||
'/chat/conversations': ('GET', self.get_conversations),
|
||||
'/chat/get_conversation': ('GET', self.get_conversation),
|
||||
'/chat/delete_conversation': ('GET', self.delete_conversation),
|
||||
'/chat/get_file': ('GET', self.get_file),
|
||||
'/chat/post_image': ('POST', self.post_image),
|
||||
'/chat/post_file': ('POST', self.post_file),
|
||||
'/chat/status': ('GET', self.status),
|
||||
}
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.register_routes()
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
|
||||
self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp']
|
||||
|
||||
async def status(self):
|
||||
has_llm_enabled = self.core_lifecycle.provider_manager.curr_provider_inst is not None
|
||||
has_stt_enabled = self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None
|
||||
return Response().ok(data={
|
||||
'llm_enabled': has_llm_enabled,
|
||||
'stt_enabled': has_stt_enabled
|
||||
}).__dict__
|
||||
|
||||
async def get_file(self):
|
||||
filename = request.args.get('filename')
|
||||
if not filename:
|
||||
return Response().error("Missing key: filename").__dict__
|
||||
|
||||
try:
|
||||
with open(os.path.join(self.imgs_dir, filename), "rb") as f:
|
||||
if filename.endswith(".wav"):
|
||||
return QuartResponse(f.read(), mimetype="audio/wav")
|
||||
elif filename.split('.')[-1] in self.supported_imgs:
|
||||
return QuartResponse(f.read(), mimetype="image/jpeg")
|
||||
else:
|
||||
return QuartResponse(f.read())
|
||||
|
||||
except FileNotFoundError:
|
||||
return Response().error("File not found").__dict__
|
||||
|
||||
async def post_image(self):
|
||||
post_data = await request.files
|
||||
if 'file' not in post_data:
|
||||
return Response().error("Missing key: file").__dict__
|
||||
|
||||
file = post_data['file']
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
path = os.path.join(self.imgs_dir, filename)
|
||||
await file.save(path)
|
||||
|
||||
return Response().ok(data={
|
||||
'filename': filename
|
||||
}).__dict__
|
||||
|
||||
async def post_file(self):
|
||||
post_data = await request.files
|
||||
if 'file' not in post_data:
|
||||
return Response().error("Missing key: file").__dict__
|
||||
|
||||
file = post_data['file']
|
||||
filename = f"{str(uuid.uuid4())}"
|
||||
print(file)
|
||||
# 通过文件格式判断文件类型
|
||||
if file.content_type.startswith('audio'):
|
||||
filename += ".wav"
|
||||
|
||||
path = os.path.join(self.imgs_dir, filename)
|
||||
await file.save(path)
|
||||
|
||||
return Response().ok(data={
|
||||
'filename': filename
|
||||
}).__dict__
|
||||
|
||||
async def chat(self):
|
||||
username = g.get('username', 'guest')
|
||||
|
||||
post_data = await request.json
|
||||
if 'message' not in post_data and 'image_url' not in post_data:
|
||||
return Response().error("Missing key: message or image_url").__dict__
|
||||
|
||||
if 'conversation_id' not in post_data:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
message = post_data['message']
|
||||
conversation_id = post_data['conversation_id']
|
||||
image_url = post_data.get('image_url')
|
||||
audio_url = post_data.get('audio_url')
|
||||
if not message and not image_url and not audio_url:
|
||||
return Response().error("Message and image_url and audio_url are empty").__dict__
|
||||
if not conversation_id:
|
||||
return Response().error("conversation_id is empty").__dict__
|
||||
|
||||
await web_chat_queue.put((username, conversation_id, {
|
||||
'message': message,
|
||||
'image_url': image_url, # list
|
||||
'audio_url': audio_url
|
||||
}))
|
||||
|
||||
async def stream():
|
||||
ret = []
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=30) # 设置超时时间为5秒
|
||||
except asyncio.TimeoutError:
|
||||
yield '[Error] 30 秒内没有返回数据,已放弃。\n'
|
||||
return
|
||||
|
||||
if result is None:
|
||||
break
|
||||
|
||||
ret.append(result)
|
||||
|
||||
yield result + '\n'
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
history = []
|
||||
|
||||
new_his = {
|
||||
'type': 'user',
|
||||
'message': message
|
||||
}
|
||||
if image_url:
|
||||
new_his['image_url'] = image_url
|
||||
if audio_url:
|
||||
new_his['audio_url'] = audio_url
|
||||
history.append(new_his)
|
||||
for r in ret:
|
||||
history.append({
|
||||
'type': 'bot',
|
||||
'message': r
|
||||
})
|
||||
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
|
||||
|
||||
return QuartResponse(
|
||||
stream(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Content-Type": "text/event-stream",
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Connection": "keep-alive",
|
||||
"Access-Control-Allow-Origin": "*" # 如果是跨域请求
|
||||
}
|
||||
)
|
||||
|
||||
async def delete_conversation(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversation_id = request.args.get('conversation_id')
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
self.db.delete_webchat_conversation(username, conversation_id)
|
||||
return Response().ok().__dict__
|
||||
|
||||
async def new_conversation(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.webchat_new_conversation(username, conversation_id)
|
||||
return Response().ok(data={
|
||||
'conversation_id': conversation_id
|
||||
}).__dict__
|
||||
|
||||
async def get_conversations(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversations = self.db.get_webchat_conversations(username)
|
||||
return Response().ok(data=conversations).__dict__
|
||||
|
||||
async def get_conversation(self):
|
||||
username = g.get('username', 'guest')
|
||||
conversation_id = request.args.get('conversation_id')
|
||||
if not conversation_id:
|
||||
return Response().error("Missing key: conversation_id").__dict__
|
||||
|
||||
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
|
||||
return Response().ok(data=conversation).__dict__
|
||||
@@ -3,7 +3,7 @@ import traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import request
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger
|
||||
from astrbot.core import logger, pip_installer
|
||||
|
||||
class UpdateRoute(Route):
|
||||
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None:
|
||||
@@ -11,6 +11,7 @@ class UpdateRoute(Route):
|
||||
self.routes = {
|
||||
'/update/check': ('GET', self.check_update),
|
||||
'/update/do': ('POST', self.update_project),
|
||||
'/update/pip-install': ('POST', self.install_pip_package)
|
||||
}
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.register_routes()
|
||||
@@ -47,4 +48,16 @@ class UpdateRoute(Route):
|
||||
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
async def install_pip_package(self):
|
||||
data = await request.json
|
||||
package = data.get('package', '')
|
||||
if not package:
|
||||
return Response().error("缺少参数 package 或不合法。").__dict__
|
||||
try:
|
||||
pip_installer.install(package)
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_pip: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import jwt
|
||||
import asyncio
|
||||
import os
|
||||
from quart import Quart, request, jsonify
|
||||
from quart import Quart, request, jsonify, g
|
||||
from quart.logging import default_handler
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from .routes import *
|
||||
@@ -31,12 +31,15 @@ class AstrBotDashboard():
|
||||
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
|
||||
self.sfr = StaticFileRoute(self.context)
|
||||
self.ar = AuthRoute(self.context)
|
||||
self.chat_route = ChatRoute(self.context, db, core_lifecycle)
|
||||
|
||||
async def auth_middleware(self):
|
||||
if not request.path.startswith("/api"):
|
||||
return
|
||||
if request.path == "/api/auth/login":
|
||||
return
|
||||
if request.path == "/api/chat/get_file":
|
||||
return
|
||||
# claim jwt
|
||||
token = request.headers.get("Authorization")
|
||||
if not token:
|
||||
@@ -46,7 +49,8 @@ class AstrBotDashboard():
|
||||
if token.startswith("Bearer "):
|
||||
token = token[7:]
|
||||
try:
|
||||
jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
|
||||
payload = jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
|
||||
g.username = payload["username"]
|
||||
except jwt.ExpiredSignatureError:
|
||||
r = jsonify(Response().error("Token 过期").__dict__)
|
||||
r.status_code = 401
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
# What's Changed
|
||||
|
||||
1. 支持通过 /set <k> <v> 设置持久化的会话变量, 方便 Dify App 输入变量
|
||||
2. 管理面板支持 Web Chat
|
||||
3. 管理面板支持手动安装 Pip 库, 在 `控制台` 页中可找到
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
# What's Changed
|
||||
|
||||
- 支持接入 STT(语音转文字)Provider
|
||||
- 内置支持 OpenAI Whisper API/本地运行模型。[看这里](https://astrbot.lwl.lol/use/whisper.html)
|
||||
- WebChat 支持语音输入
|
||||
- WebChat 支持显示当前 Provider 状态
|
||||
- 优化了 WebChat 在没有消息返回时的处理方式
|
||||
- 修复了 reminder 在初始化历史待办时没有正常传入 session_id 的问题
|
||||
- 代码执行器在成功回复后清空文件 buffer。
|
||||
Generated
+17
@@ -18,6 +18,7 @@
|
||||
"date-fns": "2.30.0",
|
||||
"js-md5": "^0.8.3",
|
||||
"lodash": "4.17.21",
|
||||
"marked": "^15.0.6",
|
||||
"pinia": "2.1.6",
|
||||
"remixicon": "3.5.0",
|
||||
"vee-validate": "4.11.3",
|
||||
@@ -3702,6 +3703,17 @@
|
||||
"markdown-it": "bin/markdown-it.js"
|
||||
}
|
||||
},
|
||||
"node_modules/marked": {
|
||||
"version": "15.0.6",
|
||||
"resolved": "https://registry.npmjs.org/marked/-/marked-15.0.6.tgz",
|
||||
"integrity": "sha512-Y07CUOE+HQXbVDCGl3LXggqJDbXDP2pArc2C1N1RRMN0ONiShoSsIInMd5Gsxupe7fKLpgimTV+HOJ9r7bA+pg==",
|
||||
"bin": {
|
||||
"marked": "bin/marked.js"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 18"
|
||||
}
|
||||
},
|
||||
"node_modules/mdurl": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/mdurl/-/mdurl-1.0.1.tgz",
|
||||
@@ -8460,6 +8472,11 @@
|
||||
"uc.micro": "^1.0.5"
|
||||
}
|
||||
},
|
||||
"marked": {
|
||||
"version": "15.0.6",
|
||||
"resolved": "https://registry.npmjs.org/marked/-/marked-15.0.6.tgz",
|
||||
"integrity": "sha512-Y07CUOE+HQXbVDCGl3LXggqJDbXDP2pArc2C1N1RRMN0ONiShoSsIInMd5Gsxupe7fKLpgimTV+HOJ9r7bA+pg=="
|
||||
},
|
||||
"mdurl": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/mdurl/-/mdurl-1.0.1.tgz",
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
"date-fns": "2.30.0",
|
||||
"js-md5": "^0.8.3",
|
||||
"lodash": "4.17.21",
|
||||
"marked": "^15.0.6",
|
||||
"pinia": "2.1.6",
|
||||
"remixicon": "3.5.0",
|
||||
"vee-validate": "4.11.3",
|
||||
@@ -32,8 +33,6 @@
|
||||
"vue3-apexcharts": "1.4.4",
|
||||
"vue3-print-nb": "0.1.4",
|
||||
"vuetify": "3.3.14",
|
||||
"xterm": "^5.3.0",
|
||||
"xterm-addon-fit": "^0.8.0",
|
||||
"yup": "1.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -30,6 +30,11 @@ const sidebarItem: menu[] = [
|
||||
icon: 'mdi-puzzle',
|
||||
to: '/extension'
|
||||
},
|
||||
{
|
||||
title: '聊天',
|
||||
icon: 'mdi-chat',
|
||||
to: '/chat'
|
||||
},
|
||||
{
|
||||
title: '控制台',
|
||||
icon: 'mdi-console',
|
||||
|
||||
@@ -36,6 +36,11 @@ const MainRoutes = {
|
||||
name: 'Project ATRI',
|
||||
path: '/project-atri',
|
||||
component: () => import('@/views/ATRIProject.vue')
|
||||
},
|
||||
{
|
||||
name: 'Chat',
|
||||
path: '/chat',
|
||||
component: () => import('@/views/ChatPage.vue')
|
||||
}
|
||||
]
|
||||
};
|
||||
|
||||
@@ -0,0 +1,504 @@
|
||||
<script setup>
|
||||
import axios from 'axios';
|
||||
import { ref } from 'vue';
|
||||
import { marked } from 'marked';
|
||||
|
||||
|
||||
marked.setOptions({
|
||||
breaks: true
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
|
||||
<v-card style="margin-bottom: 16px; width: 100%; background-color: #fff; height: 100%;">
|
||||
<v-card-text style="width: 100%; height: calc(100vh - 120px);">
|
||||
<div style="height: 100%; display: flex; gap: 16px;">
|
||||
<div style="max-width: 200px;">
|
||||
<!-- conversation -->
|
||||
<v-btn variant="tonal" rounded="xl" style="margin-bottom: 16px; min-width: 200px;" @click="newC"
|
||||
:disabled="!currCid">+ 创建对话</v-btn>
|
||||
|
||||
<v-card class="mx-auto" min-width="200">
|
||||
<v-list dense nav v-if="conversations.length > 0" style="max-height: 500px; overflow-y: auto;"
|
||||
@update:selected="getConversationMessages">
|
||||
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
|
||||
color="primary" rounded="xl">
|
||||
<v-list-item-title>新对话</v-list-item-title>
|
||||
<v-list-item-subtitle>{{ formatDate(item.updated_at) }}</v-list-item-subtitle>
|
||||
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-card>
|
||||
|
||||
<div>
|
||||
|
||||
<v-chip class="mt-4" color="primary" :append-icon="status?.llm_enabled ? 'mdi-check' : 'mdi-close'">
|
||||
LLM
|
||||
</v-chip>
|
||||
|
||||
<v-chip class="mt-4 ml-2" color="success" :append-icon="status?.stt_enabled ? 'mdi-check' : 'mdi-close'">
|
||||
语音转文本
|
||||
</v-chip>
|
||||
</div>
|
||||
|
||||
<v-btn variant="tonal" rounded="xl"
|
||||
style="position: fixed; bottom: 48px; margin-bottom: 16px; min-width: 200px;" v-if="currCid"
|
||||
@click="deleteConversation(currCid)" color="error">删除此对话</v-btn>
|
||||
</div>
|
||||
|
||||
<div style="height: 100%; width: 100%;">
|
||||
<div style="height: calc(100% - 120px); overflow-y: auto; padding: 16px; " ref="messageContainer">
|
||||
<div class="fade-in" v-if="messages.length == 0"
|
||||
style="height: 100%; display: flex; justify-content: center; align-items: center; flex-direction: column;">
|
||||
<div>
|
||||
<span style="font-size: 28px;">Hello, I'm</span>
|
||||
<span style="font-weight: 1000; font-size: 28px; margin-left: 8px;">AstrBot ⭐</span>
|
||||
</div>
|
||||
<div style="margin-top: 8px; color: #aaa;">
|
||||
<span>输入</span>
|
||||
<span
|
||||
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">/help</span>
|
||||
<span>获取帮助 😊</span>
|
||||
</div>
|
||||
<div style="margin-top: 8px; color: #aaa;">
|
||||
<span>按</span>
|
||||
<span
|
||||
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">K</span>
|
||||
<span>开始语音 🎤</span>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<div v-else style="max-height: 100%; padding: 16px; max-width: 700px; margin: 0 auto;">
|
||||
<div class="fade-in" v-for="(msg, index) in messages" :key="index"
|
||||
style="margin-bottom: 16px;">
|
||||
<div v-if="msg.type == 'user'" style="display: flex; justify-content: flex-end;">
|
||||
<div
|
||||
style="padding: 12px; border-radius: 8px; background-color: rgba(94, 53, 177, 0.15)">
|
||||
<span>{{ msg.message }}</span>
|
||||
<div style="display: flex; gap: 8px; margin-top: 8px;"
|
||||
v-if="msg.image_url && msg.image_url.length > 0">
|
||||
<div v-for="(img, index) in msg.image_url" :key="index"
|
||||
style="position: relative; display: inline-block;">
|
||||
<img :src="img"
|
||||
style="width: 100px; height: 100px; border-radius: 8px; box-shadow: 0 0 5px rgba(0, 0, 0, 0.1);" />
|
||||
</div>
|
||||
</div>
|
||||
<!-- audio -->
|
||||
<div>
|
||||
<audio controls v-if="msg.audio_url && msg.audio_url.length > 0">
|
||||
<source :src="msg.audio_url" type="audio/wav">
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div v-else style="display: flex; justify-content: flex-start; gap: 16px;">
|
||||
<span style="font-size: 32px;">✨</span>
|
||||
<div v-html="marked(msg.message)" class="mc" style="font-family: inherit;"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="fade-in" style="bottom: 16px; width: 100%; padding: 8px; ">
|
||||
|
||||
<div
|
||||
style="width: 100%; justify-content: center; align-items: center; display: flex; flex-direction: column; margin-top: 8px;">
|
||||
|
||||
<v-text-field id="input-field" variant="outlined" v-model="prompt" :label="inputFieldLabel"
|
||||
placeholder="Start typing..." loading clear-icon="mdi-close-circle" clearable
|
||||
@click:clear="clearMessage" style="width: 100%; max-width: 850px;">
|
||||
<template v-slot:loader>
|
||||
<v-progress-linear :active="loadingChat" height="6"
|
||||
indeterminate></v-progress-linear>
|
||||
</template>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-tooltip text="发送">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-icon v-bind="props" @click="sendMessage" size="35"
|
||||
icon="mdi-arrow-up-circle" />
|
||||
</template>
|
||||
</v-tooltip>
|
||||
|
||||
|
||||
<v-tooltip text="语音输入">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-icon :color="isRecording ? 'error' : ''" v-bind="props"
|
||||
@click="isRecording ? stopRecording() : startRecording()" size="35"
|
||||
icon="mdi-record-circle" />
|
||||
</template>
|
||||
</v-tooltip>
|
||||
|
||||
</template>
|
||||
</v-text-field>
|
||||
|
||||
<div style="display: flex; gap: 8px; margin-top: -8px;">
|
||||
<div v-for="(img, index) in stagedImagesUrl" :key="index"
|
||||
style="position: relative; display: inline-block;">
|
||||
<img :src="img"
|
||||
style="width: 50px; height: 50px; border-radius: 8px; box-shadow: 0 0 5px rgba(0, 0, 0, 0.1);" />
|
||||
<v-icon @click="removeImage(index)" size="20" color="red"
|
||||
style="position: absolute; top: 0; right: 0; cursor: pointer;">mdi-close-circle</v-icon>
|
||||
</div>
|
||||
<div style="display: inline-block; width: 50px; height: 50px;">
|
||||
<div v-if="stagedAudioUrl"
|
||||
style="position: relative; padding: 6px; border-radius: 8px; background-color: rgba(94, 53, 177, 0.15); display: inline-block;">
|
||||
新录音
|
||||
<v-icon @click="removeAudio" size="20" color="red"
|
||||
style="position: absolute; top: 0; right: 0; cursor: pointer;">mdi-close-circle</v-icon>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</template>
|
||||
<script>
|
||||
export default {
|
||||
name: 'ChatPage',
|
||||
components: {
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
prompt: '',
|
||||
messages: [],
|
||||
conversations: [],
|
||||
currCid: '',
|
||||
stagedImagesUrl: [],
|
||||
loadingChat: false,
|
||||
|
||||
inputFieldLabel: '聊天吧!',
|
||||
|
||||
isRecording: false,
|
||||
audioChunks: [],
|
||||
stagedAudioUrl: "",
|
||||
mediaRecorder: null,
|
||||
|
||||
status: {},
|
||||
statusText: ''
|
||||
}
|
||||
},
|
||||
|
||||
mounted() {
|
||||
this.checkStatus();
|
||||
this.getConversations();
|
||||
let inputField = document.getElementById('input-field');
|
||||
inputField.addEventListener('paste', this.handlePaste);
|
||||
inputField.addEventListener('keydown', function (e) {
|
||||
if (e.keyCode == 13 && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
this.sendMessage();
|
||||
}
|
||||
}.bind(this));
|
||||
document.addEventListener('keydown', function (e) {
|
||||
if (e.keyCode == 75) {
|
||||
this.isRecording ? this.stopRecording() : this.startRecording();
|
||||
}
|
||||
}.bind(this));
|
||||
},
|
||||
|
||||
methods: {
|
||||
|
||||
removeAudio() {
|
||||
this.stagedAudioUrl = null;
|
||||
},
|
||||
|
||||
checkStatus() {
|
||||
axios.get('/api/chat/status').then(response => {
|
||||
console.log(response.data);
|
||||
this.status = response.data.data;
|
||||
}).catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
},
|
||||
|
||||
async startRecording() {
|
||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
||||
this.mediaRecorder = new MediaRecorder(stream);
|
||||
this.mediaRecorder.ondataavailable = (event) => {
|
||||
this.audioChunks.push(event.data);
|
||||
};
|
||||
this.mediaRecorder.start();
|
||||
this.isRecording = true;
|
||||
this.inputFieldLabel = "录音中,请说话...";
|
||||
},
|
||||
|
||||
async stopRecording() {
|
||||
this.isRecording = false;
|
||||
this.inputFieldLabel = "聊天吧!";
|
||||
this.mediaRecorder.stop();
|
||||
this.mediaRecorder.onstop = async () => {
|
||||
const audioBlob = new Blob(this.audioChunks, { type: 'audio/wav' });
|
||||
this.audioChunks = [];
|
||||
|
||||
this.mediaRecorder.stream.getTracks().forEach(track => track.stop());
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('file', audioBlob);
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/chat/post_file', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||
}
|
||||
});
|
||||
|
||||
const audio = response.data.data.filename;
|
||||
console.log('Audio uploaded:', audio);
|
||||
|
||||
this.stagedAudioUrl = `/api/chat/get_file?filename=${audio}`;
|
||||
} catch (err) {
|
||||
console.error('Error uploading audio:', err);
|
||||
}
|
||||
};
|
||||
},
|
||||
|
||||
async handlePaste(event) {
|
||||
console.log('Pasting image...');
|
||||
const items = event.clipboardData.items;
|
||||
for (let i = 0; i < items.length; i++) {
|
||||
if (items[i].type.indexOf('image') !== -1) {
|
||||
const file = items[i].getAsFile();
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/chat/post_image', formData, {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||
}
|
||||
});
|
||||
|
||||
const img = response.data.data.filename;
|
||||
this.stagedImagesUrl.push(`/api/chat/get_file?filename=${img}`);
|
||||
|
||||
} catch (err) {
|
||||
console.error('Error uploading image:', err);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
removeImage(index) {
|
||||
this.stagedImagesUrl.splice(index, 1);
|
||||
},
|
||||
|
||||
clearMessage() {
|
||||
this.prompt = '';
|
||||
},
|
||||
getConversations() {
|
||||
axios.get('/api/chat/conversations').then(response => {
|
||||
this.conversations = response.data.data;
|
||||
}).catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
},
|
||||
getConversationMessages(cid) {
|
||||
if (!cid[0])
|
||||
return;
|
||||
axios.get('/api/chat/get_conversation?conversation_id=' + cid[0]).then(response => {
|
||||
this.currCid = cid[0];
|
||||
let message = JSON.parse(response.data.data.history);
|
||||
for (let i = 0; i < message.length; i++) {
|
||||
if (message[i].message.startsWith('[IMAGE]')) {
|
||||
let img = message[i].message.replace('[IMAGE]', '');
|
||||
message[i].message = `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
}
|
||||
if (message[i].image_url && message[i].image_url.length > 0) {
|
||||
for (let j = 0; j < message[i].image_url.length; j++) {
|
||||
message[i].image_url[j] = `/api/chat/get_file?filename=${message[i].image_url[j]}`;
|
||||
}
|
||||
}
|
||||
if (message[i].audio_url) {
|
||||
message[i].audio_url = `/api/chat/get_file?filename=${message[i].audio_url}`;
|
||||
}
|
||||
}
|
||||
this.messages = message;
|
||||
}).catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
},
|
||||
async newConversation() {
|
||||
await axios.get('/api/chat/new_conversation').then(response => {
|
||||
this.currCid = response.data.data.conversation_id;
|
||||
this.getConversations();
|
||||
}).catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
},
|
||||
|
||||
newC() {
|
||||
this.currCid = '';
|
||||
this.messages = [];
|
||||
},
|
||||
|
||||
formatDate(timestamp) {
|
||||
const date = new Date(timestamp * 1000); // 假设时间戳是以秒为单位
|
||||
const options = {
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
second: '2-digit',
|
||||
hour12: false
|
||||
};
|
||||
return date.toLocaleString('zh-CN', options).replace(/\//g, '-').replace(/, /g, ' ');
|
||||
},
|
||||
|
||||
deleteConversation(cid) {
|
||||
axios.get('/api/chat/delete_conversation?conversation_id=' + cid).then(response => {
|
||||
this.getConversations();
|
||||
this.currCid = '';
|
||||
this.messages = [];
|
||||
}).catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
},
|
||||
|
||||
async sendMessage() {
|
||||
if (this.currCid == '') {
|
||||
await this.newConversation();
|
||||
}
|
||||
|
||||
this.messages.push({
|
||||
type: 'user',
|
||||
message: this.prompt,
|
||||
image_url: this.stagedImagesUrl,
|
||||
audio_url: this.stagedAudioUrl
|
||||
});
|
||||
|
||||
this.scrollToBottom();
|
||||
|
||||
// images
|
||||
let image_filenames = [];
|
||||
for (let i = 0; i < this.stagedImagesUrl.length; i++) {
|
||||
let img = this.stagedImagesUrl[i].replace('/api/chat/get_file?filename=', '');
|
||||
image_filenames.push(img);
|
||||
}
|
||||
|
||||
// audio
|
||||
let audio_filenames = [];
|
||||
if (this.stagedAudioUrl) {
|
||||
let audio = this.stagedAudioUrl.replace('/api/chat/get_file?filename=', '');
|
||||
audio_filenames.push(audio);
|
||||
}
|
||||
|
||||
this.loadingChat = true;
|
||||
|
||||
|
||||
fetch('/api/chat/send', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||
},
|
||||
body: JSON.stringify({
|
||||
message: this.prompt,
|
||||
conversation_id: this.currCid,
|
||||
image_url: image_filenames,
|
||||
audio_url: audio_filenames
|
||||
}) // 发送请求体
|
||||
})
|
||||
.then(response => {
|
||||
this.prompt = '';
|
||||
this.stagedImagesUrl = [];
|
||||
this.stagedAudioUrl = "";
|
||||
|
||||
this.loadingChat = false;
|
||||
|
||||
const reader = response.body.getReader(); // 获取流的 Reader
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
const readStream = async () => {
|
||||
const { done, value } = await reader.read(); // 读取流中的数据
|
||||
if (done) {
|
||||
console.log("Stream finished.");
|
||||
return;
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true });
|
||||
// bot_resp.message.value += chunk;
|
||||
|
||||
console.log("!!!!", chunk);
|
||||
if (chunk.startsWith('[IMAGE]')) {
|
||||
let img = chunk.replace('[IMAGE]', '');
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
} else {
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: chunk
|
||||
}
|
||||
|
||||
this.messages.push(bot_resp);
|
||||
}
|
||||
|
||||
this.scrollToBottom();
|
||||
readStream(); // 递归读取流
|
||||
};
|
||||
|
||||
readStream();
|
||||
})
|
||||
.catch(err => {
|
||||
console.error(err);
|
||||
});
|
||||
},
|
||||
scrollToBottom() {
|
||||
this.$nextTick(() => {
|
||||
const container = this.$refs.messageContainer;
|
||||
container.scrollTop = container.scrollHeight;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
</script>
|
||||
|
||||
<style>
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.fade-in {
|
||||
animation: fadeIn 0.2s ease-in-out;
|
||||
}
|
||||
|
||||
.mc h1,
|
||||
.mc h2,
|
||||
.mc h3,
|
||||
.mc h4,
|
||||
.mc h5,
|
||||
.mc h6 {
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.mc li {
|
||||
margin-left: 16px;
|
||||
}
|
||||
|
||||
|
||||
.mc p {
|
||||
margin-top: 10px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
</style>
|
||||
@@ -1,5 +1,7 @@
|
||||
<script setup>
|
||||
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||
import axios from 'axios';
|
||||
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -7,8 +9,34 @@ import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||
<div
|
||||
style="background-color: white; padding: 8px; padding-left: 16px; border-radius: 8px; margin-bottom: 16px; display: flex; flex-direction: row; align-items: center; justify-content: space-between;">
|
||||
<h4>控制台</h4>
|
||||
<v-dialog v-model="pipDialog" width="400">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn variant="plain" v-bind="props">安装 pip 库</v-btn>
|
||||
</template>
|
||||
<v-card>
|
||||
<v-card-title>
|
||||
<span class="text-h5">安装 Pip 库</span>
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<v-text-field v-model="pipInstallPayload.package" label="*库名,如 llmtuner" variant="outlined"></v-text-field>
|
||||
<v-text-field v-model="pipInstallPayload.mirror" label="镜像站链接(可选)" variant="outlined"></v-text-field>
|
||||
<small>如果不填镜像站链接,默认使用阿里云镜像:https://mirrors.aliyun.com/pypi/simple/</small>
|
||||
<div>
|
||||
<small>{{ status }}</small>
|
||||
</div>
|
||||
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="blue-darken-1" variant="text" @click="pipInstall" :loading="loading">
|
||||
安装
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
</div>
|
||||
<ConsoleDisplayer style="height: calc(100vh - 160px); "/>
|
||||
<ConsoleDisplayer style="height: calc(100vh - 160px); " />
|
||||
</div>
|
||||
</template>
|
||||
<script>
|
||||
@@ -17,6 +45,36 @@ export default {
|
||||
components: {
|
||||
ConsoleDisplayer
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
pipDialog: false,
|
||||
pipInstallPayload: {
|
||||
package: '',
|
||||
mirror: ''
|
||||
},
|
||||
loading: false,
|
||||
status: ''
|
||||
}
|
||||
},
|
||||
|
||||
methods: {
|
||||
pipInstall() {
|
||||
this.loading = true;
|
||||
axios.post('/api/update/pip-install', this.pipInstallPayload)
|
||||
.then(res => {
|
||||
this.status = res.data.message;
|
||||
setTimeout(() => {
|
||||
this.status = '';
|
||||
this.pipDialog = false;
|
||||
}, 2000);
|
||||
})
|
||||
.catch(err => {
|
||||
this.status = err.response.data.message;
|
||||
}).finally(() => {
|
||||
this.loading = false;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
</script>
|
||||
@@ -26,10 +84,12 @@ export default {
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.fade-in {
|
||||
animation: fadeIn 0.2s ease-in-out;
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ import axios from 'axios';
|
||||
|
||||
<template>
|
||||
<v-row>
|
||||
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以前往 配置->其他配置->插件仓库镜像 修改安装镜像源。2. 如需插件帮助请点击 `仓库` 查看 README"
|
||||
title="💡小提示" type="info" variant="tonal">
|
||||
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README"
|
||||
title="💡提示" type="info" variant="tonal">
|
||||
</v-alert>
|
||||
<v-col cols="12" md="12">
|
||||
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
|
||||
@@ -80,7 +80,7 @@ import axios from 'axios';
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<v-dialog v-model="dialog" persistent width="700">
|
||||
<v-dialog v-model="dialog" width="700">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn v-bind="props" icon="mdi-plus" size="x-large" style="position: fixed; right: 52px; bottom: 52px;"
|
||||
color="darkprimary">
|
||||
|
||||
@@ -2,8 +2,6 @@ import os
|
||||
import asyncio
|
||||
import sys
|
||||
import mimetypes
|
||||
import aiohttp
|
||||
import zipfile
|
||||
from astrbot.dashboard import AstrBotDashBoardLifecycle
|
||||
from astrbot.core import db_helper
|
||||
from astrbot.core import logger, LogManager, LogBroker
|
||||
|
||||
@@ -56,6 +56,10 @@ class Main(star.Star):
|
||||
/persona: 情境人格设置
|
||||
/tool ls: 查看、激活、停用当前注册的函数工具
|
||||
|
||||
[其他]
|
||||
/set <变量名> <值>: 为当前会话定义一个变量。适用于 Dify 工作流输入。
|
||||
/unset <变量名>: 删除当前会话的变量。
|
||||
|
||||
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
|
||||
{notice}"""
|
||||
|
||||
@@ -345,7 +349,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
|
||||
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dashboard update")
|
||||
@filter.command("dashboard_update")
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
yield event.plain_result("正在尝试更新管理面板...")
|
||||
await download_dashboard()
|
||||
@@ -365,12 +369,35 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
|
||||
if provider.curr_personality['prompt']:
|
||||
req.system_prompt += f"\n{provider.curr_personality['prompt']}"
|
||||
|
||||
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
|
||||
async def other_message(self, event: AstrMessageEvent):
|
||||
print("triggered")
|
||||
event.stop_event()
|
||||
|
||||
@filter.command("set")
|
||||
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
|
||||
session_id = event.get_session_id()
|
||||
session_vars = sp.get("session_variables", {})
|
||||
|
||||
session_var = session_vars.get(session_id, {})
|
||||
session_var[key] = value
|
||||
|
||||
session_vars[session_id] = session_var
|
||||
|
||||
sp.put("session_variables", session_vars)
|
||||
|
||||
yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。")
|
||||
|
||||
@filter.command("unset")
|
||||
async def unset_variable(self, event: AstrMessageEvent, key: str):
|
||||
session_id = event.get_session_id()
|
||||
session_vars = sp.get("session_variables", {})
|
||||
|
||||
session_var = session_vars.get(session_id, {})
|
||||
|
||||
if key not in session_var:
|
||||
yield event.plain_result("没有那个变量名。")
|
||||
else:
|
||||
del session_var[key]
|
||||
sp.put("session_variables", session_vars)
|
||||
yield event.plain_result(f"会话 {session_id} 变量 {key} 移除成功。")
|
||||
|
||||
@filter.command_group("kdb")
|
||||
def kdb(self):
|
||||
pass
|
||||
|
||||
@@ -363,10 +363,24 @@ class Main(star.Star):
|
||||
logger.warning(f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}")
|
||||
break
|
||||
else:
|
||||
# 成功了
|
||||
self.user_file_msg_buffer.pop(event.get_session_id())
|
||||
return
|
||||
|
||||
yield event.plain_result("经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。")
|
||||
|
||||
|
||||
@pi.command("cleanfile")
|
||||
async def pi_cleanfile(self, event: AstrMessageEvent):
|
||||
'''清理用户上传的文件'''
|
||||
for file in self.user_file_msg_buffer[event.get_session_id()]:
|
||||
try:
|
||||
os.remove(file)
|
||||
except BaseException as e:
|
||||
logger.error(f"删除文件 {file} 失败: {e}")
|
||||
|
||||
self.user_file_msg_buffer.pop(event.get_session_id())
|
||||
yield event.plain_result(f"用户 {event.get_session_id()} 上传的文件已清理。")
|
||||
|
||||
|
||||
async def run_container(self, container: aiodocker.docker.DockerContainer, timeout: int = 20) -> list[str]:
|
||||
'''Run the container and get the output'''
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
aiodocker
|
||||
@@ -34,7 +34,7 @@ class Main(star.Star):
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
trigger='date',
|
||||
args=[reminder["text"], reminder],
|
||||
args=[group, reminder],
|
||||
run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"),
|
||||
misfire_grace_time=60
|
||||
)
|
||||
@@ -42,7 +42,7 @@ class Main(star.Star):
|
||||
self.scheduler.add_job(
|
||||
self._reminder_callback,
|
||||
trigger='cron',
|
||||
args=[reminder["text"], reminder],
|
||||
args=[group, reminder],
|
||||
misfire_grace_time=60,
|
||||
**self._parse_cron_expr(reminder["cron"])
|
||||
)
|
||||
|
||||
+2
-1
@@ -16,4 +16,5 @@ aiocqhttp
|
||||
pyjwt
|
||||
apscheduler
|
||||
docstring_parser
|
||||
aiodocker
|
||||
aiodocker
|
||||
silk-python
|
||||
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
|
||||
model_id = "openai/whisper-large-v3"
|
||||
|
||||
model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
|
||||
)
|
||||
model.to(device)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
|
||||
pipe = pipeline(
|
||||
"automatic-speech-recognition",
|
||||
model=model,
|
||||
tokenizer=processor.tokenizer,
|
||||
feature_extractor=processor.feature_extractor,
|
||||
chunk_length_s=30,
|
||||
batch_size=16, # batch size for inference - set based on your device
|
||||
torch_dtype=torch_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
|
||||
sample = dataset[0]["audio"]
|
||||
|
||||
result = pipe(sample)
|
||||
print(result["text"])
|
||||
Reference in New Issue
Block a user