Merge pull request #114 from lxfight/lwl-dev/knowledge-base

refactor: 知识库优化
This commit is contained in:
Soulter
2025-10-25 00:42:16 +08:00
committed by GitHub
100 changed files with 7630 additions and 6106 deletions
+9 -2
View File
@@ -13,11 +13,18 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v5
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 'latest'
- name: npm install, build
run: |
cd dashboard
npm install
npm run build
npm install pnpm -g
pnpm install
pnpm i --save-dev @types/markdown-it
pnpm run build
- name: Inject Commit SHA
id: get_sha
+17 -5
View File
@@ -6,8 +6,20 @@ ci:
autoupdate_schedule: weekly
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.2
hooks:
- id: ruff
- id: ruff-format
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.14.1
hooks:
# Run the linter.
- id: ruff-check
types_or: [ python, pyi ]
args: [ --fix ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
- repo: https://github.com/asottile/pyupgrade
rev: v3.21.0
hooks:
- id: pyupgrade
args: [--py310-plus]
+17 -6
View File
@@ -4,14 +4,25 @@
<div align="center">
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<br>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</div>
<br>
<div>
<img src="https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/Soulter/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600)
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
@@ -118,7 +129,7 @@ uv run main.py
| QQ(OneBot) | ✔ |
| Telegram | ✔ |
| 企微应用 | ✔ |
| 企机器人 | ✔ |
| 企微智能机器人 | ✔ |
| 微信客服 | ✔ |
| 微信公众号 | ✔ |
| 飞书 | ✔ |
@@ -209,9 +209,38 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
continue
valid_params = {} # 参数过滤:只传递函数实际需要的参数
# 获取实际的 handler 函数
if func_tool.handler:
logger.debug(
f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}"
)
if func_tool.parameters and func_tool.parameters.get("properties"):
expected_params = set(func_tool.parameters["properties"].keys())
valid_params = {
k: v
for k, v in func_tool_args.items()
if k in expected_params
}
# 记录被忽略的参数
ignored_params = set(func_tool_args.keys()) - set(
valid_params.keys()
)
if ignored_params:
logger.warning(
f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}"
)
else:
# 如果没有 handler(如 MCP 工具),使用所有参数
valid_params = func_tool_args
logger.warning(f"工具 {func_tool_name} 没有 handler,使用所有参数")
try:
await self.agent_hooks.on_tool_start(
self.run_context, func_tool, func_tool_args
self.run_context, func_tool, valid_params
)
except Exception as e:
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
@@ -219,7 +248,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
executor = self.tool_executor.execute(
tool=func_tool,
run_context=self.run_context,
**func_tool_args,
**valid_params, # 只传递有效的参数
)
_final_resp: CallToolResult | None = None
+23 -51
View File
@@ -5,6 +5,7 @@ from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
from astrbot.core.config.default import DEFAULT_CONFIG
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
from typing import TypeVar, TypedDict
@@ -15,14 +16,12 @@ class ConfInfo(TypedDict):
"""Configuration information for a specific session or platform."""
id: str # UUID of the configuration or "default"
umop: list[str] # Unified Message Origin Pattern
name: str
path: str # File name to the configuration file
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
id="default",
umop=["::"],
name="default",
path=ASTRBOT_CONFIG_PATH,
)
@@ -31,8 +30,14 @@ DEFAULT_CONFIG_CONF_INFO = ConfInfo(
class AstrBotConfigManager:
"""A class to manage the system configuration of AstrBot, aka ACM"""
def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
def __init__(
self,
default_config: AstrBotConfig,
ucr: UmopConfigRouter,
sp: SharedPreferences,
):
self.sp = sp
self.ucr = ucr
self.confs: dict[str, AstrBotConfig] = {}
"""uuid / "default" -> AstrBotConfig"""
self.confs["default"] = default_config
@@ -63,24 +68,15 @@ class AstrBotConfigManager:
)
continue
def _is_umo_match(self, p1: str, p2: str) -> bool:
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
p1_ls = p1.split(":")
p2_ls = p2.split(":")
if len(p1_ls) != 3 or len(p2_ls) != 3:
return False # 非法格式
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
Returns:
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
"""
# uuid -> { "umop": list, "path": str, "name": str }
# uuid -> { "path": str, "name": str }
abconf_data = self._get_abconf_data()
if isinstance(umo, MessageSession):
umo = str(umo)
else:
@@ -89,10 +85,13 @@ class AstrBotConfigManager:
except Exception:
return DEFAULT_CONFIG_CONF_INFO
for uuid_, meta in abconf_data.items():
for pattern in meta["umop"]:
if self._is_umo_match(pattern, umo):
return ConfInfo(**meta, id=uuid_)
conf_id = self.ucr.get_conf_id_for_umop(umo)
if conf_id:
meta = abconf_data.get(conf_id)
if meta and isinstance(meta, dict):
# the bind relation between umo and conf is defined in ucr now, so we remove "umop" here
meta.pop("umop", None)
return ConfInfo(**meta, id=conf_id)
return DEFAULT_CONFIG_CONF_INFO
@@ -100,23 +99,14 @@ class AstrBotConfigManager:
self,
abconf_path: str,
abconf_id: str,
umo_parts: list[str] | list[MessageSession],
abconf_name: str | None = None,
) -> None:
"""保存配置文件的映射关系"""
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data = self.sp.get(
"abconf_mapping", {}, scope="global", scope_id="global"
)
random_word = abconf_name or uuid.uuid4().hex[:8]
abconf_data[abconf_id] = {
"umop": umo_parts,
"path": abconf_path,
"name": random_word,
}
@@ -153,29 +143,26 @@ class AstrBotConfigManager:
def get_conf_list(self) -> list[ConfInfo]:
"""获取所有配置文件的元数据列表"""
conf_list = []
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
abconf_mapping = self._get_abconf_data()
for uuid_, meta in abconf_mapping.items():
if not isinstance(meta, dict):
continue
meta.pop("umop", None)
conf_list.append(ConfInfo(**meta, id=uuid_))
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
return conf_list
def create_conf(
self,
umo_parts: list[str] | list[MessageSession],
config: dict = DEFAULT_CONFIG,
name: str | None = None,
) -> str:
"""
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
"""
conf_uuid = str(uuid.uuid4())
conf_file_name = f"abconf_{conf_uuid}.json"
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
conf = AstrBotConfig(config_path=conf_path, default_config=config)
conf.save_config()
self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
self._save_conf_mapping(conf_file_name, conf_uuid, abconf_name=name)
self.confs[conf_uuid] = conf
return conf_uuid
@@ -228,15 +215,12 @@ class AstrBotConfigManager:
logger.info(f"成功删除配置文件 {conf_id}")
return True
def update_conf_info(
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
) -> bool:
def update_conf_info(self, conf_id: str, name: str | None = None) -> bool:
"""更新配置文件信息
Args:
conf_id: 配置文件的 UUID
name: 新的配置文件名称 (可选)
umo_parts: 新的 UMO 部分列表 (可选)
Returns:
bool: 更新是否成功
@@ -255,18 +239,6 @@ class AstrBotConfigManager:
if name is not None:
abconf_data[conf_id]["name"] = name
# 更新 UMO 部分
if umo_parts is not None:
# 验证 UMO 部分格式
for part in umo_parts:
if isinstance(part, MessageSession):
part = str(part)
elif not isinstance(part, str):
raise ValueError(
"umo_parts must be a list of strings or MessageSession instances"
)
abconf_data[conf_id]["umop"] = umo_parts
# 保存更新
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
self.abconf_data = abconf_data
+90 -30
View File
@@ -134,27 +134,11 @@ DEFAULT_CONFIG = {
"persona": [], # deprecated
"timezone": "Asia/Shanghai",
"callback_api_base": "",
"default_kb_collection": "", # 默认知识库名称
"default_kb_collection": "", # 默认知识库名称, 已经过时
"plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件
"knowledge_base": {
"enabled": False, # 默认禁用,用户需要主动启用
"embedding_provider_id": "", # 嵌入模型提供商 ID (为空时自动选择第一个)
"rerank_provider_id": "", # 重排序模型提供商 ID (为空时自动选择第一个)
"storage": {
"files_path": "data/knowledge_base", # 文件存储路径
"vector_db_path": "data/knowledge_base/vectors", # 向量数据库路径
},
"chunking": {
"chunk_size": 512, # 文档块大小(字符数)
"chunk_overlap": 50, # 文档块重叠大小(字符数)
},
"retrieval": {
"top_k_dense": 50, # 密集检索返回结果数
"top_k_sparse": 50, # 稀疏检索返回结果数
"top_m_final": 5, # 最终融合后返回的结果数
"enable_rerank": True, # 是否启用重排序
},
},
"kb_names": [], # 默认知识库名称列表
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
}
@@ -181,10 +165,11 @@ CONFIG_METADATA_2 = {
"enable": False,
"appid": "",
"secret": "",
"is_sandbox": False,
"callback_server_host": "0.0.0.0",
"port": 6196,
},
"QQ 个人号(aiocqhttp)": {
"QQ 个人号(OneBot v11)": {
"id": "default",
"type": "aiocqhttp",
"enable": False,
@@ -192,7 +177,7 @@ CONFIG_METADATA_2 = {
"ws_reverse_port": 6199,
"ws_reverse_token": "",
},
"微信个人号(WeChatPadPro)": {
"WeChatPadPro": {
"id": "wechatpadpro",
"type": "wechatpadpro",
"enable": False,
@@ -287,6 +272,14 @@ CONFIG_METADATA_2 = {
"misskey_default_visibility": "public",
"misskey_local_only": False,
"misskey_enable_chat": True,
# download / security options
"misskey_allow_insecure_downloads": False,
"misskey_download_timeout": 15,
"misskey_download_chunk_size": 65536,
"misskey_max_download_bytes": None,
"misskey_enable_file_upload": True,
"misskey_upload_concurrency": 3,
"misskey_upload_folder": "",
},
"Slack": {
"id": "slack",
@@ -311,8 +304,26 @@ CONFIG_METADATA_2 = {
"satori_heartbeat_interval": 10,
"satori_reconnect_delay": 5,
},
# "WebChat": {
# "id": "webchat",
# "type": "webchat",
# "enable": False,
# "webchat_link_path": "",
# "webchat_present_type": "fullscreen",
# },
},
"items": {
# "webchat_link_path": {
# "description": "链接路径",
# "_special": "webchat_link_path",
# "type": "string",
# },
# "webchat_present_type": {
# "_special": "webchat_present_type",
# "description": "展现形式",
# "type": "string",
# "options": ["fullscreen", "embedded"],
# },
"satori_api_base_url": {
"description": "Satori API 终结点",
"type": "string",
@@ -415,6 +426,41 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "启用后,机器人将会监听和响应私信聊天消息",
},
"misskey_enable_file_upload": {
"description": "启用文件上传到 Misskey",
"type": "bool",
"hint": "启用后,适配器会尝试将消息链中的文件上传到 Misskey。URL 文件会先尝试服务器端上传,异步上传失败时会回退到下载后本地上传。",
},
"misskey_allow_insecure_downloads": {
"description": "允许不安全下载(禁用 SSL 验证)",
"type": "bool",
"hint": "当远端服务器存在证书问题导致无法正常下载时,自动禁用 SSL 验证作为回退方案。适用于某些图床的证书配置问题。启用有安全风险,仅在必要时使用。",
},
"misskey_download_timeout": {
"description": "远端下载超时时间(秒)",
"type": "int",
"hint": "下载远程文件时的超时时间(秒),用于异步上传回退到本地上传的场景。",
},
"misskey_download_chunk_size": {
"description": "流式下载分块大小(字节)",
"type": "int",
"hint": "流式下载和计算 MD5 时使用的每次读取字节数,过小会增加开销,过大会占用内存。",
},
"misskey_max_download_bytes": {
"description": "最大允许下载字节数(超出则中止)",
"type": "int",
"hint": "如果希望限制下载文件的最大大小以防止 OOM,请填写最大字节数;留空或 null 表示不限制。",
},
"misskey_upload_concurrency": {
"description": "并发上传限制",
"type": "int",
"hint": "同时进行的文件上传任务上限(整数,默认 3)。",
},
"misskey_upload_folder": {
"description": "上传到网盘的目标文件夹 ID",
"type": "string",
"hint": "可选:填写 Misskey 网盘中目标文件夹的 ID,上传的文件将放置到该文件夹内。留空则使用账号网盘根目录。",
},
"telegram_command_register": {
"description": "Telegram 命令注册",
"type": "bool",
@@ -466,19 +512,18 @@ CONFIG_METADATA_2 = {
"hint": "启用后,机器人可以接收到频道的私聊消息。",
},
"ws_reverse_host": {
"description": "反向 Websocket 主机地址(AstrBot 为服务器端)",
"description": "反向 Websocket 主机",
"type": "string",
"hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号",
"hint": "AstrBot 将作为服务器端",
},
"ws_reverse_port": {
"description": "反向 Websocket 端口",
"type": "int",
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
},
"ws_reverse_token": {
"description": "反向 Websocket Token",
"type": "string",
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
"hint": "反向 Websocket Token。未设置则不启用 Token 验证。",
},
"wecom_ai_bot_name": {
"description": "企业微信智能机器人的名字",
@@ -2019,6 +2064,9 @@ CONFIG_METADATA_2 = {
"default_kb_collection": {
"type": "string",
},
"kb_names": {"type": "list", "items": {"type": "string"}},
"kb_fusion_top_k": {"type": "int", "default": 20},
"kb_final_top_k": {"type": "int", "default": 5},
},
},
}
@@ -2097,10 +2145,22 @@ CONFIG_METADATA_3 = {
"description": "知识库",
"type": "object",
"items": {
"default_kb_collection": {
"description": "默认使用的知识库",
"type": "string",
"kb_names": {
"description": "知识库列表",
"type": "list",
"items": {"type": "string"},
"_special": "select_knowledgebase",
"hint": "支持多选",
},
"kb_fusion_top_k": {
"description": "融合检索结果数",
"type": "int",
"hint": "多个知识库检索结果融合后的返回结果数量",
},
"kb_final_top_k": {
"description": "最终返回结果数",
"type": "int",
"hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
},
},
},
@@ -2194,7 +2254,7 @@ CONFIG_METADATA_3 = {
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
"hint": "例子: 如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
"hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
},
"provider_settings.prompt_prefix": {
"description": "用户提示词",
+17 -12
View File
@@ -17,7 +17,6 @@ import os
from .event_bus import EventBus
from . import astrbot_config, html_renderer
from asyncio import Queue
from typing import List
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
@@ -26,15 +25,17 @@ from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger, sp
from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star_handler import star_map
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
class AstrBotCoreLifecycle:
@@ -85,11 +86,21 @@ class AstrBotCoreLifecycle:
await html_renderer.initialize()
# 初始化 UMOP 配置路由器
self.umop_config_router = UmopConfigRouter(sp=sp)
# 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager(
default_config=self.astrbot_config, sp=sp
default_config=self.astrbot_config, ucr=self.umop_config_router, sp=sp
)
# 4.5 to 4.6 migration for umop_config_router
try:
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
except Exception as e:
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
logger.error(traceback.format_exc())
# 初始化事件队列
self.event_queue = Queue()
@@ -112,9 +123,7 @@ class AstrBotCoreLifecycle:
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
# 初始化知识库管理器
self.kb_manager = KnowledgeBaseManager(
self.astrbot_config, self.db, self.provider_manager
)
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
# 初始化提供给插件的上下文
self.star_context = Context(
@@ -141,10 +150,6 @@ class AstrBotCoreLifecycle:
await self.kb_manager.initialize()
# 注册知识库会话生命周期钩子(零侵入级联清理)
if self.kb_manager.is_initialized:
self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager)
# 初始化消息事件流水线调度器
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
@@ -160,7 +165,7 @@ class AstrBotCoreLifecycle:
self.start_time = int(time.time())
# 初始化当前任务列表
self.curr_tasks: List[asyncio.Task] = []
self.curr_tasks: list[asyncio.Task] = []
# 根据配置实例化各个平台适配器
await self.platform_manager.initialize()
@@ -267,7 +272,7 @@ class AstrBotCoreLifecycle:
target=self.astrbot_updator._reboot, name="restart", daemon=True
).start()
def load_platform(self) -> List[asyncio.Task]:
def load_platform(self) -> list[asyncio.Task]:
"""加载平台实例并返回所有平台实例的异步任务列表"""
tasks = []
platform_insts = self.platform_manager.get_insts()
@@ -0,0 +1,44 @@
from astrbot.api import logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.umop_config_router import UmopConfigRouter
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
abconf_data = acm.abconf_data
if not isinstance(abconf_data, dict):
# should be unreachable
logger.warning(
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}"
)
return
# 如果任何一项带有 umop,则说明需要迁移
need_migration = False
for conf_id, conf_info in abconf_data.items():
if isinstance(conf_info, dict) and "umop" in conf_info:
need_migration = True
break
if not need_migration:
return
logger.info("Starting migration from version 4.5 to 4.6")
# extract umo->conf_id mapping
umo_to_conf_id = {}
for conf_id, conf_info in abconf_data.items():
if isinstance(conf_info, dict) and "umop" in conf_info:
umop_ls = conf_info.pop("umop")
if not isinstance(umop_ls, list):
continue
for umo in umop_ls:
if isinstance(umo, str) and umo not in umo_to_conf_id:
umo_to_conf_id[umo] = conf_id
# update the abconf data
await sp.global_put("abconf_mapping", abconf_data)
# update the umop config router
await ucr.update_routing_data(umo_to_conf_id)
logger.info("Migration from version 45 to 46 completed successfully")
+33 -2
View File
@@ -16,14 +16,42 @@ class BaseVecDB:
pass
@abc.abstractmethod
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
async def insert(
self, content: str, metadata: dict | None = None, id: str | None = None
) -> int:
"""
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
"""
...
@abc.abstractmethod
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
async def insert_batch(
self,
contents: list[str],
metadatas: list[dict] | None = None,
ids: list[str] | None = None,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> int:
"""
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
Args:
progress_callback: 进度回调函数,接收参数 (current, total)
"""
...
@abc.abstractmethod
async def retrieve(
self,
query: str,
top_k: int = 5,
fetch_k: int = 20,
rerank: bool = False,
metadata_filters: dict | None = None,
) -> list[Result]:
"""
搜索最相似的文档。
Args:
@@ -44,3 +72,6 @@ class BaseVecDB:
bool: 删除是否成功
"""
...
@abc.abstractmethod
async def close(self): ...
@@ -1,59 +1,219 @@
import aiosqlite
import os
import json
from datetime import datetime
from contextlib import asynccontextmanager
from sqlalchemy import Text, Column
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
class BaseDocModel(SQLModel, table=False):
metadata = MetaData()
class Document(BaseDocModel, table=True):
"""SQLModel for documents table."""
__tablename__ = "documents" # type: ignore
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
doc_id: str = Field(nullable=False)
text: str = Field(nullable=False)
metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
created_at: datetime | None = Field(default=None)
updated_at: datetime | None = Field(default=None)
class DocumentStorage:
def __init__(self, db_path: str):
self.db_path = db_path
self.connection = None
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.engine: AsyncEngine | None = None
self.async_session_maker: sessionmaker | None = None
self.sqlite_init_path = os.path.join(
os.path.dirname(__file__), "sqlite_init.sql"
)
async def initialize(self):
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
if not os.path.exists(self.db_path):
await self.connect()
async with self.connection.cursor() as cursor:
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
sql_script = f.read()
await cursor.executescript(sql_script)
await self.connection.commit()
else:
await self.connect()
await self.connect()
async with self.engine.begin() as conn: # type: ignore
# Create tables using SQLModel
await conn.run_sync(BaseDocModel.metadata.create_all)
try:
await conn.execute(
text(
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
)
)
await conn.execute(
text(
"ALTER TABLE documents ADD COLUMN user_id TEXT "
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
)
)
# Create indexes
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
)
)
await conn.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
)
)
except BaseException:
pass
await conn.commit()
async def connect(self):
"""Connect to the SQLite database."""
self.connection = await aiosqlite.connect(self.db_path)
if self.engine is None:
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
future=True,
)
self.async_session_maker = sessionmaker(
self.engine, # type: ignore
class_=AsyncSession,
expire_on_commit=False,
) # type: ignore
async def get_documents(self, metadata_filters: dict, ids: list = None):
@asynccontextmanager
async def get_session(self):
"""Context manager for database sessions."""
async with self.async_session_maker() as session: # type: ignore
yield session
async def get_documents(
self,
metadata_filters: dict,
ids: list | None = None,
offset: int | None = 0,
limit: int | None = 100,
) -> list[dict]:
"""Retrieve documents by metadata filters and ids.
Args:
metadata_filters (dict): The metadata filters to apply.
ids (list | None): Optional list of document IDs to filter.
offset (int | None): Offset for pagination.
limit (int | None): Limit for pagination.
Returns:
list: The list of document IDs(primary key, not doc_id) that match the filters.
list: The list of documents that match the filters.
"""
# metadata filter -> SQL WHERE clause
where_clauses = []
values = []
for key, val in metadata_filters.items():
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
values.append(val)
if ids is not None and len(ids) > 0:
ids = [str(i) for i in ids if i != -1]
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
values.extend(ids)
where_sql = " AND ".join(where_clauses) or "1=1"
assert self.engine is not None, "Database connection is not initialized."
result = []
async with self.connection.cursor() as cursor:
sql = "SELECT * FROM documents WHERE " + where_sql
await cursor.execute(sql, values)
for row in await cursor.fetchall():
result.append(await self.tuple_to_dict(row))
return result
async with self.get_session() as session:
query = select(Document)
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
if ids is not None and len(ids) > 0:
valid_ids = [int(i) for i in ids if i != -1]
if valid_ids:
query = query.where(col(Document.id).in_(valid_ids))
if limit is not None:
query = query.limit(limit)
if offset is not None:
query = query.offset(offset)
result = await session.execute(query)
documents = result.scalars().all()
return [self._document_to_dict(doc) for doc in documents]
async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
"""Insert a single document and return its integer ID.
Args:
doc_id (str): The document ID (UUID string).
text (str): The document text.
metadata (dict): The document metadata.
Returns:
int: The integer ID of the inserted document.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
session.add(document)
await session.flush() # Flush to get the ID
return document.id # type: ignore
async def insert_documents_batch(
self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
) -> list[int]:
"""Batch insert documents and return their integer IDs.
Args:
doc_ids (list[str]): List of document IDs (UUID strings).
texts (list[str]): List of document texts.
metadatas (list[dict]): List of document metadata.
Returns:
list[int]: List of integer IDs of the inserted documents.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
import json
documents = []
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
document = Document(
doc_id=doc_id,
text=text,
metadata_=json.dumps(metadata),
created_at=datetime.now(),
updated_at=datetime.now(),
)
documents.append(document)
session.add(document)
await session.flush() # Flush to get all IDs
return [doc.id for doc in documents] # type: ignore
async def delete_document_by_doc_id(self, doc_id: str):
"""Delete a document by its doc_id.
Args:
doc_id (str): The doc_id of the document to delete.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
await session.delete(document)
async def get_document_by_doc_id(self, doc_id: str):
"""Retrieve a document by its doc_id.
@@ -62,28 +222,85 @@ class DocumentStorage:
doc_id (str): The doc_id of the document to retrieve.
Returns:
dict: The document data.
dict: The document data or None if not found.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
row = await cursor.fetchone()
if row:
return await self.tuple_to_dict(row)
else:
return None
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
return self._document_to_dict(document)
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
"""Retrieve a document by its doc_id.
"""Update a document by its doc_id.
Args:
doc_id (str): The doc_id.
new_text (str): The new text to update the document with.
"""
async with self.connection.cursor() as cursor:
await cursor.execute(
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
)
await self.connection.commit()
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
query = select(Document).where(col(Document.doc_id) == doc_id)
result = await session.execute(query)
document = result.scalar_one_or_none()
if document:
document.text = new_text
document.updated_at = datetime.now()
session.add(document)
async def delete_documents(self, metadata_filters: dict):
"""Delete documents by their metadata filters.
Args:
metadata_filters (dict): The metadata filters to apply.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
async with session.begin():
query = select(Document)
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
result = await session.execute(query)
documents = result.scalars().all()
for doc in documents:
await session.delete(doc)
async def count_documents(self, metadata_filters: dict | None = None) -> int:
"""Count documents in the database.
Args:
metadata_filters (dict | None): Metadata filters to apply.
Returns:
int: The count of documents.
"""
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = select(func.count(col(Document.id)))
if metadata_filters:
for key, val in metadata_filters.items():
query = query.where(
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
).params(**{f"filter_{key}": val})
result = await session.execute(query)
count = result.scalar_one_or_none()
return count if count is not None else 0
async def get_user_ids(self) -> list[str]:
"""Retrieve all user IDs from the documents table.
@@ -91,11 +308,38 @@ class DocumentStorage:
Returns:
list: A list of user IDs.
"""
async with self.connection.cursor() as cursor:
await cursor.execute("SELECT DISTINCT user_id FROM documents")
rows = await cursor.fetchall()
assert self.engine is not None, "Database connection is not initialized."
async with self.get_session() as session:
query = text(
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
)
result = await session.execute(query)
rows = result.fetchall()
return [row[0] for row in rows]
def _document_to_dict(self, document: Document) -> dict:
"""Convert a Document model to a dictionary.
Args:
document (Document): The document to convert.
Returns:
dict: The converted dictionary.
"""
return {
"id": document.id,
"doc_id": document.doc_id,
"text": document.text,
"metadata": document.metadata_,
"created_at": document.created_at.isoformat()
if isinstance(document.created_at, datetime)
else document.created_at,
"updated_at": document.updated_at.isoformat()
if isinstance(document.updated_at, datetime)
else document.updated_at,
}
async def tuple_to_dict(self, row):
"""Convert a tuple to a dictionary.
@@ -104,6 +348,8 @@ class DocumentStorage:
Returns:
dict: The converted dictionary.
Note: This method is kept for backward compatibility but is no longer used internally.
"""
return {
"id": row[0],
@@ -116,6 +362,7 @@ class DocumentStorage:
async def close(self):
"""Close the connection to the SQLite database."""
if self.connection:
await self.connection.close()
self.connection = None
if self.engine:
await self.engine.dispose()
self.engine = None
self.async_session_maker = None
@@ -9,7 +9,7 @@ import numpy as np
class EmbeddingStorage:
def __init__(self, dimension: int, path: str = None):
def __init__(self, dimension: int, path: str | None = None):
self.dimension = dimension
self.path = path
self.index = None
@@ -18,7 +18,6 @@ class EmbeddingStorage:
else:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)
self.storage = {}
async def insert(self, vector: np.ndarray, id: int):
"""插入向量
@@ -29,12 +28,29 @@ class EmbeddingStorage:
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
assert self.index is not None, "FAISS index is not initialized."
if vector.shape[0] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
)
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
self.storage[id] = vector
await self.save_index()
async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
"""批量插入向量
Args:
vectors (np.ndarray): 要插入的向量数组
ids (list[int]): 向量的ID列表
Raises:
ValueError: 如果向量的维度与存储的维度不匹配
"""
assert self.index is not None, "FAISS index is not initialized."
if vectors.shape[1] != self.dimension:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}"
)
self.index.add_with_ids(vectors, np.array(ids))
await self.save_index()
async def search(self, vector: np.ndarray, k: int) -> tuple:
@@ -46,10 +62,22 @@ class EmbeddingStorage:
Returns:
tuple: (距离, 索引)
"""
assert self.index is not None, "FAISS index is not initialized."
faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k)
return distances, indices
async def delete(self, ids: list[int]):
"""删除向量
Args:
ids (list[int]): 要删除的向量ID列表
"""
assert self.index is not None, "FAISS index is not initialized."
id_array = np.array(ids, dtype=np.int64)
self.index.remove_ids(id_array)
await self.save_index()
async def save_index(self):
"""保存索引
+81 -23
View File
@@ -1,11 +1,12 @@
import uuid
import json
import time
import numpy as np
from .document_storage import DocumentStorage
from .embedding_storage import EmbeddingStorage
from ..base import Result, BaseVecDB
from astrbot.core.provider.provider import EmbeddingProvider
from astrbot.core.provider.provider import RerankProvider
from astrbot import logger
class FaissVecDB(BaseVecDB):
@@ -44,18 +45,56 @@ class FaissVecDB(BaseVecDB):
vector = await self.embedding_provider.get_embedding(content)
vector = np.array(vector, dtype=np.float32)
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute(
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
(str_id, content, json.dumps(metadata)),
)
await self.document_storage.connection.commit()
result = await self.document_storage.get_document_by_doc_id(str_id)
int_id = result["id"]
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
# 使用 DocumentStorage 的方法插入文档
int_id = await self.document_storage.insert_document(str_id, content, metadata)
# 插入向量到 FAISS
await self.embedding_storage.insert(vector, int_id)
return int_id
async def insert_batch(
self,
contents: list[str],
metadatas: list[dict] | None = None,
ids: list[str] | None = None,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> list[int]:
"""
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
Args:
progress_callback: 进度回调函数,接收参数 (current, total)
"""
metadatas = metadatas or [{} for _ in contents]
ids = ids or [str(uuid.uuid4()) for _ in contents]
start = time.time()
logger.debug(f"Generating embeddings for {len(contents)} contents...")
vectors = await self.embedding_provider.get_embeddings_batch(
contents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
)
end = time.time()
logger.debug(
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds."
)
# 使用 DocumentStorage 的批量插入方法
int_ids = await self.document_storage.insert_documents_batch(
ids, contents, metadatas
)
# 批量插入向量到 FAISS
vectors_array = np.array(vectors).astype("float32")
await self.embedding_storage.insert_batch(vectors_array, int_ids)
return int_ids
async def retrieve(
self,
@@ -119,23 +158,42 @@ class FaissVecDB(BaseVecDB):
return top_k_results
async def delete(self, doc_id: int):
async def delete(self, doc_id: str):
"""
删除一条文档
删除一条文档块(chunk
"""
await self.document_storage.connection.execute(
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
)
await self.document_storage.connection.commit()
# 获得对应的 int id
result = await self.document_storage.get_document_by_doc_id(doc_id)
int_id = result["id"] if result else None
if int_id is None:
return
# 使用 DocumentStorage 的删除方法
await self.document_storage.delete_document_by_doc_id(doc_id)
await self.embedding_storage.delete([int_id])
async def close(self):
await self.document_storage.close()
async def count_documents(self) -> int:
async def count_documents(self, metadata_filter: dict | None = None) -> int:
"""
计算文档数量
Args:
metadata_filter (dict | None): 元数据过滤器
"""
async with self.document_storage.connection.cursor() as cursor:
await cursor.execute("SELECT COUNT(*) FROM documents")
count = await cursor.fetchone()
return count[0] if count else 0
count = await self.document_storage.count_documents(
metadata_filters=metadata_filter or {}
)
return count
async def delete_documents(self, metadata_filters: dict):
"""
根据元数据过滤器删除文档
"""
docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters, offset=None, limit=None
)
doc_ids: list[int] = [doc["id"] for doc in docs]
await self.embedding_storage.delete(doc_ids)
await self.document_storage.delete_documents(metadata_filters=metadata_filters)
-34
View File
@@ -1,34 +0,0 @@
"""
知识库管理模块
提供文档上传、解析、分块、向量化、检索等功能
"""
from astrbot.core.knowledge_base.models import (
KBChunk,
KBDocument,
KBMedia,
KBSessionConfig,
KnowledgeBase,
)
# 注意: 以下导入在对应模块实现后取消注释
from .database import KBDatabase
from .manager import KBManager
from .manager_ops import KBManagerOps
from .session_config_db import SessionConfigDB
# from .injector import KnowledgeBaseInjector
__all__ = [
"KnowledgeBase",
"KBDocument",
"KBChunk",
"KBMedia",
"KBSessionConfig",
"KBDatabase",
"SessionConfigDB",
"KBManager",
"KBManagerOps",
# "KnowledgeBaseInjector",
]
+1 -1
View File
@@ -13,7 +13,7 @@ class BaseChunker(ABC):
"""
@abstractmethod
async def chunk(self, text: str) -> list[str]:
async def chunk(self, text: str, **kwargs) -> list[str]:
"""将文本分块
Args:
@@ -22,7 +22,7 @@ class FixedSizeChunker(BaseChunker):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
async def chunk(self, text: str) -> list[str]:
async def chunk(self, text: str, **kwargs) -> list[str]:
"""固定大小分块
Args:
@@ -31,22 +31,25 @@ class FixedSizeChunker(BaseChunker):
Returns:
list[str]: 分块后的文本列表
"""
chunk_size = kwargs.get("chunk_size", self.chunk_size)
chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
chunks = []
start = 0
text_len = len(text)
while start < text_len:
end = start + self.chunk_size
end = start + chunk_size
chunk = text[start:end]
if chunk:
chunks.append(chunk)
# 移动窗口,保留重叠部分
start = end - self.chunk_overlap
start = end - chunk_overlap
# 防止无限循环: 如果重叠过大,直接移到end
if start >= end or self.chunk_overlap >= self.chunk_size:
if start >= end or chunk_overlap >= chunk_size:
start = end
return chunks
-347
View File
@@ -1,347 +0,0 @@
"""知识库数据库操作类
该模块封装知识库、文档、块、多媒体和会话配置相关的数据库查询操作。
注意:
- 该模块操作的是独立的知识库数据库 (data/knowledge_base/kb.db)
- 会话配置也存储在此数据库中,会话ID来源于主数据库
"""
import json
from typing import Optional
from sqlalchemy import func, select
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.models import (
KBChunk,
KBDocument,
KBMedia,
KBSessionConfig,
KnowledgeBase,
)
class KBDatabase:
"""知识库数据库操作类
职责:
- 封装知识库、文档、块、多媒体和会话配置的数据库查询操作
- 统一异常处理
注意:
- 该类操作独立的知识库数据库 (kb.db)
- 会话配置存储会话ID与知识库的绑定关系,会话ID来源于主数据库
"""
def __init__(self, kb_db: KBSQLiteDatabase):
"""初始化知识库数据库操作类
Args:
kb_db: 知识库独立数据库实例,而非主数据库
"""
self.db = kb_db
# ===== 知识库查询 =====
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
"""根据 ID 获取知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
"""根据名称获取知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_name == kb_name)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.db.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(KnowledgeBase.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_kbs(self) -> int:
"""统计知识库数量"""
async with self.db.get_db() as session:
stmt = select(func.count(KnowledgeBase.id))
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 文档查询 =====
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
"""根据 ID 获取文档"""
async with self.db.get_db() as session:
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_documents_by_kb(
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.db.get_db() as session:
stmt = (
select(KBDocument)
.where(KBDocument.kb_id == kb_id)
.offset(offset)
.limit(limit)
.order_by(KBDocument.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_documents_by_kb(self, kb_id: str) -> int:
"""统计知识库的文档数量"""
async with self.db.get_db() as session:
stmt = select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 块查询 =====
async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]:
"""根据 ID 获取块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]:
"""根据知识库 ID 列表获取所有块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.kb_id.in_(kb_ids))
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]:
"""根据向量文档 ID 获取块"""
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.vec_doc_id == vec_doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]:
"""获取块及其关联的文档和知识库元数据"""
async with self.db.get_db() as session:
stmt = (
select(KBChunk, KBDocument, KnowledgeBase)
.join(KBDocument, KBChunk.doc_id == KBDocument.doc_id)
.join(KnowledgeBase, KBChunk.kb_id == KnowledgeBase.kb_id)
.where(KBChunk.chunk_id == chunk_id)
)
result = await session.execute(stmt)
row = result.first()
if not row:
return None
chunk, doc, kb = row
return {
"chunk": chunk,
"document": doc,
"knowledge_base": kb,
}
async def list_chunks_by_doc(
self, doc_id: str, offset: int = 0, limit: int = 100
) -> list[KBChunk]:
"""列出文档的所有块"""
async with self.db.get_db() as session:
stmt = (
select(KBChunk)
.where(KBChunk.doc_id == doc_id)
.offset(offset)
.limit(limit)
.order_by(KBChunk.chunk_index)
)
result = await session.execute(stmt)
return list(result.scalars().all())
# ===== 多媒体查询 =====
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
"""根据 ID 获取多媒体资源"""
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
# ===== 会话配置查询 =====
async def get_session_kb_ids(self, session_id: str) -> list[str]:
"""获取会话关联的知识库 ID 列表
查找顺序:
1. 会话级别配置 (优先)
2. 平台级别配置
3. 返回空列表
Args:
session_id: 会话ID(来自主数据库)
Returns:
知识库ID列表
"""
async with self.db.get_db() as session:
# 1. 查找会话级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "session",
KBSessionConfig.scope_id == session_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 2. 提取平台 ID (格式: platform:xxx:session_id)
parts = session_id.split(":")
if len(parts) >= 2:
platform_id = parts[0]
# 查找平台级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "platform",
KBSessionConfig.scope_id == platform_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 3. 无配置
return []
async def set_session_kb_ids(
self,
scope: str,
scope_id: str,
kb_ids: list[str],
top_k: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> KBSessionConfig:
"""设置会话知识库配置
Args:
scope: 配置范围 (session/platform)
scope_id: 范围标识 (会话 ID 或平台 ID,来自主数据库)
kb_ids: 知识库 ID 列表
top_k: 返回结果数量 (可选)
enable_rerank: 是否启用 Rerank (可选)
Returns:
配置对象
"""
async with self.db.get_db() as session:
# 查找现有配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
# 更新现有配置
config.kb_ids = json.dumps(kb_ids)
if top_k is not None:
config.top_k = top_k
if enable_rerank is not None:
config.enable_rerank = enable_rerank
else:
# 创建新配置
config = KBSessionConfig(
scope=scope,
scope_id=scope_id,
kb_ids=json.dumps(kb_ids),
top_k=top_k,
enable_rerank=enable_rerank,
)
session.add(config)
await session.commit()
await session.refresh(config)
return config
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
"""删除会话知识库配置
Args:
scope: 配置范围 (session/platform)
scope_id: 范围标识 (会话 ID 或平台 ID)
Returns:
是否删除成功
"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if not config:
return False
await session.delete(config)
await session.commit()
return True
async def delete_session_kb_config_by_session_id(self, session_id: str) -> bool:
"""根据会话ID删除会话配置(用于主数据库会话删除时的级联清理)
Args:
session_id: 会话ID(来自主数据库)
Returns:
是否删除成功
"""
return await self.delete_session_kb_config("session", session_id)
async def list_all_session_configs(
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
) -> list[KBSessionConfig]:
"""列出所有会话配置
Args:
offset: 偏移量
limit: 限制数量
scope: 可选的范围过滤 (session/platform)
Returns:
会话配置列表
"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig)
if scope:
stmt = stmt.where(KBSessionConfig.scope == scope)
stmt = (
stmt.offset(offset)
.limit(limit)
.order_by(KBSessionConfig.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
-115
View File
@@ -1,115 +0,0 @@
"""知识库上下文注入器
负责检索相关知识并格式化为 LLM 可用的上下文文本
"""
from typing import List, Optional
from astrbot.core.knowledge_base.database import KBDatabase
from astrbot.core.knowledge_base.retrieval.manager import (
RetrievalManager,
RetrievalResult,
)
class KnowledgeBaseInjector:
"""知识库上下文注入器
职责:
- 检索相关知识
- 格式化为上下文文本
- 注入到 LLM Prompt
"""
def __init__(
self,
kb_db: KBDatabase,
retrieval_manager: RetrievalManager,
):
"""初始化知识库上下文注入器
Args:
kb_db: 知识库数据库实例
retrieval_manager: 检索管理器实例
"""
self.kb_db = kb_db
self.retrieval_manager = retrieval_manager
async def retrieve_and_inject(
self,
unified_msg_origin: str,
query: str,
top_k: int = 5,
) -> Optional[dict]:
"""检索并注入知识库上下文
Args:
unified_msg_origin: 统一消息来源 ID (会话 ID)
query: 用户查询
top_k: 返回结果数量
Returns:
Optional[dict]: 包含检索结果和格式化上下文的字典,如果无结果则返回 None
{
"context_text": str, # 格式化的上下文文本
"results": List[dict], # 原始检索结果列表
}
"""
# 1. 获取会话关联的知识库
kb_ids = await self.kb_db.get_session_kb_ids(unified_msg_origin)
if not kb_ids:
return None
# 2. 检索知识
results = await self.retrieval_manager.retrieve(
query=query,
kb_ids=kb_ids,
top_m_final=top_k,
)
if not results:
return None
# 3. 格式化上下文
context_text = self._format_context(results)
# 4. 转换结果为字典格式
results_dict = [
{
"chunk_id": r.chunk_id,
"doc_id": r.doc_id,
"kb_id": r.kb_id,
"kb_name": r.kb_name,
"doc_name": r.doc_name,
"chunk_index": r.metadata.get("chunk_index", 0),
"content": r.content,
"score": r.score,
}
for r in results
]
return {
"context_text": context_text,
"results": results_dict,
}
def _format_context(self, results: List[RetrievalResult]) -> str:
"""格式化知识上下文
Args:
results: 检索结果列表
Returns:
str: 格式化的上下文文本
"""
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
for i, result in enumerate(results, 1):
lines.append(f"【知识 {i}")
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
lines.append(f"内容: {result.content}")
lines.append(f"相关度: {result.score:.2f}")
lines.append("")
return "\n".join(lines)
+299
View File
@@ -0,0 +1,299 @@
from contextlib import asynccontextmanager
from pathlib import Path
from sqlmodel import col, desc
from sqlalchemy import text, func, select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core import logger
from astrbot.core.knowledge_base.models import (
BaseKBModel,
KBDocument,
KBMedia,
KnowledgeBase,
)
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
class KBSQLiteDatabase:
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
"""初始化知识库数据库
Args:
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
"""
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
# 确保目录存在
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# 创建异步引擎
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
pool_pre_ping=True,
pool_recycle=3600,
)
# 创建会话工厂
self.async_session = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
)
@asynccontextmanager
async def get_db(self):
"""获取数据库会话
用法:
async with kb_db.get_db() as session:
# 执行数据库操作
result = await session.execute(stmt)
"""
async with self.async_session() as session:
yield session
async def initialize(self) -> None:
"""初始化数据库,创建表并配置 SQLite 参数"""
async with self.engine.begin() as conn:
# 创建所有知识库相关表
await conn.run_sync(BaseKBModel.metadata.create_all)
# 配置 SQLite 性能优化参数
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit()
self.inited = True
async def migrate_to_v1(self) -> None:
"""执行知识库数据库 v1 迁移
创建所有必要的索引以优化查询性能
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# 创建知识库表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
"ON knowledge_bases(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_name "
"ON knowledge_bases(kb_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
"ON knowledge_bases(created_at)"
)
)
# 创建文档表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
"ON kb_documents(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
"ON kb_documents(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_name "
"ON kb_documents(doc_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_type "
"ON kb_documents(file_type)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
"ON kb_documents(created_at)"
)
)
# 创建多媒体表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
"ON kb_media(media_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
"ON kb_media(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_type "
"ON kb_media(media_type)"
)
)
await session.commit()
async def close(self) -> None:
"""关闭数据库连接"""
await self.engine.dispose()
logger.info(f"知识库数据库已关闭: {self.db_path}")
async def get_kb_by_id(self, kb_id: str) -> KnowledgeBase | None:
"""根据 ID 获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None:
"""根据名称获取知识库"""
async with self.get_db() as session:
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(desc(KnowledgeBase.created_at))
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_kbs(self) -> int:
"""统计知识库数量"""
async with self.get_db() as session:
stmt = select(func.count(col(KnowledgeBase.id)))
result = await session.execute(stmt)
return result.scalar() or 0
# ===== 文档查询 =====
async def get_document_by_id(self, doc_id: str) -> KBDocument | None:
"""根据 ID 获取文档"""
async with self.get_db() as session:
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_documents_by_kb(
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.get_db() as session:
stmt = (
select(KBDocument)
.where(col(KBDocument.kb_id) == kb_id)
.offset(offset)
.limit(limit)
.order_by(desc(KBDocument.created_at))
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def count_documents_by_kb(self, kb_id: str) -> int:
"""统计知识库的文档数量"""
async with self.get_db() as session:
stmt = select(func.count(col(KBDocument.id))).where(
col(KBDocument.kb_id) == kb_id
)
result = await session.execute(stmt)
return result.scalar() or 0
async def get_document_with_metadata(self, doc_id: str) -> dict | None:
async with self.get_db() as session:
stmt = (
select(KBDocument, KnowledgeBase)
.join(KnowledgeBase, col(KBDocument.kb_id) == col(KnowledgeBase.kb_id))
.where(col(KBDocument.doc_id) == doc_id)
)
result = await session.execute(stmt)
row = result.first()
if not row:
return None
return {
"document": row[0],
"knowledge_base": row[1],
}
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
"""删除单个文档及其相关数据"""
# 在知识库表中删除
async with self.get_db() as session:
async with session.begin():
# 删除文档记录
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
await session.execute(delete_stmt)
await session.commit()
# 在 vec db 中删除相关向量
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
# ===== 多媒体查询 =====
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.get_db() as session:
stmt = select(KBMedia).where(col(KBMedia.doc_id) == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_media_by_id(self, media_id: str) -> KBMedia | None:
"""根据 ID 获取多媒体资源"""
async with self.get_db() as session:
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def update_kb_stats(self, kb_id: str, vec_db: FaissVecDB) -> None:
"""更新知识库统计信息"""
chunk_cnt = await vec_db.count_documents()
async with self.get_db() as session:
async with session.begin():
update_stmt = (
update(KnowledgeBase)
.where(col(KnowledgeBase.kb_id) == kb_id)
.values(
doc_count=select(func.count(col(KBDocument.id)))
.where(col(KBDocument.kb_id) == kb_id)
.scalar_subquery(),
chunk_count=chunk_cnt,
)
)
await session.execute(update_stmt)
await session.commit()
+352
View File
@@ -0,0 +1,352 @@
import uuid
import aiofiles
import json
from pathlib import Path
from .models import KnowledgeBase, KBDocument, KBMedia
from .kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from astrbot.core.provider.manager import ProviderManager
from .parsers.base import BaseParser
from .chunking.base import BaseChunker
from astrbot.core import logger
class KBHelper:
vec_db: BaseVecDB
kb: KnowledgeBase
def __init__(
self,
kb_db: KBSQLiteDatabase,
kb: KnowledgeBase,
provider_manager: ProviderManager,
kb_root_dir: str,
chunker: BaseChunker,
parsers: dict[str, BaseParser],
):
self.kb_db = kb_db
self.kb = kb
self.prov_mgr = provider_manager
self.kb_root_dir = kb_root_dir
self.parsers = parsers
self.chunker = chunker
self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id
self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id
self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
async def initialize(self):
await self._ensure_vec_db()
async def get_ep(self) -> EmbeddingProvider:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
self.kb.embedding_provider_id
) # type: ignore
if not ep:
raise ValueError(
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider"
)
return ep
async def get_rp(self) -> RerankProvider | None:
if not self.kb.rerank_provider_id:
return None
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
self.kb.rerank_provider_id
) # type: ignore
if not rp:
raise ValueError(
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider"
)
return rp
async def _ensure_vec_db(self) -> FaissVecDB:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
ep = await self.get_ep()
rp = await self.get_rp()
vec_db = FaissVecDB(
doc_store_path=str(self.kb_dir / "doc.db"),
index_store_path=str(self.kb_dir / "index.faiss"),
embedding_provider=ep,
rerank_provider=rp,
)
await vec_db.initialize()
self.vec_db = vec_db
return vec_db
async def delete_vec_db(self):
await self.terminate()
if self.kb_dir.exists():
for item in self.kb_dir.iterdir():
if item.is_file():
item.unlink()
self.kb_dir.rmdir()
async def terminate(self):
if self.vec_db:
await self.vec_db.close()
async def upload_document(
self,
file_name: str,
file_content: bytes,
file_type: str,
chunk_size: int = 512,
chunk_overlap: int = 50,
batch_size: int = 32,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> KBDocument:
"""上传并处理文档(带原子性保证和失败清理)
流程:
1. 保存原始文件
2. 解析文档内容
3. 提取多媒体资源
4. 分块处理
5. 生成向量并存储
6. 保存元数据(事务)
7. 更新统计
Args:
progress_callback: 进度回调函数,接收参数 (stage, current, total)
- stage: 当前阶段 ('parsing', 'chunking', 'embedding')
- current: 当前进度
- total: 总数
"""
await self._ensure_vec_db()
doc_id = str(uuid.uuid4())
media_paths: list[Path] = []
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
# async with aiofiles.open(file_path, "wb") as f:
# await f.write(file_content)
try:
# 阶段1: 解析文档
if progress_callback:
await progress_callback("parsing", 0, 100)
parser = self.parsers.get(file_type)
if not parser:
raise ValueError(f"不支持的文件类型: {file_type}")
parse_result = await parser.parse(file_content, file_name)
text_content = parse_result.text
media_items = parse_result.media
if progress_callback:
await progress_callback("parsing", 100, 100)
# 保存媒体文件
saved_media = []
for media_item in media_items:
media = await self._save_media(
doc_id=doc_id,
media_type=media_item.media_type,
file_name=media_item.file_name,
content=media_item.content,
mime_type=media_item.mime_type,
)
saved_media.append(media)
media_paths.append(Path(media.file_path))
# 阶段2: 分块
if progress_callback:
await progress_callback("chunking", 0, 100)
chunks_text = await self.chunker.chunk(
text_content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
contents = []
metadatas = []
for idx, chunk_text in enumerate(chunks_text):
contents.append(chunk_text)
metadatas.append(
{
"kb_id": self.kb.kb_id,
"kb_doc_id": doc_id,
"chunk_index": idx,
}
)
if progress_callback:
await progress_callback("chunking", 100, 100)
# 阶段3: 生成向量(带进度回调)
async def embedding_progress_callback(current, total):
if progress_callback:
await progress_callback("embedding", current, total)
await self.vec_db.insert_batch(
contents=contents,
metadatas=metadatas,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=embedding_progress_callback,
)
# 保存文档的元数据
doc = KBDocument(
doc_id=doc_id,
kb_id=self.kb.kb_id,
doc_name=file_name,
file_type=file_type,
file_size=len(file_content),
# file_path=str(file_path),
file_path="",
chunk_count=len(chunks_text),
media_count=0,
)
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
for media in saved_media:
session.add(media)
await session.commit()
await session.refresh(doc)
vec_db: FaissVecDB = self.vec_db # type: ignore
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db)
await self.refresh_kb()
await self.refresh_document(doc_id)
return doc
except Exception as e:
logger.error(f"上传文档失败: {e}")
# if file_path.exists():
# file_path.unlink()
for media_path in media_paths:
try:
if media_path.exists():
media_path.unlink()
except Exception as me:
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
raise e
async def list_documents(
self, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
return docs
async def get_document(self, doc_id: str) -> KBDocument | None:
"""获取单个文档"""
doc = await self.kb_db.get_document_by_id(doc_id)
return doc
async def delete_document(self, doc_id: str):
"""删除单个文档及其相关数据"""
await self.kb_db.delete_document_by_id(
doc_id=doc_id,
vec_db=self.vec_db, # type: ignore
)
await self.kb_db.update_kb_stats(
kb_id=self.kb.kb_id,
vec_db=self.vec_db, # type: ignore
)
await self.refresh_kb()
async def delete_chunk(self, chunk_id: str, doc_id: str):
"""删除单个文本块及其相关数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
await vec_db.delete(chunk_id)
await self.kb_db.update_kb_stats(
kb_id=self.kb.kb_id,
vec_db=self.vec_db, # type: ignore
)
await self.refresh_kb()
await self.refresh_document(doc_id)
async def refresh_kb(self):
if self.kb:
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
if kb:
self.kb = kb
async def refresh_document(self, doc_id: str) -> None:
"""更新文档的元数据"""
doc = await self.get_document(doc_id)
if not doc:
raise ValueError(f"无法找到 ID 为 {doc_id} 的文档")
chunk_count = await self.get_chunk_count_by_doc_id(doc_id)
doc.chunk_count = chunk_count
async with self.kb_db.get_db() as session:
async with session.begin():
session.add(doc)
await session.commit()
await session.refresh(doc)
async def get_chunks_by_doc_id(
self, doc_id: str, offset: int = 0, limit: int = 100
) -> list[dict]:
"""获取文档的所有块及其元数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
chunks = await vec_db.document_storage.get_documents(
metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit
)
result = []
for chunk in chunks:
chunk_md = json.loads(chunk["metadata"])
result.append(
{
"chunk_id": chunk["doc_id"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": chunk_md["kb_id"],
"chunk_index": chunk_md["chunk_index"],
"content": chunk["text"],
"char_count": len(chunk["text"]),
}
)
return result
async def get_chunk_count_by_doc_id(self, doc_id: str) -> int:
"""获取文档的块数量"""
vec_db: FaissVecDB = self.vec_db # type: ignore
count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id})
return count
async def _save_media(
self,
doc_id: str,
media_type: str,
file_name: str,
content: bytes,
mime_type: str,
) -> KBMedia:
"""保存多媒体资源"""
media_id = str(uuid.uuid4())
ext = Path(file_name).suffix
# 保存文件
file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}"
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
media = KBMedia(
media_id=media_id,
doc_id=doc_id,
kb_id=self.kb.kb_id,
media_type=media_type,
file_name=file_name,
file_path=str(file_path),
file_size=len(content),
mime_type=mime_type,
)
return media
@@ -1,364 +0,0 @@
"""
知识库管理器
负责知识库模块的初始化、配置和资源管理
架构说明:
- 知识库数据存储在独立的数据库 (data/knowledge_base/kb.db)
- 会话配置存储在主数据库 (data/astrbot.db) 以便于会话关联
"""
from pathlib import Path
from astrbot.core import logger
from astrbot.core.db import BaseDatabase
from astrbot.core.provider.manager import ProviderManager
class KnowledgeBaseManager:
"""知识库管理器
职责:
- 知识库模块的初始化
- Embedding Provider 和 Rerank Provider 的选择
- 各个子组件的协调管理
- 注册会话删除回调,实现级联清理
架构说明:
- 知识库数据存储在独立数据库 (kb.db)
- 会话配置存储在独立数据库 (kb.db),会话ID来自主数据库
- 通过回调机制实现与主数据库的生命周期同步
"""
def __init__(
self,
config: dict,
main_db: BaseDatabase,
provider_manager: ProviderManager,
):
"""初始化知识库管理器
Args:
config: 配置字典
main_db: 主数据库实例 (不直接使用,仅用于类型引用)
provider_manager: Provider 管理器
"""
self.config = config.get("knowledge_base", {})
self.provider_manager = provider_manager
# 知识库独立数据库
self.kb_db = None
# 组件实例
self.kb_database = None
self.kb_manager = None
self.kb_vec_db = None
self.retrieval_manager = None
self.kb_injector = None
self._initialized = False
self._session_deleted_callback_registered = False
async def initialize(self):
"""初始化知识库模块"""
if not self.config.get("enabled", False):
logger.info("知识库功能未启用")
return
try:
logger.info("正在初始化知识库模块...")
# 1. 检查并选择 Embedding Provider
embedding_provider = self._select_embedding_provider()
if not embedding_provider:
logger.warning("未配置 Embedding Provider,知识库功能无法使用")
return
# 2. 初始化数据库
await self._init_kb_database()
await self._init_database()
# 3. 初始化向量数据库
await self._init_vector_db(embedding_provider)
# 4. 初始化解析器和分块器
parsers = self._init_parsers()
chunker = self._init_chunker()
# 5. 初始化知识库管理器
await self._init_kb_manager(parsers, chunker)
# 6. 初始化检索管理器
await self._init_retrieval_manager()
# 7. 初始化上下文注入器
await self._init_injector()
self._initialized = True
logger.info("知识库模块初始化完成")
except ImportError as e:
logger.error(f"知识库模块导入失败: {e}")
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
except Exception as e:
logger.error(f"知识库模块初始化失败: {e}")
import traceback
logger.error(traceback.format_exc())
async def _init_kb_database(self):
"""初始化知识库独立数据库"""
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
db_path = self.config.get("storage", {}).get(
"kb_db_path", "data/knowledge_base/kb.db"
)
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
self.kb_db = KBSQLiteDatabase(db_path)
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
logger.info(f"知识库独立数据库已初始化: {db_path}")
async def _init_database(self):
"""初始化知识库数据库操作类"""
from astrbot.core.knowledge_base.database import KBDatabase
self.kb_database = KBDatabase(self.kb_db)
async def _init_vector_db(self, embedding_provider):
"""初始化向量数据库"""
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
storage_path = self.config.get("storage", {}).get(
"vector_db_path", "data/knowledge_base/vectors"
)
Path(storage_path).mkdir(parents=True, exist_ok=True)
self.kb_vec_db = FaissVecDB(
doc_store_path=f"{storage_path}/documents.db",
index_store_path=f"{storage_path}/index.faiss",
embedding_provider=embedding_provider,
)
await self.kb_vec_db.initialize()
def _init_parsers(self) -> dict:
"""初始化文档解析器"""
from astrbot.core.knowledge_base.parsers.text_parser import TextParser
from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser
return {
"txt": TextParser(),
"md": TextParser(),
"markdown": TextParser(),
"pdf": PDFParser(),
}
def _init_chunker(self):
"""初始化分块器"""
from astrbot.core.knowledge_base.chunking.fixed_size import FixedSizeChunker
chunking_config = self.config.get("chunking", {})
return FixedSizeChunker(
chunk_size=chunking_config.get("chunk_size", 512),
chunk_overlap=chunking_config.get("chunk_overlap", 50),
)
async def _init_kb_manager(self, parsers: dict, chunker):
"""初始化知识库管理器"""
from astrbot.core.knowledge_base.manager import KBManager
files_path = self.config.get("storage", {}).get(
"files_path", "data/knowledge_base"
)
self.kb_manager = KBManager(
db=self.kb_db, # 使用独立的知识库数据库
vec_db=self.kb_vec_db,
storage_path=files_path,
parsers=parsers,
chunker=chunker,
provider_manager=self.provider_manager,
)
async def _init_retrieval_manager(self):
"""初始化检索管理器"""
from astrbot.core.knowledge_base.retrieval.manager import RetrievalManager
from astrbot.core.knowledge_base.retrieval.sparse_retriever import (
SparseRetriever,
)
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
sparse_retriever = SparseRetriever(self.kb_database)
rank_fusion = RankFusion(self.kb_database)
# 选择 Rerank Provider (可选)
rerank_provider = self._select_rerank_provider()
self.retrieval_manager = RetrievalManager(
vec_db=self.kb_vec_db,
sparse_retriever=sparse_retriever,
rank_fusion=rank_fusion,
kb_db=self.kb_database,
rerank_provider=rerank_provider,
)
async def _init_injector(self):
"""初始化上下文注入器"""
from astrbot.core.knowledge_base.injector import KnowledgeBaseInjector
self.kb_injector = KnowledgeBaseInjector(
kb_db=self.kb_database,
retrieval_manager=self.retrieval_manager,
)
def _select_embedding_provider(self):
"""选择 Embedding Provider
逻辑:
- 如果配置了 embedding_provider_id,则使用指定的 provider
- 如果没有配置,但有 embedding provider,则使用第一个
- 如果有多个 embedding provider 但没有指定,则警告并使用第一个
"""
embedding_providers = self.provider_manager.embedding_provider_insts
if not embedding_providers:
return None
configured_provider_id = self.config.get("embedding_provider_id")
if configured_provider_id:
# 按 ID 查找
for provider in embedding_providers:
provider_id = provider.meta().id
if provider_id == configured_provider_id:
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
return provider
logger.warning(
f"未找到配置的 Embedding Provider ID: {configured_provider_id}"
f"将使用第一个可用的"
)
if len(embedding_providers) > 1 and not configured_provider_id:
provider = embedding_providers[0]
provider_id = provider.meta().id
logger.info(
f"检测到 {len(embedding_providers)} 个 Embedding Provider"
f"未在配置文件中指定 embedding_provider_id,将使用第一个: {provider_id}"
)
return provider
provider = embedding_providers[0]
provider_id = provider.meta().id
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
return provider
def _select_rerank_provider(self):
"""选择 Rerank Provider (可选)"""
if not self.config.get("retrieval", {}).get("enable_rerank", True):
return None
rerank_providers = self.provider_manager.rerank_provider_insts
if not rerank_providers:
return None
configured_provider_id = self.config.get("rerank_provider_id")
if configured_provider_id:
for provider in rerank_providers:
provider_id = provider.meta().id
if provider_id == configured_provider_id:
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
return provider
logger.warning(f"未找到配置的 Rerank Provider ID: {configured_provider_id}")
if len(rerank_providers) > 0:
provider = rerank_providers[0]
provider_id = provider.meta().id
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
return provider
return None
@property
def is_initialized(self) -> bool:
"""检查是否已初始化"""
return self._initialized
def get_kb_manager(self):
"""获取知识库管理器"""
return self.kb_manager if self._initialized else None
def get_kb_injector(self):
"""获取知识库上下文注入器"""
return self.kb_injector if self._initialized else None
def register_session_lifecycle_hooks(self, conversation_manager):
"""注册会话生命周期钩子
在会话删除时自动清理知识库配置,实现零侵入的级联清理。
Args:
conversation_manager: 会话管理器实例
"""
if self._session_deleted_callback_registered or not self._initialized:
return
async def on_session_deleted(session_id: str):
"""会话删除回调:清理知识库配置"""
try:
await self.kb_database.delete_session_kb_config_by_session_id(
session_id
)
logger.info(f"已清理会话知识库配置: {session_id}")
except Exception as e:
logger.error(f"清理会话知识库配置失败 ({session_id}): {e}")
conversation_manager.register_on_session_deleted(on_session_deleted)
self._session_deleted_callback_registered = True
logger.info("已注册知识库会话删除回调")
async def reinitialize(self):
"""重新初始化知识库模块
用于在运行时动态初始化知识库模块(例如用户添加了 embedding provider 后)
"""
if self._initialized:
logger.info("知识库模块已初始化,将重新初始化")
await self.terminate()
await self.initialize()
return self._initialized
async def terminate(self):
"""终止知识库模块,清理资源"""
if not self._initialized:
return
logger.info("正在终止知识库模块...")
# 关闭向量数据库连接
if self.kb_vec_db:
try:
await self.kb_vec_db.close()
logger.debug("向量数据库已关闭")
except Exception as e:
logger.warning(f"关闭向量数据库时出错: {e}")
# 关闭知识库独立数据库连接
if self.kb_db:
try:
await self.kb_db.close()
logger.debug("知识库数据库已关闭")
except Exception as e:
logger.warning(f"关闭知识库数据库时出错: {e}")
# 清理资源
self._initialized = False
self.kb_db = None
self.kb_database = None
self.kb_manager = None
self.kb_vec_db = None
self.retrieval_manager = None
self.kb_injector = None
logger.info("知识库模块已终止")
+279
View File
@@ -0,0 +1,279 @@
import traceback
from pathlib import Path
from astrbot.core import logger
from astrbot.core.provider.manager import ProviderManager
from .retrieval.manager import RetrievalManager, RetrievalResult
from .retrieval.sparse_retriever import SparseRetriever
from .retrieval.rank_fusion import RankFusion
from .kb_db_sqlite import KBSQLiteDatabase
from .parsers.text_parser import TextParser
from .parsers.pdf_parser import PDFParser
from .chunking.fixed_size import FixedSizeChunker
from .kb_helper import KBHelper
from .models import KnowledgeBase
FILES_PATH = "data/knowledge_base"
DB_PATH = Path(FILES_PATH) / "kb.db"
"""Knowledge Base storage root directory"""
PARSERS = {
"txt": TextParser(),
"md": TextParser(),
"markdown": TextParser(),
"pdf": PDFParser(),
}
CHUNKER = FixedSizeChunker()
class KnowledgeBaseManager:
kb_db: KBSQLiteDatabase
retrieval_manager: RetrievalManager
def __init__(
self,
provider_manager: ProviderManager,
):
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
self.provider_manager = provider_manager
self._session_deleted_callback_registered = False
self.kb_insts: dict[str, KBHelper] = {}
async def initialize(self):
"""初始化知识库模块"""
try:
logger.info("正在初始化知识库模块...")
# 初始化数据库
await self._init_kb_database()
# 初始化检索管理器
sparse_retriever = SparseRetriever(self.kb_db)
rank_fusion = RankFusion(self.kb_db)
self.retrieval_manager = RetrievalManager(
sparse_retriever=sparse_retriever,
rank_fusion=rank_fusion,
kb_db=self.kb_db,
)
await self.load_kbs()
except ImportError as e:
logger.error(f"知识库模块导入失败: {e}")
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
except Exception as e:
logger.error(f"知识库模块初始化失败: {e}")
logger.error(traceback.format_exc())
async def _init_kb_database(self):
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
async def load_kbs(self):
"""加载所有知识库实例"""
kb_records = await self.kb_db.list_kbs()
for record in kb_records:
kb_helper = KBHelper(
kb_db=self.kb_db,
kb=record,
provider_manager=self.provider_manager,
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
parsers=PARSERS,
)
await kb_helper.initialize()
self.kb_insts[record.kb_id] = kb_helper
async def create_kb(
self,
kb_name: str,
description: str | None = None,
emoji: str | None = None,
embedding_provider_id: str | None = None,
rerank_provider_id: str | None = None,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
top_k_dense: int | None = None,
top_k_sparse: int | None = None,
top_m_final: int | None = None,
) -> KBHelper:
"""创建新的知识库实例"""
kb = KnowledgeBase(
kb_name=kb_name,
description=description,
emoji=emoji or "📚",
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=chunk_size if chunk_size is not None else 512,
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
top_k_dense=top_k_dense if top_k_dense is not None else 50,
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
top_m_final=top_m_final if top_m_final is not None else 5,
)
async with self.kb_db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
kb_helper = KBHelper(
kb_db=self.kb_db,
kb=kb,
provider_manager=self.provider_manager,
kb_root_dir=FILES_PATH,
chunker=CHUNKER,
parsers=PARSERS,
)
await kb_helper.initialize()
self.kb_insts[kb.kb_id] = kb_helper
return kb_helper
async def get_kb(self, kb_id: str) -> KBHelper | None:
"""获取知识库实例"""
if kb_id in self.kb_insts:
return self.kb_insts[kb_id]
async def get_kb_by_name(self, kb_name: str) -> KBHelper | None:
"""通过名称获取知识库实例"""
for kb_helper in self.kb_insts.values():
if kb_helper.kb.kb_name == kb_name:
return kb_helper
return None
async def delete_kb(self, kb_id: str) -> bool:
"""删除知识库实例"""
kb_helper = await self.get_kb(kb_id)
if not kb_helper:
return False
await kb_helper.delete_vec_db()
async with self.kb_db.get_db() as session:
await session.delete(kb_helper.kb)
await session.commit()
self.kb_insts.pop(kb_id, None)
return True
async def list_kbs(self) -> list[KnowledgeBase]:
"""列出所有知识库实例"""
kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()]
return kbs
async def update_kb(
self,
kb_id: str,
kb_name: str,
description: str | None = None,
emoji: str | None = None,
embedding_provider_id: str | None = None,
rerank_provider_id: str | None = None,
chunk_size: int | None = None,
chunk_overlap: int | None = None,
top_k_dense: int | None = None,
top_k_sparse: int | None = None,
top_m_final: int | None = None,
) -> KBHelper | None:
"""更新知识库实例"""
kb_helper = await self.get_kb(kb_id)
if not kb_helper:
return None
kb = kb_helper.kb
if kb_name is not None:
kb.kb_name = kb_name
if description is not None:
kb.description = description
if emoji is not None:
kb.emoji = emoji
if embedding_provider_id is not None:
kb.embedding_provider_id = embedding_provider_id
kb.rerank_provider_id = rerank_provider_id # 允许设置为 None
if chunk_size is not None:
kb.chunk_size = chunk_size
if chunk_overlap is not None:
kb.chunk_overlap = chunk_overlap
if top_k_dense is not None:
kb.top_k_dense = top_k_dense
if top_k_sparse is not None:
kb.top_k_sparse = top_k_sparse
if top_m_final is not None:
kb.top_m_final = top_m_final
async with self.kb_db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
return kb_helper
async def retrieve(
self,
query: str,
kb_names: list[str],
top_k_fusion: int = 20,
top_m_final: int = 5,
) -> dict | None:
"""从指定知识库中检索相关内容"""
kb_ids = []
kb_id_helper_map = {}
for kb_name in kb_names:
if kb_helper := await self.get_kb_by_name(kb_name):
kb_ids.append(kb_helper.kb.kb_id)
kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper
if not kb_ids:
return {}
results = await self.retrieval_manager.retrieve(
query=query,
kb_ids=kb_ids,
kb_id_helper_map=kb_id_helper_map,
top_k_fusion=top_k_fusion,
top_m_final=top_m_final,
)
if not results:
return None
context_text = self._format_context(results)
results_dict = [
{
"chunk_id": r.chunk_id,
"doc_id": r.doc_id,
"kb_id": r.kb_id,
"kb_name": r.kb_name,
"doc_name": r.doc_name,
"chunk_index": r.metadata.get("chunk_index", 0),
"content": r.content,
"score": r.score,
"char_count": r.metadata.get("char_count", 0),
}
for r in results
]
return {
"context_text": context_text,
"results": results_dict,
}
def _format_context(self, results: list[RetrievalResult]) -> str:
"""格式化知识上下文
Args:
results: 检索结果列表
Returns:
str: 格式化的上下文文本
"""
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
for i, result in enumerate(results, 1):
lines.append(f"【知识 {i}")
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
lines.append(f"内容: {result.content}")
lines.append(f"相关度: {result.score:.2f}")
lines.append("")
return "\n".join(lines)
-230
View File
@@ -1,230 +0,0 @@
"""
知识库独立 SQLite 数据库
该模块提供知识库专用的独立 SQLite 数据库,与主数据库 (astrbot.db) 完全隔离。
职责:
- 管理知识库相关表 (knowledge_bases, kb_documents, kb_chunks, kb_media)
- 提供数据库连接和会话管理
- 执行数据库迁移和初始化
"""
from contextlib import asynccontextmanager
from pathlib import Path
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core import logger
class KBSQLiteDatabase:
"""知识库独立 SQLite 数据库
与主数据库 (astrbot.db) 完全隔离的独立数据库,专门用于存储知识库数据。
特点:
- 数据隔离: 知识库数据不会影响主数据库格式
- 独立备份: 可以单独备份和恢复知识库数据
- 性能隔离: 大量知识库查询不会影响主业务性能
"""
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
"""初始化知识库数据库
Args:
db_path: 数据库文件路径,默认为 data/knowledge_base/kb.db
"""
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.inited = False
# 确保目录存在
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# 创建异步引擎
self.engine = create_async_engine(
self.DATABASE_URL,
echo=False,
pool_pre_ping=True,
pool_recycle=3600,
)
# 创建会话工厂
self.async_session = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
)
@asynccontextmanager
async def get_db(self):
"""获取数据库会话
用法:
async with kb_db.get_db() as session:
# 执行数据库操作
result = await session.execute(stmt)
"""
async with self.async_session() as session:
yield session
async def initialize(self) -> None:
"""初始化数据库,创建表并配置 SQLite 参数"""
# noqa: F401 - 这些导入是必需的,用于触发 SQLModel 创建对应的数据库表
from astrbot.core.knowledge_base.models import ( # noqa: F401
KBChunk,
KBDocument,
KBMedia,
KBSessionConfig,
KnowledgeBase,
)
from sqlmodel import SQLModel
async with self.engine.begin() as conn:
# 创建所有知识库相关表
await conn.run_sync(SQLModel.metadata.create_all)
# 配置 SQLite 性能优化参数
await conn.execute(text("PRAGMA journal_mode=WAL"))
await conn.execute(text("PRAGMA synchronous=NORMAL"))
await conn.execute(text("PRAGMA cache_size=20000"))
await conn.execute(text("PRAGMA temp_store=MEMORY"))
await conn.execute(text("PRAGMA mmap_size=134217728"))
await conn.execute(text("PRAGMA optimize"))
await conn.commit()
self.inited = True
logger.info(f"知识库数据库已初始化: {self.db_path}")
async def migrate_to_v1(self) -> None:
"""执行知识库数据库 v1 迁移
创建所有必要的索引以优化查询性能
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
# 创建知识库表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
"ON knowledge_bases(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_name "
"ON knowledge_bases(kb_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
"ON knowledge_bases(created_at)"
)
)
# 创建文档表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
"ON kb_documents(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
"ON kb_documents(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_name "
"ON kb_documents(doc_name)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_type "
"ON kb_documents(file_type)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
"ON kb_documents(created_at)"
)
)
# 创建块表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_chunk_id "
"ON kb_chunks(chunk_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_doc_id "
"ON kb_chunks(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id ON kb_chunks(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_chunk_vec_doc_id "
"ON kb_chunks(vec_doc_id)"
)
)
# 创建多媒体表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
"ON kb_media(media_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
"ON kb_media(doc_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_media_type "
"ON kb_media(media_type)"
)
)
# 创建会话配置表索引
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_session_config_scope_id "
"ON kb_session_config(scope_id)"
)
)
await session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_session_config_scope "
"ON kb_session_config(scope)"
)
)
await session.commit()
logger.info("知识库数据库迁移 v1 完成")
async def close(self) -> None:
"""关闭数据库连接"""
await self.engine.dispose()
logger.info(f"知识库数据库已关闭: {self.db_path}")
-375
View File
@@ -1,375 +0,0 @@
"""知识库管理器
该模块提供知识库的CRUD操作和文档上传处理流程。
"""
import uuid
from pathlib import Path
from typing import Optional
import aiofiles
from sqlalchemy import func, select, update
from astrbot.core.db import BaseDatabase
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.knowledge_base.chunking.base import BaseChunker
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase
from astrbot.core.knowledge_base.parsers.base import BaseParser
class KBManager:
"""知识库管理器
职责:
- 知识库的 CRUD 操作
- 文档上传与解析
- 文档块生成与存储
- 多媒体资源管理
"""
def __init__(
self,
db: BaseDatabase,
vec_db: BaseVecDB,
storage_path: str,
parsers: dict[str, BaseParser],
chunker: BaseChunker,
provider_manager=None,
):
self.db = db
self.vec_db = vec_db
self.storage_path = Path(storage_path)
self.media_path = self.storage_path / "media"
self.files_path = self.storage_path / "files"
self.parsers = parsers
self.chunker = chunker
self.provider_manager = provider_manager
# 确保目录存在
self.media_path.mkdir(parents=True, exist_ok=True)
self.files_path.mkdir(parents=True, exist_ok=True)
# ===== 知识库操作 =====
async def create_kb(
self,
kb_name: str,
description: Optional[str] = None,
emoji: Optional[str] = None,
embedding_provider_id: Optional[str] = None,
rerank_provider_id: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
top_k_dense: Optional[int] = None,
top_k_sparse: Optional[int] = None,
top_m_final: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> KnowledgeBase:
"""创建知识库
Args:
enable_rerank: 是否启用重排序。
- 如果明确传入 True/False,则使用该值
- 如果为 None,则根据是否有可用的 rerank provider 自动决定
"""
# 智能决定 enable_rerank 的默认值
if enable_rerank is None:
# 检查是否有可用的 rerank provider
has_rerank_provider = (
self.provider_manager
and hasattr(self.provider_manager, 'rerank_provider_insts')
and len(self.provider_manager.rerank_provider_insts) > 0
)
enable_rerank = has_rerank_provider
kb = KnowledgeBase(
kb_name=kb_name,
description=description,
emoji=emoji or "📚",
embedding_provider_id=embedding_provider_id,
rerank_provider_id=rerank_provider_id,
chunk_size=chunk_size if chunk_size is not None else 512,
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
top_k_dense=top_k_dense if top_k_dense is not None else 50,
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
top_m_final=top_m_final if top_m_final is not None else 5,
enable_rerank=enable_rerank,
)
async with self.db.get_db() as session:
session.add(kb)
await session.commit()
await session.refresh(kb)
return kb
async def get_kb(self, kb_id: str) -> Optional[KnowledgeBase]:
"""获取知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
"""列出所有知识库"""
async with self.db.get_db() as session:
stmt = (
select(KnowledgeBase)
.offset(offset)
.limit(limit)
.order_by(KnowledgeBase.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def update_kb(
self,
kb_id: str,
kb_name: Optional[str] = None,
description: Optional[str] = None,
emoji: Optional[str] = None,
embedding_provider_id: Optional[str] = None,
rerank_provider_id: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_overlap: Optional[int] = None,
top_k_dense: Optional[int] = None,
top_k_sparse: Optional[int] = None,
top_m_final: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> Optional[KnowledgeBase]:
"""更新知识库"""
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
kb = result.scalar_one_or_none()
if not kb:
return None
if kb_name is not None:
kb.kb_name = kb_name
if description is not None:
kb.description = description
if emoji is not None:
kb.emoji = emoji
if embedding_provider_id is not None:
kb.embedding_provider_id = embedding_provider_id
if rerank_provider_id is not None:
kb.rerank_provider_id = rerank_provider_id
if chunk_size is not None:
kb.chunk_size = chunk_size
if chunk_overlap is not None:
kb.chunk_overlap = chunk_overlap
if top_k_dense is not None:
kb.top_k_dense = top_k_dense
if top_k_sparse is not None:
kb.top_k_sparse = top_k_sparse
if top_m_final is not None:
kb.top_m_final = top_m_final
if enable_rerank is not None:
kb.enable_rerank = enable_rerank
await session.commit()
await session.refresh(kb)
return kb
async def delete_kb(self, kb_id: str) -> bool:
"""删除知识库(级联删除所有文档和资源)"""
# 1. 获取所有文档
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
ops = KBManagerOps(self)
docs = await ops.list_documents(kb_id)
# 2. 删除所有文档(包括文件和向量)
for doc in docs:
await ops.delete_document(doc.doc_id)
# 3. 删除知识库记录
async with self.db.get_db() as session:
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
result = await session.execute(stmt)
kb = result.scalar_one_or_none()
if not kb:
return False
await session.delete(kb)
await session.commit()
return True
# ===== 文档上传 =====
async def upload_document(
self,
kb_id: str,
file_name: str,
file_content: bytes,
file_type: str,
) -> KBDocument:
"""上传并处理文档(带原子性保证和失败清理)
流程:
1. 保存原始文件
2. 解析文档内容
3. 提取多媒体资源
4. 分块处理
5. 生成向量并存储
6. 保存元数据(事务)
7. 更新统计
"""
doc_id = str(uuid.uuid4())
file_path = None
media_paths = []
vec_doc_ids = []
try:
# 1. 保存原始文件
file_path = self.files_path / kb_id / f"{doc_id}.{file_type}"
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, "wb") as f:
await f.write(file_content)
# 2. 解析文档
parser = self.parsers.get(file_type)
if not parser:
raise ValueError(f"不支持的文件类型: {file_type}")
parse_result = await parser.parse(file_content, file_name)
text_content = parse_result.text
media_items = parse_result.media
# 3. 保存多媒体资源
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
ops = KBManagerOps(self)
saved_media = []
for media_item in media_items:
media = await ops._save_media(
kb_id=kb_id,
doc_id=doc_id,
media_type=media_item.media_type,
file_name=media_item.file_name,
content=media_item.content,
mime_type=media_item.mime_type,
)
saved_media.append(media)
media_paths.append(Path(media.file_path))
# 4. 文档分块
chunks_text = await self.chunker.chunk(text_content)
# 5. 生成向量并存储
saved_chunks = []
for idx, chunk_text in enumerate(chunks_text):
# 存储到向量数据库
vec_doc_id = await self.vec_db.insert(
content=chunk_text,
metadata={
"kb_id": kb_id,
"doc_id": doc_id,
"chunk_index": idx,
},
)
vec_doc_ids.append(str(vec_doc_id))
# 保存块元数据
chunk = KBChunk(
doc_id=doc_id,
kb_id=kb_id,
chunk_index=idx,
content=chunk_text,
char_count=len(chunk_text),
vec_doc_id=str(vec_doc_id),
)
saved_chunks.append(chunk)
# 6. 保存文档元数据(事务)
doc = KBDocument(
doc_id=doc_id,
kb_id=kb_id,
doc_name=file_name,
file_type=file_type,
file_size=len(file_content),
file_path=str(file_path),
chunk_count=len(saved_chunks),
media_count=len(saved_media),
)
async with self.db.get_db() as session:
async with session.begin():
session.add(doc)
for chunk in saved_chunks:
session.add(chunk)
for media in saved_media:
session.add(media)
await session.commit()
await session.refresh(doc)
# 7. 更新知识库统计
await self._update_kb_stats(kb_id)
return doc
except Exception as e:
# 失败清理:删除已创建的资源
from astrbot.core import logger
logger.error(f"文档上传失败,开始清理资源: {e}")
# 清理向量数据库
for vec_id in vec_doc_ids:
try:
await self.vec_db.delete(vec_id)
except Exception as ve:
logger.warning(f"清理向量失败 {vec_id}: {ve}")
# 清理多媒体文件
for media_path in media_paths:
try:
if media_path.exists():
media_path.unlink()
except Exception as me:
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
# 清理文档文件
if file_path and file_path.exists():
try:
file_path.unlink()
except Exception as fe:
logger.warning(f"清理文档文件失败 {file_path}: {fe}")
# 重新抛出原始异常
raise
# ===== 统计更新 =====
async def _update_kb_stats(self, kb_id: str):
"""更新知识库统计信息(事务中执行)"""
async with self.db.get_db() as session:
async with session.begin():
# 统计文档数(在事务中查询)
doc_count = (
await session.scalar(
select(func.count(KBDocument.id)).where(
KBDocument.kb_id == kb_id
)
)
or 0
)
# 统计块数(在事务中查询)
chunk_count = (
await session.scalar(
select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id)
)
or 0
)
# 更新知识库(在同一事务中)
await session.execute(
update(KnowledgeBase)
.where(KnowledgeBase.kb_id == kb_id)
.values(doc_count=doc_count, chunk_count=chunk_count)
)
await session.commit()
-312
View File
@@ -1,312 +0,0 @@
"""知识库管理器辅助操作
该模块提供文档、块和多媒体的管理操作。
"""
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
import aiofiles
from sqlalchemy import delete, func, select
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KBMedia
if TYPE_CHECKING:
from astrbot.core.knowledge_base.manager import KBManager
class KBManagerOps:
"""知识库管理器辅助操作类
职责:
- 文档管理操作
- 块管理操作
- 多媒体管理操作
"""
def __init__(self, manager: "KBManager"):
self.manager = manager
self.db = manager.db
self.vec_db = manager.vec_db
self.media_path = manager.media_path
self.files_path = manager.files_path
# ===== 文档操作 =====
async def list_documents(
self, kb_id: str, offset: int = 0, limit: int = 100
) -> list[KBDocument]:
"""列出知识库的所有文档"""
async with self.db.get_db() as session:
stmt = (
select(KBDocument)
.where(KBDocument.kb_id == kb_id)
.offset(offset)
.limit(limit)
.order_by(KBDocument.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def get_document(self, doc_id: str) -> KBDocument | None:
"""获取文档详情"""
async with self.db.get_db() as session:
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
result = await session.execute(stmt)
return result.scalar_one_or_none()
async def delete_document(self, doc_id: str) -> bool:
"""删除文档(级联删除块、多媒体、向量)
采用三阶段删除策略:
1. 删除向量数据库中的向量(允许部分失败)
2. 删除SQL数据库中的记录(事务保证原子性)
3. 删除文件系统中的文件(失败不影响数据一致性)
"""
from astrbot.core import logger
# 0. 获取文档信息
doc = await self.get_document(doc_id)
if not doc:
return False
# 收集所有需要删除的资源
chunks = await self.list_chunks(doc_id)
media_list = await self.list_media(doc_id)
# ===== 第一阶段: 删除向量(可重试) =====
vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks]
deleted_vec_ids = []
failed_vec_ids = []
for vec_id in vec_ids_to_delete:
try:
await self.vec_db.delete(vec_id)
deleted_vec_ids.append(vec_id)
except Exception as e:
logger.error(f"删除向量失败: {vec_id}, {e}")
failed_vec_ids.append(vec_id)
# 如果向量删除失败过多(超过50%),中止操作
if len(failed_vec_ids) > len(vec_ids_to_delete) * 0.5:
logger.error(
f"向量删除失败过多 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 中止文档删除"
)
return False
# 记录部分失败但继续执行
if failed_vec_ids:
logger.warning(
f"部分向量删除失败 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 但继续执行删除操作"
)
# ===== 第二阶段: 删除数据库记录(事务) =====
async with self.db.get_db() as session:
async with session.begin():
# 删除块记录
await session.execute(delete(KBChunk).where(KBChunk.doc_id == doc_id))
# 删除多媒体记录
await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id))
# 删除文档记录
await session.execute(
delete(KBDocument).where(KBDocument.doc_id == doc_id)
)
await session.commit()
# ===== 第三阶段: 删除文件(失败不影响) =====
# 删除多媒体文件
for media in media_list:
try:
media_path = Path(media.file_path)
if media_path.exists():
media_path.unlink()
except Exception as e:
logger.warning(f"删除多媒体文件失败: {media.file_path}, {e}")
# 删除文档文件
try:
file_path = Path(doc.file_path)
if file_path.exists():
file_path.unlink()
except Exception as e:
logger.warning(f"删除文档文件失败: {doc.file_path}, {e}")
# ===== 更新统计 =====
await self.manager._update_kb_stats(doc.kb_id)
return True
# ===== 块操作 =====
async def list_chunks(self, doc_id: str) -> list[KBChunk]:
"""列出文档的所有块"""
async with self.db.get_db() as session:
stmt = (
select(KBChunk)
.where(KBChunk.doc_id == doc_id)
.order_by(KBChunk.chunk_index)
)
result = await session.execute(stmt)
return list(result.scalars().all())
async def delete_chunk(self, chunk_id: str) -> bool:
"""删除单个块
流程:
1. 查询块信息
2. 删除向量
3. 删除数据库记录
4. 更新文档统计
"""
from astrbot.core import logger
# 1. 查询块信息
async with self.db.get_db() as session:
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
result = await session.execute(stmt)
chunk = result.scalar_one_or_none()
if not chunk:
return False
doc_id = chunk.doc_id
vec_doc_id = chunk.vec_doc_id
# 2. 删除向量
try:
await self.vec_db.delete(vec_doc_id)
except Exception as e:
logger.error(f"删除向量失败: {vec_doc_id}, {e}")
return False
# 3. 删除数据库记录
async with self.db.get_db() as session:
async with session.begin():
await session.execute(
delete(KBChunk).where(KBChunk.chunk_id == chunk_id)
)
await session.commit()
# 4. 更新文档统计
await self._update_doc_stats(doc_id)
return True
# ===== 多媒体操作 =====
async def list_media(self, doc_id: str) -> list[KBMedia]:
"""列出文档的所有多媒体资源"""
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
result = await session.execute(stmt)
return list(result.scalars().all())
async def delete_media(self, media_id: str) -> bool:
"""删除多媒体资源
流程:
1. 查询媒体信息
2. 删除数据库记录
3. 删除文件(失败不影响)
4. 更新文档统计
"""
from astrbot.core import logger
# 1. 查询媒体信息
async with self.db.get_db() as session:
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
result = await session.execute(stmt)
media = result.scalar_one_or_none()
if not media:
return False
doc_id = media.doc_id
file_path_str = media.file_path
# 2. 删除数据库记录
async with self.db.get_db() as session:
async with session.begin():
await session.execute(
delete(KBMedia).where(KBMedia.media_id == media_id)
)
await session.commit()
# 3. 删除文件(失败不影响)
try:
media_path = Path(file_path_str)
if media_path.exists():
media_path.unlink()
except Exception as e:
logger.warning(f"删除多媒体文件失败: {file_path_str}, {e}")
# 4. 更新文档统计
await self._update_doc_stats(doc_id)
return True
# ===== 内部辅助方法 =====
async def _save_media(
self,
kb_id: str,
doc_id: str,
media_type: str,
file_name: str,
content: bytes,
mime_type: str,
) -> KBMedia:
"""保存多媒体资源"""
media_id = str(uuid.uuid4())
ext = Path(file_name).suffix
# 保存文件
file_path = self.media_path / kb_id / doc_id / f"{media_id}{ext}"
file_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.open(file_path, "wb") as f:
await f.write(content)
# 创建记录
media = KBMedia(
media_id=media_id,
doc_id=doc_id,
kb_id=kb_id,
media_type=media_type,
file_name=file_name,
file_path=str(file_path),
file_size=len(content),
mime_type=mime_type,
)
return media
async def _update_doc_stats(self, doc_id: str):
"""更新文档统计信息(事务中执行)"""
async with self.db.get_db() as session:
async with session.begin():
# 统计块数
chunk_count = (
await session.scalar(
select(func.count(KBChunk.id)).where(KBChunk.doc_id == doc_id)
)
) or 0
# 统计多媒体数
media_count = (
await session.scalar(
select(func.count(KBMedia.id)).where(KBMedia.doc_id == doc_id)
)
) or 0
# 更新文档
doc = await session.scalar(
select(KBDocument).where(KBDocument.doc_id == doc_id)
)
if doc:
doc.chunk_count = chunk_count
doc.media_count = media_count
await session.commit()
+27 -95
View File
@@ -1,31 +1,20 @@
"""知识库管理功能的数据模型定义
该模块定义了知识库系统所需的数据模型,包括:
- KnowledgeBase: 知识库表 (存储在独立的 kb.db)
- KBDocument: 文档表 (存储在独立的 kb.db)
- KBChunk: 文档块表 (存储在独立的 kb.db)
- KBMedia: 多媒体资源表 (存储在独立的 kb.db)
- KBSessionConfig: 会话配置表 (存储在独立的 kb.db)
注意:
- 所有模型存储在独立的知识库数据库 (data/knowledge_base/kb.db)
- 与主数据库 (astrbot.db) 完全解耦
"""
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
from sqlmodel import Field, SQLModel, Text, UniqueConstraint, MetaData
class KnowledgeBase(SQLModel, table=True):
class BaseKBModel(SQLModel, table=False):
metadata = MetaData()
class KnowledgeBase(BaseKBModel, table=True):
"""知识库表
存储知识库的基本信息和统计数据。
"""
__tablename__ = "knowledge_bases"
__tablename__ = "knowledge_bases" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
@@ -38,18 +27,17 @@ class KnowledgeBase(SQLModel, table=True):
index=True,
)
kb_name: str = Field(max_length=100, nullable=False)
description: Optional[str] = Field(default=None, sa_type=Text)
emoji: Optional[str] = Field(default="📚", max_length=10)
embedding_provider_id: Optional[str] = Field(default=None, max_length=100)
rerank_provider_id: Optional[str] = Field(default=None, max_length=100)
description: str | None = Field(default=None, sa_type=Text)
emoji: str | None = Field(default="📚", max_length=10)
embedding_provider_id: str | None = Field(default=None, max_length=100)
rerank_provider_id: str | None = Field(default=None, max_length=100)
# 分块配置参数
chunk_size: Optional[int] = Field(default=512, nullable=True)
chunk_overlap: Optional[int] = Field(default=50, nullable=True)
chunk_size: int | None = Field(default=512, nullable=True)
chunk_overlap: int | None = Field(default=50, nullable=True)
# 检索配置参数
top_k_dense: Optional[int] = Field(default=50, nullable=True)
top_k_sparse: Optional[int] = Field(default=50, nullable=True)
top_m_final: Optional[int] = Field(default=5, nullable=True)
enable_rerank: Optional[bool] = Field(default=False, nullable=True)
top_k_dense: int | None = Field(default=50, nullable=True)
top_k_sparse: int | None = Field(default=50, nullable=True)
top_m_final: int | None = Field(default=5, nullable=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
@@ -58,14 +46,21 @@ class KnowledgeBase(SQLModel, table=True):
doc_count: int = Field(default=0, nullable=False)
chunk_count: int = Field(default=0, nullable=False)
__table_args__ = (
UniqueConstraint(
"kb_name",
name="uix_kb_name",
),
)
class KBDocument(SQLModel, table=True):
class KBDocument(BaseKBModel, table=True):
"""文档表
存储上传到知识库的文档元数据。
"""
__tablename__ = "kb_documents"
__tablename__ = "kb_documents" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
@@ -91,40 +86,13 @@ class KBDocument(SQLModel, table=True):
)
class KBChunk(SQLModel, table=True):
"""文档块表
存储文档分块后的文本内容和向量索引关联信息。
"""
__tablename__ = "kb_chunks"
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
chunk_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
index=True,
)
doc_id: str = Field(max_length=36, nullable=False, index=True)
kb_id: str = Field(max_length=36, nullable=False, index=True)
chunk_index: int = Field(nullable=False)
content: str = Field(sa_type=Text, nullable=False)
char_count: int = Field(nullable=False)
vec_doc_id: str = Field(max_length=100, nullable=False, index=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class KBMedia(SQLModel, table=True):
class KBMedia(BaseKBModel, table=True):
"""多媒体资源表
存储从文档中提取的图片、视频等多媒体资源。
"""
__tablename__ = "kb_media"
__tablename__ = "kb_media" # type: ignore
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
@@ -144,39 +112,3 @@ class KBMedia(SQLModel, table=True):
file_size: int = Field(nullable=False)
mime_type: str = Field(max_length=100, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class KBSessionConfig(SQLModel, table=True):
"""会话知识库配置表
存储会话或平台级别的知识库关联配置。
该表存储在知识库独立数据库中,保持完全解耦。
支持两种配置范围:
- platform: 平台级别配置 (如 'qq', 'telegram')
- session: 会话级别配置 (如 'qq:group:12345')
"""
__tablename__ = "kb_session_config"
id: int | None = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
)
config_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4()),
)
scope: str = Field(max_length=20, nullable=False)
scope_id: str = Field(max_length=255, nullable=False, index=True)
kb_ids: str = Field(sa_type=Text, nullable=False)
top_k: Optional[int] = Field(default=None, nullable=True)
enable_rerank: Optional[bool] = Field(default=None, nullable=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),)
@@ -51,10 +51,10 @@ class PDFParser(BaseParser):
continue
resources = page["/Resources"]
if not resources or "/XObject" not in resources:
if not resources or "/XObject" not in resources: # type: ignore
continue
xobjects = resources["/XObject"].get_object()
xobjects = resources["/XObject"].get_object() # type: ignore
if not xobjects:
continue
@@ -0,0 +1,767 @@
———
》),
)÷(1-
”,
)、
=(
:
&
*
一一
~~~~
.
.一
./
--
=″
[*]
}>
[⑤]]
[①D]
c]
ng昉
//
[②e]
[②g]
={
}
,也
[①⑥]
[②B]
[①a]
[④a]
[①③]
[③h]
③]
1.
--
[②b]
’‘
×××
[①⑧]
0:2
=[
[⑤b]
[②c]
[④b]
[②③]
[③a]
[④c]
[①⑤]
[①⑦]
[①g]
∈[
[①⑨]
[①④]
[①c]
[②f]
[②⑧]
[②①]
[①C]
[③c]
[③g]
[②⑤]
[②②]
一.
[①h]
.数
[]
[①B]
数/
[①i]
[③e]
[①①]
[④d]
[④e]
[③b]
[⑤a]
[①A]
[②⑧]
[②⑦]
[①d]
[②j]
〕〔
][
://
′∈
[②④
[⑤e]
12%
b]
...
...................
…………………………………………………③
ZXFITL
[③F]
[①o]
]∧′=[
∪φ∈
′|
{-
②c
[③①]
R.L.
[①E]
Ψ
-[*]-
.日
[②d]
[②
[②⑦]
[②②]
[③e]
[①i]
[①B]
[①h]
[①d]
[①g]
[①②]
[②a]
f]
[⑩]
a]
[①e]
[②h]
[②⑥]
[③d]
[②⑩]
e]
元/吨
[②⑩]
2.3%
5:0
[①]
::
[②]
[③]
[④]
[⑤]
[⑥]
[⑦]
[⑧]
[⑨]
……
——
?
,
'
?
·
———
──
?
<
>
[
]
(
)
-
+
×
/
В
"
;
#
@
γ
μ
φ
φ.
×
Δ
sub
exp
sup
sub
Lex
+ξ
++
-β
<±
<Δ
<λ
<φ
<<
=
=☆
=-
>λ
_
~±
~+
[⑤f]
[⑤d]
[②i]
[②G]
[①f]
LI
[-
......
[③⑩]
第二
一番
一直
一个
一些
许多
有的是
也就是说
末##末
哎呀
哎哟
俺们
按照
吧哒
罢了
本着
比方
比如
鄙人
彼此
别的
别说
并且
不比
不成
不单
不但
不独
不管
不光
不过
不仅
不拘
不论
不怕
不然
不如
不特
不惟
不问
不只
朝着
趁着
除此之外
除非
除了
此间
此外
从而
但是
当着
的话
等等
叮咚
对于
多少
而况
而且
而是
而外
而言
而已
尔后
反过来
反过来说
反之
非但
非徒
否则
嘎登
各个
各位
各种
各自
根据
故此
固然
关于
果然
果真
哈哈
何处
何况
何时
哼唷
呼哧
还是
还有
换句话说
换言之
或是
或者
极了
及其
及至
即便
即或
即令
即若
即使
几时
既然
既是
继而
加之
假如
假若
假使
鉴于
较之
接着
结果
紧接着
进而
尽管
经过
就是
就是说
具体地说
具体说来
开始
开外
可见
可是
可以
况且
来着
例如
连同
两者
另外
另一方面
慢说
漫说
每当
莫若
某个
某些
哪边
哪儿
哪个
哪里
哪年
哪怕
哪天
哪些
哪样
那边
那儿
那个
那会儿
那里
那么
那么些
那么样
那时
那些
那样
乃至
你们
宁可
宁肯
宁愿
啪达
旁人
凭借
其次
其二
其他
其它
其一
其余
其中
起见
起见
岂但
恰恰相反
前后
前者
然而
然后
然则
人家
任何
任凭
如此
如果
如何
如其
如若
如上所述
若非
若是
上下
尚且
设若
设使
甚而
甚么
甚至
省得
时候
什么
什么样
使得
是的
首先
谁知
顺着
似的
虽然
虽说
虽则
随着
所以
他们
他人
它们
她们
倘或
倘然
倘若
倘使
通过
同时
万一
为何
为了
为什么
为着
嗡嗡
我们
呜呼
乌乎
无论
无宁
毋宁
相对而言
向着
沿
沿着
要不
要不然
要不是
要么
要是
也罢
也好
一般
一旦
一方面
一来
一切
一样
一则
依照
以便
以及
以免
以至
以至于
以致
抑或
因此
因而
因为
由此可见
由于
有的
有关
有些
于是
于是乎
与此同时
与否
与其
越是
云云
再说
再者
在下
咱们
怎么
怎么办
怎么样
怎样
照着
这边
这儿
这个
这会儿
这就是说
这里
这么
这么点儿
这么些
这么样
这时
这些
这样
正如
之类
之所以
之一
只是
只限
只要
只有
至于
诸位
着呢
自从
自个儿
自各儿
自己
自家
自身
综上所述
总的来看
总的来说
总的说来
总而言之
总之
纵令
纵然
纵使
遵照
作为
喔唷
@@ -3,15 +3,19 @@
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
"""
import json
from dataclasses import dataclass
from typing import List, Optional
import time
from astrbot.core.db.vec_db.base import BaseVecDB
from astrbot.core.knowledge_base.database import KBDatabase
from dataclasses import dataclass
from typing import List
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
from astrbot.core.provider.provider import RerankProvider
from astrbot.core.db.vec_db.base import Result
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
from ..kb_helper import KBHelper
from astrbot import logger
@dataclass
@@ -38,36 +42,29 @@ class RetrievalManager:
def __init__(
self,
vec_db: BaseVecDB,
sparse_retriever: SparseRetriever,
rank_fusion: RankFusion,
kb_db: KBDatabase,
rerank_provider: Optional[RerankProvider] = None,
kb_db: KBSQLiteDatabase,
):
"""初始化检索管理器
Args:
vec_db: 向量数据库实例
vec_db_factory: 向量数据库工厂
sparse_retriever: 稀疏检索器
rank_fusion: 结果融合器
kb_db: 知识库数据库实例
rerank_provider: Rerank 提供商 (可选)
"""
self.vec_db = vec_db
self.sparse_retriever = sparse_retriever
self.rank_fusion = rank_fusion
self.kb_db = kb_db
self.rerank_provider = rerank_provider
async def retrieve(
self,
query: str,
kb_ids: List[str],
top_k_dense: int = 50,
top_k_sparse: int = 50,
top_n_fusion: int = 20,
kb_id_helper_map: dict[str, KBHelper],
top_k_fusion: int = 20,
top_m_final: int = 5,
enable_rerank: bool = True,
) -> List[RetrievalResult]:
"""混合检索
@@ -80,40 +77,74 @@ class RetrievalManager:
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_k_dense: 稠密检索返回数量
top_k_sparse: 稀疏检索返回数量
top_n_fusion: 融合后返回数量
top_m_final: 最终返回数量
enable_rerank: 是否启用 Rerank
Returns:
List[RetrievalResult]: 检索结果列表
"""
if not kb_ids:
return []
kb_options: dict = {}
new_kb_ids = []
for kb_id in kb_ids:
kb_helper = kb_id_helper_map.get(kb_id)
if kb_helper:
kb = kb_helper.kb
kb_options[kb_id] = {
"top_k_dense": kb.top_k_dense or 50,
"top_k_sparse": kb.top_k_sparse or 50,
"top_m_final": kb.top_m_final or 5,
"vec_db": kb_helper.vec_db,
"rerank_provider_id": kb.rerank_provider_id,
}
new_kb_ids.append(kb_id)
else:
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
kb_ids = new_kb_ids
# 1. 稠密检索
time_start = time.time()
dense_results = await self._dense_retrieve(
query=query,
kb_ids=kb_ids,
top_k=top_k_dense,
kb_options=kb_options,
)
time_end = time.time()
logger.debug(
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results."
)
# 2. 稀疏检索
time_start = time.time()
sparse_results = await self.sparse_retriever.retrieve(
query=query,
kb_ids=kb_ids,
top_k=top_k_sparse,
kb_options=kb_options,
)
time_end = time.time()
logger.debug(
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results."
)
# 3. 结果融合
time_start = time.time()
fused_results = await self.rank_fusion.fuse(
dense_results=dense_results,
sparse_results=sparse_results,
top_k=top_n_fusion,
top_k=top_k_fusion,
)
time_end = time.time()
logger.debug(
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results."
)
# 4. 转换为 RetrievalResult (获取元数据)
retrieval_results = []
for fr in fused_results:
metadata_dict = await self.kb_db.get_chunk_with_metadata(fr.chunk_id)
metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
if metadata_dict:
retrieval_results.append(
RetrievalResult(
@@ -125,33 +156,45 @@ class RetrievalManager:
content=fr.content,
score=fr.score,
metadata={
"chunk_index": metadata_dict["chunk"].chunk_index,
"char_count": metadata_dict["chunk"].char_count,
"chunk_index": fr.chunk_index,
"char_count": len(fr.content),
},
)
)
# 5. Rerank (可选)
if enable_rerank and self.rerank_provider and retrieval_results:
# 5. Rerank
first_rerank = None
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if (
vec_db
and vec_db.rerank_provider
and rerank_pi
and rerank_pi == vec_db.rerank_provider.meta().id
):
first_rerank = vec_db.rerank_provider
break
if first_rerank and retrieval_results:
retrieval_results = await self._rerank(
query=query,
results=retrieval_results,
top_k=top_m_final,
rerank_provider=self.rerank_provider,
rerank_provider=first_rerank,
)
else:
retrieval_results = retrieval_results[:top_m_final]
return retrieval_results
return retrieval_results[:top_m_final]
async def _dense_retrieve(
self,
query: str,
kb_ids: List[str],
top_k: int,
kb_options: dict,
):
"""稠密检索 (向量相似度)
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
@@ -160,28 +203,32 @@ class RetrievalManager:
Returns:
List[Result]: 检索结果列表
"""
# 直接调用向量数据库检索
vec_results = await self.vec_db.retrieve(
query=query,
top_k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤
)
# 过滤:只保留指定知识库的结果
filtered_results = []
for result in vec_results:
metadata_str = result.data.get("metadata", "{}")
all_results: list[Result] = []
for kb_id in kb_ids:
if kb_id not in kb_options:
continue
try:
metadata = json.loads(metadata_str)
except (json.JSONDecodeError, TypeError):
metadata = {}
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
dense_k = int(kb_options[kb_id]["top_k_dense"])
vec_results = await vec_db.retrieve(
query=query,
k=dense_k,
fetch_k=dense_k * 2,
rerank=False, # 稠密检索阶段不进行 rerank
metadata_filters={"kb_id": kb_id},
)
if metadata.get("kb_id") in kb_ids:
filtered_results.append(result)
all_results.extend(vec_results)
except Exception as e:
from astrbot.core import logger
if len(filtered_results) >= top_k:
break
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
continue
return filtered_results[:top_k]
# 按相似度排序并返回 top_k
all_results.sort(key=lambda x: x.similarity, reverse=True)
# return all_results[: len(all_results) // len(kb_ids)]
return all_results
async def _rerank(
self,
@@ -3,11 +3,11 @@
使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
"""
import json
from dataclasses import dataclass
from typing import Dict, List
from astrbot.core.db.vec_db.base import Result
from astrbot.core.knowledge_base.database import KBDatabase
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
@@ -16,6 +16,7 @@ class FusedResult:
"""融合后的检索结果"""
chunk_id: str
chunk_index: int
doc_id: str
kb_id: str
content: str
@@ -30,7 +31,7 @@ class RankFusion:
- 使用 Reciprocal Rank Fusion (RRF) 算法
"""
def __init__(self, kb_db: KBDatabase, k: int = 60):
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
"""初始化结果融合器
Args:
@@ -42,10 +43,10 @@ class RankFusion:
async def fuse(
self,
dense_results: List[Result],
sparse_results: List[SparseResult],
dense_results: list[Result],
sparse_results: list[SparseResult],
top_k: int = 20,
) -> List[FusedResult]:
) -> list[FusedResult]:
"""融合稠密和稀疏检索结果
RRF 公式:
@@ -62,14 +63,14 @@ class RankFusion:
# 1. 构建排名映射
dense_ranks = {
r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
}
} # 这里的 doc_id 实际上是 chunk_id
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
# 2. 收集所有唯一的 ID (来自稠密检索的是 vec_doc_id, 稀疏检索的是 chunk_id)
# 2. 收集所有唯一的 ID
# 需要统一为 chunk_id
all_chunk_ids = set()
vec_doc_id_to_dense = {} # vec_doc_id -> Result
chunk_id_to_sparse = {} # chunk_id -> SparseResult
vec_doc_id_to_dense: dict[str, Result] = {} # vec_doc_id -> Result
chunk_id_to_sparse: dict[str, SparseResult] = {} # chunk_id -> SparseResult
# 处理稀疏检索结果
for r in sparse_results:
@@ -83,7 +84,7 @@ class RankFusion:
vec_doc_id_to_dense[vec_doc_id] = r
# 3. 计算 RRF 分数
rrf_scores: Dict[str, float] = {}
rrf_scores: dict[str, float] = {}
for identifier in all_chunk_ids:
score = 0.0
@@ -112,6 +113,7 @@ class RankFusion:
fused_results.append(
FusedResult(
chunk_id=sr.chunk_id,
chunk_index=sr.chunk_index,
doc_id=sr.doc_id,
kb_id=sr.kb_id,
content=sr.content,
@@ -120,16 +122,17 @@ class RankFusion:
)
elif identifier in vec_doc_id_to_dense:
# 从向量检索获取信息,需要从数据库获取块的详细信息
chunk = await self.kb_db.get_chunk_by_vec_doc_id(identifier)
if chunk:
fused_results.append(
FusedResult(
chunk_id=chunk.chunk_id,
doc_id=chunk.doc_id,
kb_id=chunk.kb_id,
content=chunk.content,
score=rrf_scores[identifier],
)
vec_result = vec_doc_id_to_dense[identifier]
chunk_md = json.loads(vec_result.data["metadata"])
fused_results.append(
FusedResult(
chunk_id=identifier,
chunk_index=chunk_md["chunk_index"],
doc_id=chunk_md["kb_doc_id"],
kb_id=chunk_md["kb_id"],
content=vec_result.data["text"],
score=rrf_scores[identifier],
)
)
return fused_results
@@ -3,18 +3,20 @@
使用 BM25 算法进行基于关键词的文档检索
"""
import jieba
import os
import json
from dataclasses import dataclass
from typing import List
from rank_bm25 import BM25Okapi
from astrbot.core.knowledge_base.database import KBDatabase
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
@dataclass
class SparseResult:
"""稀疏检索结果"""
chunk_index: int
chunk_id: str
doc_id: str
kb_id: str
@@ -30,7 +32,7 @@ class SparseRetriever:
- 使用 BM25 算法计算相关度
"""
def __init__(self, kb_db: KBDatabase):
def __init__(self, kb_db: KBSQLiteDatabase):
"""初始化稀疏检索器
Args:
@@ -39,37 +41,73 @@ class SparseRetriever:
self.kb_db = kb_db
self._index_cache = {} # 缓存 BM25 索引
with open(
os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
encoding="utf-8",
) as f:
self.hit_stopwords = {
word.strip() for word in set(f.read().splitlines()) if word.strip()
}
async def retrieve(
self,
query: str,
kb_ids: List[str],
top_k: int = 50,
) -> List[SparseResult]:
kb_ids: list[str],
kb_options: dict,
) -> list[SparseResult]:
"""执行稀疏检索
Args:
query: 查询文本
kb_ids: 知识库 ID 列表
top_k: 返回结果数量
kb_options: 每个知识库的检索选项
Returns:
List[SparseResult]: 检索结果列表
"""
# 1. 获取所有相关块
chunks = await self.kb_db.get_chunks_by_kb_ids(kb_ids)
top_k_sparse = 0
chunks = []
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
if not vec_db:
continue
result = await vec_db.document_storage.get_documents(
metadata_filters={}, limit=None, offset=None
)
chunk_mds = [json.loads(doc["metadata"]) for doc in result]
result = [
{
"chunk_id": doc["doc_id"],
"chunk_index": chunk_md["chunk_index"],
"doc_id": chunk_md["kb_doc_id"],
"kb_id": kb_id,
"text": doc["text"],
}
for doc, chunk_md in zip(result, chunk_mds)
]
chunks.extend(result)
top_k_sparse += kb_options.get(kb_id, {}).get("top_k_sparse", 50)
if not chunks:
return []
# 2. 准备文档和索引
corpus = [chunk.content for chunk in chunks]
tokenized_corpus = [doc.split() for doc in corpus]
corpus = [chunk["text"] for chunk in chunks]
tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
tokenized_corpus = [
[word for word in doc if word not in self.hit_stopwords]
for doc in tokenized_corpus
]
# 3. 构建 BM25 索引
bm25 = BM25Okapi(tokenized_corpus)
# 4. 执行检索
tokenized_query = query.split()
tokenized_query = list(jieba.cut(query))
tokenized_query = [
word for word in tokenized_query if word not in self.hit_stopwords
]
scores = bm25.get_scores(tokenized_query)
# 5. 排序并返回 Top-K
@@ -78,13 +116,15 @@ class SparseRetriever:
chunk = chunks[idx]
results.append(
SparseResult(
chunk_id=chunk.chunk_id,
doc_id=chunk.doc_id,
kb_id=chunk.kb_id,
content=chunk.content,
chunk_id=chunk["chunk_id"],
chunk_index=chunk["chunk_index"],
doc_id=chunk["doc_id"],
kb_id=chunk["kb_id"],
content=chunk["text"],
score=float(score),
)
)
results.sort(key=lambda x: x.score, reverse=True)
return results[:top_k]
# return results[: len(results) // len(kb_ids)]
return results[:top_k_sparse]
@@ -1,157 +0,0 @@
"""会话知识库配置数据库操作
该模块封装会话知识库配置的数据库查询操作
注意: 会话配置表 (kb_session_config) 存储在知识库独立数据库 (kb.db) ,
而不是主数据库 (astrbot.db) ,以实现完全解耦
"""
import json
from typing import Optional
from sqlalchemy import select
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
from astrbot.core.knowledge_base.models import KBSessionConfig
class SessionConfigDB:
"""会话知识库配置数据库操作类
职责:
- 提供会话知识库配置管理
- 统一异常处理
注意: 该类操作知识库独立数据库,实现完全解耦
"""
def __init__(self, db: KBSQLiteDatabase):
"""初始化会话配置数据库操作类
Args:
db: 知识库独立数据库实例 (kb.db),不是主数据库
"""
self.db = db
async def get_session_kb_ids(self, session_id: str) -> list[str]:
"""获取会话关联的知识库 ID 列表
查找顺序:
1. 会话级别配置 (优先)
2. 平台级别配置
3. 返回空列表
"""
async with self.db.get_db() as session:
# 1. 查找会话级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "session",
KBSessionConfig.scope_id == session_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 2. 提取平台 ID (格式: platform:xxx:session_id)
parts = session_id.split(":")
if len(parts) >= 2:
platform_id = parts[0]
# 查找平台级别配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == "platform",
KBSessionConfig.scope_id == platform_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
return json.loads(config.kb_ids)
# 3. 无配置
return []
async def set_session_kb_ids(
self,
scope: str,
scope_id: str,
kb_ids: list[str],
top_k: Optional[int] = None,
enable_rerank: Optional[bool] = None,
) -> KBSessionConfig:
"""设置会话知识库配置
Args:
scope: 配置范围 (session/platform)
scope_id: 范围标识 (会话 ID 或平台 ID)
kb_ids: 知识库 ID 列表
top_k: 返回结果数量 (可选)
enable_rerank: 是否启用 Rerank (可选)
"""
async with self.db.get_db() as session:
# 查找现有配置
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if config:
# 更新现有配置
config.kb_ids = json.dumps(kb_ids)
if top_k is not None:
config.top_k = top_k
if enable_rerank is not None:
config.enable_rerank = enable_rerank
else:
# 创建新配置
config = KBSessionConfig(
scope=scope,
scope_id=scope_id,
kb_ids=json.dumps(kb_ids),
top_k=top_k,
enable_rerank=enable_rerank,
)
session.add(config)
await session.commit()
await session.refresh(config)
return config
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
"""删除会话知识库配置"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig).where(
KBSessionConfig.scope == scope,
KBSessionConfig.scope_id == scope_id,
)
result = await session.execute(stmt)
config = result.scalar_one_or_none()
if not config:
return False
await session.delete(config)
await session.commit()
return True
async def list_all_session_configs(
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
) -> list[KBSessionConfig]:
"""列出所有会话配置"""
async with self.db.get_db() as session:
stmt = select(KBSessionConfig)
if scope:
stmt = stmt.where(KBSessionConfig.scope == scope)
stmt = (
stmt.offset(offset)
.limit(limit)
.order_by(KBSessionConfig.created_at.desc())
)
result = await session.execute(stmt)
return list(result.scalars().all())
@@ -7,7 +7,7 @@ import copy
import json
import traceback
from datetime import timedelta
from typing import AsyncGenerator, Union
from collections.abc import AsyncGenerator
from astrbot.core.conversation_mgr import Conversation
from astrbot.core import logger
from astrbot.core.message.components import Image
@@ -45,7 +45,7 @@ except (ModuleNotFoundError, ImportError):
AgentContextWrapper = ContextWrapper[AstrAgentContext]
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
@@ -103,7 +103,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
request = ProviderRequest(
prompt=input_,
system_prompt=tool.description,
system_prompt=tool.description or "",
image_urls=[], # 暂时不传递原始 agent 的上下文
contexts=[], # 暂时不传递原始 agent 的上下文
func_tool=toolset,
@@ -240,7 +240,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
yield res
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
await call_event_hook(
@@ -338,7 +338,7 @@ class LLMRequestSubStage(Stage):
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
def _select_provider(self, event: AstrMessageEvent):
"""选择使用的 LLM 提供商"""
sel_provider = event.get_extra("selected_provider")
_ctx = self.ctx.plugin_manager.context
@@ -368,7 +368,7 @@ class LLMRequestSubStage(Stage):
async def process(
self, event: AstrMessageEvent, _nested: bool = False
) -> Union[None, AsyncGenerator[None, None]]:
) -> None | AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
@@ -383,6 +383,9 @@ class LLMRequestSubStage(Stage):
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
@@ -489,6 +492,9 @@ class LLMRequestSubStage(Stage):
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
logger.debug(
@@ -526,8 +532,10 @@ class LLMRequestSubStage(Stage):
chain = (
MessageChain().message(final_llm_resp.completion_text).chain
)
else:
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
event.set_result(
MessageEventResult(
chain=chain,
@@ -538,6 +546,9 @@ class LLMRequestSubStage(Stage):
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
@@ -556,6 +567,8 @@ class LLMRequestSubStage(Stage):
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
if not req.conversation:
return
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, req.conversation.cid
)
+11 -6
View File
@@ -7,7 +7,6 @@ async def inject_kb_context(
umo: str,
p_ctx: PipelineContext,
req: ProviderRequest,
top_k: int = 5,
) -> None:
"""inject knowledge base context into the provider request
@@ -15,13 +14,19 @@ async def inject_kb_context(
p_ctx: Pipeline context
req: Provider request
"""
kb_injector = p_ctx.plugin_manager.context.kb_manager.get_kb_injector()
if not kb_injector:
kb_mgr = p_ctx.plugin_manager.context.kb_manager
kb_names = p_ctx.astrbot_config.get("kb_names", [])
top_k_fusion = p_ctx.astrbot_config.get("kb_fusion_top_k", 20)
top_k = p_ctx.astrbot_config.get("kb_final_top_k", 5)
if not kb_names:
return
kb_context = await kb_injector.retrieve_and_inject(
unified_msg_origin=umo,
kb_context = await kb_mgr.retrieve(
query=req.prompt,
top_k=top_k,
kb_names=kb_names,
top_k_fusion=top_k_fusion,
top_m_final=top_k,
)
if not kb_context:
return
+8 -7
View File
@@ -4,7 +4,7 @@ import re
import hashlib
import uuid
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
from typing import List, Union, Optional, AsyncGenerator, Any
from astrbot import logger
from astrbot.core.db.po import Conversation
@@ -26,8 +26,6 @@ from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
from .message_session import MessageSession, MessageSesion # noqa
_VT = TypeVar("_VT")
class AstrMessageEvent(abc.ABC):
def __init__(
@@ -92,8 +90,10 @@ class AstrMessageEvent(abc.ABC):
"""
return self.message_str
def _outline_chain(self, chain: List[BaseMessageComponent]) -> str:
def _outline_chain(self, chain: Optional[List[BaseMessageComponent]]) -> str:
outline = ""
if not chain:
return outline
for i in chain:
if isinstance(i, Plain):
outline += i.text
@@ -175,9 +175,7 @@ class AstrMessageEvent(abc.ABC):
"""
self._extras[key] = value
def get_extra(
self, key: str | None = None, default: _VT = None
) -> dict[str, Any] | _VT:
def get_extra(self, key: str | None = None, default=None) -> Any:
"""
获取额外的信息
"""
@@ -265,6 +263,9 @@ class AstrMessageEvent(abc.ABC):
"""
if isinstance(result, str):
result = MessageEventResult().message(result)
# 兼容外部插件或调用方传入的 chain=None 的情况,确保为可迭代列表
if isinstance(result, MessageEventResult) and result.chain is None:
result.chain = []
self._result = result
def stop_event(self):
@@ -1,6 +1,6 @@
import asyncio
import json
from typing import Dict, Any, Optional, Awaitable
import random
from typing import Dict, Any, Optional, Awaitable, List
from astrbot.api import logger
from astrbot.api.event import MessageChain
@@ -14,6 +14,13 @@ from astrbot.core.platform.astr_message_event import MessageSession
import astrbot.api.message_components as Comp
from .misskey_api import MisskeyAPI
import os
try:
import magic # type: ignore
except Exception:
magic = None
from .misskey_event import MisskeyPlatformEvent
from .misskey_utils import (
serialize_message_chain,
@@ -25,9 +32,15 @@ from .misskey_utils import (
extract_sender_info,
create_base_message,
process_at_mention,
format_poll,
cache_user_info,
cache_room_info,
)
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
# Constants
MAX_FILE_UPLOAD_COUNT = 16
DEFAULT_UPLOAD_CONCURRENCY = 3
@register_platform_adapter("misskey", "Misskey 平台适配器")
@@ -46,6 +59,31 @@ class MisskeyPlatformAdapter(Platform):
)
self.local_only = self.config.get("misskey_local_only", False)
self.enable_chat = self.config.get("misskey_enable_chat", True)
self.enable_file_upload = self.config.get("misskey_enable_file_upload", True)
self.upload_folder = self.config.get("misskey_upload_folder")
# download / security related options (exposed to platform_config)
self.allow_insecure_downloads = bool(
self.config.get("misskey_allow_insecure_downloads", False)
)
# parse download timeout and chunk size safely
_dt = self.config.get("misskey_download_timeout")
try:
self.download_timeout = int(_dt) if _dt is not None else 15
except Exception:
self.download_timeout = 15
_chunk = self.config.get("misskey_download_chunk_size")
try:
self.download_chunk_size = int(_chunk) if _chunk is not None else 64 * 1024
except Exception:
self.download_chunk_size = 64 * 1024
# parse max download bytes safely
_md_bytes = self.config.get("misskey_max_download_bytes")
try:
self.max_download_bytes = int(_md_bytes) if _md_bytes is not None else None
except Exception:
self.max_download_bytes = None
self.unique_session = platform_settings["unique_session"]
@@ -63,6 +101,11 @@ class MisskeyPlatformAdapter(Platform):
"misskey_default_visibility": "public",
"misskey_local_only": False,
"misskey_enable_chat": True,
# download / security options
"misskey_allow_insecure_downloads": False,
"misskey_download_timeout": 15,
"misskey_download_chunk_size": 65536,
"misskey_max_download_bytes": None,
}
default_config.update(self.config)
@@ -78,7 +121,14 @@ class MisskeyPlatformAdapter(Platform):
logger.error("[Misskey] 配置不完整,无法启动")
return
self.api = MisskeyAPI(self.instance_url, self.access_token)
self.api = MisskeyAPI(
self.instance_url,
self.access_token,
allow_insecure_downloads=self.allow_insecure_downloads,
download_timeout=self.download_timeout,
chunk_size=self.download_chunk_size,
max_download_bytes=self.max_download_bytes,
)
self._running = True
try:
@@ -95,6 +145,80 @@ class MisskeyPlatformAdapter(Platform):
await self._start_websocket_connection()
def _register_event_handlers(self, streaming):
"""注册事件处理器"""
streaming.add_message_handler("notification", self._handle_notification)
streaming.add_message_handler("main:notification", self._handle_notification)
if self.enable_chat:
streaming.add_message_handler("newChatMessage", self._handle_chat_message)
streaming.add_message_handler(
"messaging:newChatMessage", self._handle_chat_message
)
streaming.add_message_handler("_debug", self._debug_handler)
async def _send_text_only_message(
self, session_id: str, text: str, session, message_chain
):
"""发送纯文本消息(无文件上传)"""
if not self.api:
return await super().send_by_session(session, message_chain)
if session_id and is_valid_user_session_id(session_id):
from .misskey_utils import extract_user_id_from_session_id
user_id = extract_user_id_from_session_id(session_id)
payload: Dict[str, Any] = {"toUserId": user_id, "text": text}
await self.api.send_message(payload)
elif session_id and is_valid_room_session_id(session_id):
from .misskey_utils import extract_room_id_from_session_id
room_id = extract_room_id_from_session_id(session_id)
payload = {"toRoomId": room_id, "text": text}
await self.api.send_room_message(payload)
return await super().send_by_session(session, message_chain)
def _process_poll_data(
self, message: AstrBotMessage, poll: Dict[str, Any], message_parts: List[str]
):
"""处理投票数据,将其添加到消息中"""
try:
if not isinstance(message.raw_message, dict):
message.raw_message = {}
message.raw_message["poll"] = poll
setattr(message, "poll", poll)
except Exception:
pass
poll_text = format_poll(poll)
if poll_text:
message.message.append(Comp.Plain(poll_text))
message_parts.append(poll_text)
def _extract_additional_fields(self, session, message_chain) -> Dict[str, Any]:
"""从会话和消息链中提取额外字段"""
fields = {"cw": None, "poll": None, "renote_id": None, "channel_id": None}
for comp in message_chain.chain:
if hasattr(comp, "cw") and getattr(comp, "cw", None):
fields["cw"] = getattr(comp, "cw")
break
if hasattr(session, "extra_data") and isinstance(
getattr(session, "extra_data", None), dict
):
extra_data = getattr(session, "extra_data")
fields.update(
{
"poll": extra_data.get("poll"),
"renote_id": extra_data.get("renote_id"),
"channel_id": extra_data.get("channel_id"),
}
)
return fields
async def _start_websocket_connection(self):
backoff_delay = 1.0
max_backoff = 300.0
@@ -109,25 +233,20 @@ class MisskeyPlatformAdapter(Platform):
break
streaming = self.api.get_streaming_client()
streaming.add_message_handler("notification", self._handle_notification)
if self.enable_chat:
streaming.add_message_handler(
"newChatMessage", self._handle_chat_message
)
streaming.add_message_handler("_debug", self._debug_handler)
self._register_event_handlers(streaming)
if await streaming.connect():
logger.info(
f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})"
)
connection_attempts = 0 # 重置计数器
connection_attempts = 0
await streaming.subscribe_channel("main")
if self.enable_chat:
await streaming.subscribe_channel("messaging")
await streaming.subscribe_channel("messagingIndex")
logger.info("[Misskey] 聊天频道已订阅")
backoff_delay = 1.0 # 重置延迟
backoff_delay = 1.0
await streaming.listen()
else:
logger.error(
@@ -140,18 +259,20 @@ class MisskeyPlatformAdapter(Platform):
)
if self._running:
jitter = random.uniform(0, 1.0)
sleep_time = backoff_delay + jitter
logger.info(
f"[Misskey] {backoff_delay:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
f"[Misskey] {sleep_time:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
)
await asyncio.sleep(backoff_delay)
await asyncio.sleep(sleep_time)
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
async def _handle_notification(self, data: Dict[str, Any]):
try:
logger.debug(
f"[Misskey] 收到通知事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
)
notification_type = data.get("type")
logger.debug(
f"[Misskey] 收到通知事件: type={notification_type}, user_id={data.get('userId', 'unknown')}"
)
if notification_type in ["mention", "reply", "quote"]:
note = data.get("note")
if note and self._is_bot_mentioned(note):
@@ -164,7 +285,7 @@ class MisskeyPlatformAdapter(Platform):
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.api,
client=self,
)
self.commit_event(event)
except Exception as e:
@@ -172,17 +293,16 @@ class MisskeyPlatformAdapter(Platform):
async def _handle_chat_message(self, data: Dict[str, Any]):
try:
logger.debug(
f"[Misskey] 收到聊天事件数据:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
)
sender_id = str(
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "")
)
room_id = data.get("toRoomId")
logger.debug(
f"[Misskey] 收到聊天事件: sender_id={sender_id}, room_id={room_id}, is_self={sender_id == self.client_self_id}"
)
if sender_id == self.client_self_id:
return
room_id = data.get("toRoomId")
if room_id:
raw_text = data.get("text", "")
logger.debug(
@@ -200,15 +320,16 @@ class MisskeyPlatformAdapter(Platform):
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.api,
client=self,
)
self.commit_event(event)
except Exception as e:
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
async def _debug_handler(self, data: Dict[str, Any]):
event_type = data.get("type", "unknown")
logger.debug(
f"[Misskey] 收到未处理事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}"
)
def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool:
@@ -239,43 +360,250 @@ class MisskeyPlatformAdapter(Platform):
try:
session_id = session.session_id
text, has_at_user = serialize_message_chain(message_chain.chain)
if not has_at_user and session_id:
user_info = self._user_cache.get(session_id)
# 从session_id中提取用户ID用于缓存查询
# session_id格式为: "chat%<user_id>" 或 "room%<room_id>" 或 "note%<user_id>"
user_id_for_cache = None
if "%" in session_id:
parts = session_id.split("%")
if len(parts) >= 2:
user_id_for_cache = parts[1]
user_info = None
if user_id_for_cache:
user_info = self._user_cache.get(user_id_for_cache)
text = add_at_mention_if_needed(text, user_info, has_at_user)
# 检查是否有文件组件
has_file_components = any(
isinstance(comp, Comp.Image)
or isinstance(comp, Comp.File)
or hasattr(comp, "convert_to_file_path")
or hasattr(comp, "get_file")
or any(
hasattr(comp, a) for a in ("file", "url", "path", "src", "source")
)
for comp in message_chain.chain
)
if not text or not text.strip():
logger.warning("[Misskey] 消息内容为空,跳过发送")
return await super().send_by_session(session, message_chain)
if not has_file_components:
logger.warning("[Misskey] 消息内容为空且无文件组件,跳过发送")
return await super().send_by_session(session, message_chain)
else:
text = ""
if len(text) > self.max_message_length:
text = text[: self.max_message_length] + "..."
if session_id and is_valid_user_session_id(session_id):
from .misskey_utils import extract_user_id_from_session_id
file_ids: List[str] = []
fallback_urls: List[str] = []
user_id = extract_user_id_from_session_id(session_id)
await self.api.send_message(user_id, text)
elif session_id and is_valid_room_session_id(session_id):
if not self.enable_file_upload:
return await self._send_text_only_message(
session_id, text, session, message_chain
)
MAX_UPLOAD_CONCURRENCY = 10
upload_concurrency = int(
self.config.get(
"misskey_upload_concurrency", DEFAULT_UPLOAD_CONCURRENCY
)
)
upload_concurrency = min(upload_concurrency, MAX_UPLOAD_CONCURRENCY)
sem = asyncio.Semaphore(upload_concurrency)
async def _upload_comp(comp) -> Optional[object]:
"""组件上传函数:处理 URL(下载后上传)或本地文件(直接上传)"""
from .misskey_utils import (
resolve_component_url_or_path,
upload_local_with_retries,
)
local_path = None
try:
async with sem:
if not self.api:
return None
# 解析组件的 URL 或本地路径
url_candidate, local_path = await resolve_component_url_or_path(
comp
)
if not url_candidate and not local_path:
return None
preferred_name = getattr(comp, "name", None) or getattr(
comp, "file", None
)
# URL 上传:下载后本地上传
if url_candidate:
result = await self.api.upload_and_find_file(
str(url_candidate),
preferred_name,
folder_id=self.upload_folder,
)
if isinstance(result, dict) and result.get("id"):
return str(result["id"])
# 本地文件上传
if local_path:
file_id = await upload_local_with_retries(
self.api,
str(local_path),
preferred_name,
self.upload_folder,
)
if file_id:
return file_id
# 所有上传都失败,尝试获取 URL 作为回退
if hasattr(comp, "register_to_file_service"):
try:
url = await comp.register_to_file_service()
if url:
return {"fallback_url": url}
except Exception:
pass
return None
finally:
# 清理临时文件
if local_path and isinstance(local_path, str):
data_temp = os.path.join(get_astrbot_data_path(), "temp")
if local_path.startswith(data_temp) and os.path.exists(
local_path
):
try:
os.remove(local_path)
logger.debug(f"[Misskey] 已清理临时文件: {local_path}")
except Exception:
pass
# 收集所有可能包含文件/URL信息的组件:支持异步接口或同步字段
file_components = []
for comp in message_chain.chain:
try:
if (
isinstance(comp, Comp.Image)
or isinstance(comp, Comp.File)
or hasattr(comp, "convert_to_file_path")
or hasattr(comp, "get_file")
or any(
hasattr(comp, a)
for a in ("file", "url", "path", "src", "source")
)
):
file_components.append(comp)
except Exception:
# 保守跳过无法访问属性的组件
continue
if len(file_components) > MAX_FILE_UPLOAD_COUNT:
logger.warning(
f"[Misskey] 文件数量超过限制 ({len(file_components)} > {MAX_FILE_UPLOAD_COUNT}),只上传前{MAX_FILE_UPLOAD_COUNT}个文件"
)
file_components = file_components[:MAX_FILE_UPLOAD_COUNT]
upload_tasks = [_upload_comp(comp) for comp in file_components]
try:
results = await asyncio.gather(*upload_tasks) if upload_tasks else []
for r in results:
if not r:
continue
if isinstance(r, dict) and r.get("fallback_url"):
url = r.get("fallback_url")
if url:
fallback_urls.append(str(url))
else:
try:
fid_str = str(r)
if fid_str:
file_ids.append(fid_str)
except Exception:
pass
except Exception:
logger.debug("[Misskey] 并发上传过程中出现异常,继续发送文本")
if session_id and is_valid_room_session_id(session_id):
from .misskey_utils import extract_room_id_from_session_id
room_id = extract_room_id_from_session_id(session_id)
await self.api.send_room_message(room_id, text)
else:
visibility, visible_user_ids = resolve_message_visibility(
user_id=session_id,
user_cache=self._user_cache,
self_id=self.client_self_id,
default_visibility=self.default_visibility,
if fallback_urls:
appended = "\n" + "\n".join(fallback_urls)
text = (text or "") + appended
payload: Dict[str, Any] = {"toRoomId": room_id, "text": text}
if file_ids:
payload["fileIds"] = file_ids
await self.api.send_room_message(payload)
elif session_id:
from .misskey_utils import (
extract_user_id_from_session_id,
is_valid_chat_session_id,
)
await self.api.create_note(
text,
visibility=visibility,
visible_user_ids=visible_user_ids,
local_only=self.local_only,
)
if is_valid_chat_session_id(session_id):
user_id = extract_user_id_from_session_id(session_id)
if fallback_urls:
appended = "\n" + "\n".join(fallback_urls)
text = (text or "") + appended
payload: Dict[str, Any] = {"toUserId": user_id, "text": text}
if file_ids:
# 聊天消息只支持单个文件,使用 fileId 而不是 fileIds
payload["fileId"] = file_ids[0]
if len(file_ids) > 1:
logger.warning(
f"[Misskey] 聊天消息只支持单个文件,忽略其余 {len(file_ids) - 1} 个文件"
)
await self.api.send_message(payload)
else:
# 回退到发帖逻辑
# 去掉 session_id 中的 note% 前缀以匹配 user_cache 的键格式
user_id_for_cache = (
session_id.split("%")[1] if "%" in session_id else session_id
)
# 获取用户缓存信息(包含reply_to_note_id
user_info_for_reply = self._user_cache.get(user_id_for_cache, {})
visibility, visible_user_ids = resolve_message_visibility(
user_id=user_id_for_cache,
user_cache=self._user_cache,
self_id=self.client_self_id,
default_visibility=self.default_visibility,
)
logger.debug(
f"[Misskey] 解析可见性: visibility={visibility}, visible_user_ids={visible_user_ids}, session_id={session_id}, user_id_for_cache={user_id_for_cache}"
)
fields = self._extract_additional_fields(session, message_chain)
if fallback_urls:
appended = "\n" + "\n".join(fallback_urls)
text = (text or "") + appended
# 从缓存中获取原消息ID作为reply_id
reply_id = user_info_for_reply.get("reply_to_note_id")
await self.api.create_note(
text=text,
visibility=visibility,
visible_user_ids=visible_user_ids,
file_ids=file_ids or None,
local_only=self.local_only,
reply_id=reply_id, # 添加reply_id参数
cw=fields["cw"],
poll=fields["poll"],
renote_id=fields["renote_id"],
channel_id=fields["channel_id"],
)
except Exception as e:
logger.error(f"[Misskey] 发送消息失败: {e}")
@@ -309,6 +637,14 @@ class MisskeyPlatformAdapter(Platform):
file_parts = process_files(message, files)
message_parts.extend(file_parts)
poll = raw_data.get("poll") or (
raw_data.get("note", {}).get("poll")
if isinstance(raw_data.get("note"), dict)
else None
)
if poll and isinstance(poll, dict):
self._process_poll_data(message, poll, message_parts)
message.message_str = (
" ".join(part for part in message_parts if part.strip())
if message_parts
@@ -1,4 +1,6 @@
import json
import random
import asyncio
from typing import Any, Optional, Dict, List, Callable, Awaitable
import uuid
@@ -11,6 +13,7 @@ except ImportError as e:
) from e
from astrbot.api import logger
from .misskey_utils import FileIDExtractor
# Constants
API_MAX_RETRIES = 3
@@ -55,6 +58,7 @@ class StreamingClient:
self.is_connected = False
self.message_handlers: Dict[str, Callable] = {}
self.channels: Dict[str, str] = {}
self.desired_channels: Dict[str, Optional[Dict]] = {}
self._running = False
self._last_pong = None
@@ -72,6 +76,18 @@ class StreamingClient:
self._running = True
logger.info("[Misskey WebSocket] 已连接")
if self.desired_channels:
try:
desired = list(self.desired_channels.items())
for channel_type, params in desired:
try:
await self.subscribe_channel(channel_type, params)
except Exception as e:
logger.warning(
f"[Misskey WebSocket] 重新订阅 {channel_type} 失败: {e}"
)
except Exception:
pass
return True
except Exception as e:
@@ -112,9 +128,12 @@ class StreamingClient:
return
message = {"type": "disconnect", "body": {"id": channel_id}}
await self.websocket.send(json.dumps(message))
del self.channels[channel_id]
channel_type = self.channels.get(channel_id)
if channel_id in self.channels:
del self.channels[channel_id]
if channel_type and channel_type not in self.channels.values():
self.desired_channels.pop(channel_type, None)
def add_message_handler(
self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
@@ -141,25 +160,67 @@ class StreamingClient:
except websockets.exceptions.ConnectionClosedError as e:
logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
self.is_connected = False
try:
await self.disconnect()
except Exception:
pass
except websockets.exceptions.ConnectionClosed as e:
logger.warning(
f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
)
self.is_connected = False
try:
await self.disconnect()
except Exception:
pass
except websockets.exceptions.InvalidHandshake as e:
logger.error(f"[Misskey WebSocket] 握手失败: {e}")
self.is_connected = False
try:
await self.disconnect()
except Exception:
pass
except Exception as e:
logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
self.is_connected = False
try:
await self.disconnect()
except Exception:
pass
async def _handle_message(self, data: Dict[str, Any]):
message_type = data.get("type")
body = data.get("body", {})
logger.debug(
f"[Misskey WebSocket] 收到消息类型: {message_type}\n数据: {json.dumps(data, indent=2, ensure_ascii=False)}"
)
def _build_channel_summary(message_type: Optional[str], body: Any) -> str:
try:
if not isinstance(body, dict):
return f"[Misskey WebSocket] 收到消息类型: {message_type}"
inner = body.get("body") if isinstance(body.get("body"), dict) else body
note = (
inner.get("note")
if isinstance(inner, dict) and isinstance(inner.get("note"), dict)
else None
)
text = note.get("text") if note else None
note_id = note.get("id") if note else None
files = note.get("files") or [] if note else []
has_files = bool(files)
is_hidden = bool(note.get("isHidden")) if note else False
user = note.get("user", {}) if note else None
return (
f"[Misskey WebSocket] 收到消息类型: {message_type} | "
f"note_id={note_id} | user={user.get('username') if user else None} | "
f"text={text[:80] if text else '[no-text]'} | files={has_files} | hidden={is_hidden}"
)
except Exception:
return f"[Misskey WebSocket] 收到消息类型: {message_type}"
channel_summary = _build_channel_summary(message_type, body)
logger.info(channel_summary)
if message_type == "channel":
channel_id = body.get("id")
@@ -202,16 +263,60 @@ class StreamingClient:
await self.message_handlers["_debug"](data)
def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
def retry_async(
max_retries: int = 3,
retryable_exceptions: tuple = (APIConnectionError, APIRateLimitError),
backoff_base: float = 1.0,
max_backoff: float = 30.0,
):
"""
智能异步重试装饰器
Args:
max_retries: 最大重试次数
retryable_exceptions: 可重试的异常类型
backoff_base: 退避基数
max_backoff: 最大退避时间
"""
def decorator(func):
async def wrapper(*args, **kwargs):
last_exc = None
for _ in range(max_retries):
func_name = getattr(func, "__name__", "unknown")
for attempt in range(1, max_retries + 1):
try:
return await func(*args, **kwargs)
except retryable_exceptions as e:
last_exc = e
if attempt == max_retries:
logger.error(
f"[Misskey API] {func_name} 重试 {max_retries} 次后仍失败: {e}"
)
break
# 智能退避策略
if isinstance(e, APIRateLimitError):
# 频率限制用更长的退避时间
backoff = min(backoff_base * (3**attempt), max_backoff)
else:
# 其他错误用指数退避
backoff = min(backoff_base * (2**attempt), max_backoff)
jitter = random.uniform(0.1, 0.5) # 随机抖动
sleep_time = backoff + jitter
logger.warning(
f"[Misskey API] {func_name}{attempt} 次重试失败: {e}"
f"{sleep_time:.1f}s后重试"
)
await asyncio.sleep(sleep_time)
continue
except Exception as e:
# 非可重试异常直接抛出
logger.error(f"[Misskey API] {func_name} 遇到不可重试异常: {e}")
raise
if last_exc:
raise last_exc
@@ -221,11 +326,27 @@ def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
class MisskeyAPI:
def __init__(self, instance_url: str, access_token: str):
def __init__(
self,
instance_url: str,
access_token: str,
*,
allow_insecure_downloads: bool = False,
download_timeout: int = 15,
chunk_size: int = 64 * 1024,
max_download_bytes: Optional[int] = None,
):
self.instance_url = instance_url.rstrip("/")
self.access_token = access_token
self._session: Optional[aiohttp.ClientSession] = None
self.streaming: Optional[StreamingClient] = None
# download options
self.allow_insecure_downloads = allow_insecure_downloads
self.download_timeout = download_timeout
self.chunk_size = chunk_size
self.max_download_bytes = (
int(max_download_bytes) if max_download_bytes is not None else None
)
async def __aenter__(self):
return self
@@ -258,16 +379,37 @@ class MisskeyAPI:
def _handle_response_status(self, status: int, endpoint: str):
"""处理 HTTP 响应状态码"""
if status == 400:
logger.error(f"API 请求错误: {endpoint} (状态码: {status})")
logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})")
raise APIError(f"Bad request for {endpoint}")
elif status in (401, 403):
logger.error(f"API 认证失败: {endpoint} (状态码: {status})")
raise AuthenticationError(f"Authentication failed for {endpoint}")
elif status == 401:
logger.error(f"[Misskey API] 未授权访问: {endpoint} (HTTP {status})")
raise AuthenticationError(f"Unauthorized access for {endpoint}")
elif status == 403:
logger.error(f"[Misskey API] 访问被禁止: {endpoint} (HTTP {status})")
raise AuthenticationError(f"Forbidden access for {endpoint}")
elif status == 404:
logger.error(f"[Misskey API] 资源不存在: {endpoint} (HTTP {status})")
raise APIError(f"Resource not found for {endpoint}")
elif status == 413:
logger.error(f"[Misskey API] 请求体过大: {endpoint} (HTTP {status})")
raise APIError(f"Request entity too large for {endpoint}")
elif status == 429:
logger.warning(f"API 频率限制: {endpoint} (状态码: {status})")
logger.warning(f"[Misskey API] 请求频率限制: {endpoint} (HTTP {status})")
raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
elif status == 500:
logger.error(f"[Misskey API] 服务器内部错误: {endpoint} (HTTP {status})")
raise APIConnectionError(f"Internal server error for {endpoint}")
elif status == 502:
logger.error(f"[Misskey API] 网关错误: {endpoint} (HTTP {status})")
raise APIConnectionError(f"Bad gateway for {endpoint}")
elif status == 503:
logger.error(f"[Misskey API] 服务不可用: {endpoint} (HTTP {status})")
raise APIConnectionError(f"Service unavailable for {endpoint}")
elif status == 504:
logger.error(f"[Misskey API] 网关超时: {endpoint} (HTTP {status})")
raise APIConnectionError(f"Gateway timeout for {endpoint}")
else:
logger.error(f"API 请求失败: {endpoint} (状态码: {status})")
logger.error(f"[Misskey API] 未知错误: {endpoint} (HTTP {status})")
raise APIConnectionError(f"HTTP {status} for {endpoint}")
async def _process_response(
@@ -286,21 +428,25 @@ class MisskeyAPI:
else []
)
if notifications_data:
logger.debug(f"获取到 {len(notifications_data)} 条新通知")
logger.debug(
f"[Misskey API] 获取到 {len(notifications_data)} 条新通知"
)
else:
logger.debug(f"API 请求成功: {endpoint}")
logger.debug(f"[Misskey API] 请求成功: {endpoint}")
return result
except json.JSONDecodeError as e:
logger.error(f"响应不是有效的 JSON 格式: {e}")
logger.error(f"[Misskey API] 响应格式错误: {e}")
raise APIConnectionError("Invalid JSON response") from e
else:
try:
error_text = await response.text()
logger.error(
f"API 请求失败: {endpoint} - 状态码: {response.status}, 响应: {error_text}"
f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}, 响应: {error_text}"
)
except Exception:
logger.error(f"API 请求失败: {endpoint} - 状态码: {response.status}")
logger.error(
f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}"
)
self._handle_response_status(response.status, endpoint)
raise APIConnectionError(f"Request failed for {endpoint}")
@@ -321,53 +467,307 @@ class MisskeyAPI:
async with self.session.post(url, json=payload) as response:
return await self._process_response(response, endpoint)
except aiohttp.ClientError as e:
logger.error(f"HTTP 请求错误: {e}")
logger.error(f"[Misskey API] HTTP 请求错误: {e}")
raise APIConnectionError(f"HTTP request failed: {e}") from e
async def create_note(
self,
text: str,
text: Optional[str] = None,
visibility: str = "public",
reply_id: Optional[str] = None,
visible_user_ids: Optional[List[str]] = None,
file_ids: Optional[List[str]] = None,
local_only: bool = False,
cw: Optional[str] = None,
poll: Optional[Dict[str, Any]] = None,
renote_id: Optional[str] = None,
channel_id: Optional[str] = None,
reaction_acceptance: Optional[str] = None,
no_extract_mentions: Optional[bool] = None,
no_extract_hashtags: Optional[bool] = None,
no_extract_emojis: Optional[bool] = None,
media_ids: Optional[List[str]] = None,
) -> Dict[str, Any]:
"""创建新贴文"""
data: Dict[str, Any] = {
"text": text,
"visibility": visibility,
"localOnly": local_only,
}
"""Create a note (wrapper for notes/create). All additional fields are optional and passed through to the API."""
data: Dict[str, Any] = {}
if text is not None:
data["text"] = text
data["visibility"] = visibility
data["localOnly"] = local_only
if reply_id:
data["replyId"] = reply_id
if visible_user_ids and visibility == "specified":
data["visibleUserIds"] = visible_user_ids
if file_ids:
data["fileIds"] = file_ids
if media_ids:
data["mediaIds"] = media_ids
if cw is not None:
data["cw"] = cw
if poll is not None:
data["poll"] = poll
if renote_id is not None:
data["renoteId"] = renote_id
if channel_id is not None:
data["channelId"] = channel_id
if reaction_acceptance is not None:
data["reactionAcceptance"] = reaction_acceptance
if no_extract_mentions is not None:
data["noExtractMentions"] = bool(no_extract_mentions)
if no_extract_hashtags is not None:
data["noExtractHashtags"] = bool(no_extract_hashtags)
if no_extract_emojis is not None:
data["noExtractEmojis"] = bool(no_extract_emojis)
result = await self._make_request("notes/create", data)
note_id = result.get("createdNote", {}).get("id", "unknown")
logger.debug(f"发帖成功,note_id: {note_id}")
note_id = (
result.get("createdNote", {}).get("id", "unknown")
if isinstance(result, dict)
else "unknown"
)
logger.debug(f"[Misskey API] 发帖成功: {note_id}")
return result
async def upload_file(
self,
file_path: str,
name: Optional[str] = None,
folder_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Upload a file to Misskey drive/files/create and return a dict containing id and raw result."""
if not file_path:
raise APIError("No file path provided for upload")
url = f"{self.instance_url}/api/drive/files/create"
form = aiohttp.FormData()
form.add_field("i", self.access_token)
try:
filename = name or file_path.split("/")[-1]
if folder_id:
form.add_field("folderId", str(folder_id))
try:
f = open(file_path, "rb")
except FileNotFoundError as e:
logger.error(f"[Misskey API] 本地文件不存在: {file_path}")
raise APIError(f"File not found: {file_path}") from e
try:
form.add_field("file", f, filename=filename)
async with self.session.post(url, data=form) as resp:
result = await self._process_response(resp, "drive/files/create")
file_id = FileIDExtractor.extract_file_id(result)
logger.debug(
f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}"
)
return {"id": file_id, "raw": result}
finally:
f.close()
except aiohttp.ClientError as e:
logger.error(f"[Misskey API] 文件上传网络错误: {e}")
raise APIConnectionError(f"Upload failed: {e}") from e
async def find_files_by_hash(self, md5_hash: str) -> List[Dict[str, Any]]:
"""Find files by MD5 hash"""
if not md5_hash:
raise APIError("No MD5 hash provided for find-by-hash")
data = {"md5": md5_hash}
try:
logger.debug(f"[Misskey API] find-by-hash 请求: md5={md5_hash}")
result = await self._make_request("drive/files/find-by-hash", data)
logger.debug(
f"[Misskey API] find-by-hash 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
)
return result if isinstance(result, list) else []
except Exception as e:
logger.error(f"[Misskey API] 根据哈希查找文件失败: {e}")
raise
async def find_files_by_name(
self, name: str, folder_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Find files by name"""
if not name:
raise APIError("No name provided for find")
data: Dict[str, Any] = {"name": name}
if folder_id:
data["folderId"] = folder_id
try:
logger.debug(f"[Misskey API] find 请求: name={name}, folder_id={folder_id}")
result = await self._make_request("drive/files/find", data)
logger.debug(
f"[Misskey API] find 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
)
return result if isinstance(result, list) else []
except Exception as e:
logger.error(f"[Misskey API] 根据名称查找文件失败: {e}")
raise
async def find_files(
self,
limit: int = 10,
folder_id: Optional[str] = None,
type: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""List files with optional filters"""
data: Dict[str, Any] = {"limit": limit}
if folder_id is not None:
data["folderId"] = folder_id
if type is not None:
data["type"] = type
try:
logger.debug(
f"[Misskey API] 列表文件请求: limit={limit}, folder_id={folder_id}, type={type}"
)
result = await self._make_request("drive/files", data)
logger.debug(
f"[Misskey API] 列表文件响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
)
return result if isinstance(result, list) else []
except Exception as e:
logger.error(f"[Misskey API] 列表文件失败: {e}")
raise
async def _download_with_existing_session(
self, url: str, ssl_verify: bool = True
) -> Optional[bytes]:
"""使用现有会话下载文件"""
if not (hasattr(self, "session") and self.session):
raise APIConnectionError("No existing session available")
async with self.session.get(
url, timeout=aiohttp.ClientTimeout(total=15), ssl=ssl_verify
) as response:
if response.status == 200:
return await response.read()
return None
async def _download_with_temp_session(
self, url: str, ssl_verify: bool = True
) -> Optional[bytes]:
"""使用临时会话下载文件"""
connector = aiohttp.TCPConnector(ssl=ssl_verify)
async with aiohttp.ClientSession(connector=connector) as temp_session:
async with temp_session.get(
url, timeout=aiohttp.ClientTimeout(total=15)
) as response:
if response.status == 200:
return await response.read()
return None
async def upload_and_find_file(
self,
url: str,
name: Optional[str] = None,
folder_id: Optional[str] = None,
max_wait_time: float = 30.0,
check_interval: float = 2.0,
) -> Optional[Dict[str, Any]]:
"""
简化的文件上传尝试 URL 上传失败则下载后本地上传
Args:
url: 文件URL
name: 文件名可选
folder_id: 文件夹ID可选
max_wait_time: 保留参数未使用
check_interval: 保留参数未使用
Returns:
包含文件ID和元信息的字典失败时返回None
"""
if not url:
raise APIError("URL不能为空")
# 通过本地上传获取即时文件 ID(下载文件 → 上传 → 返回 ID)
try:
import tempfile
import os
# SSL 验证下载,失败则重试不验证 SSL
tmp_bytes = None
try:
tmp_bytes = await self._download_with_existing_session(
url, ssl_verify=True
) or await self._download_with_temp_session(url, ssl_verify=True)
except Exception as ssl_error:
logger.debug(
f"[Misskey API] SSL 验证下载失败: {ssl_error},重试不验证 SSL"
)
try:
tmp_bytes = await self._download_with_existing_session(
url, ssl_verify=False
) or await self._download_with_temp_session(url, ssl_verify=False)
except Exception:
pass
if tmp_bytes:
with tempfile.NamedTemporaryFile(delete=False) as tmpf:
tmpf.write(tmp_bytes)
tmp_path = tmpf.name
try:
result = await self.upload_file(tmp_path, name, folder_id)
logger.debug(f"[Misskey API] 本地上传成功: {result.get('id')}")
return result
finally:
try:
os.unlink(tmp_path)
except Exception:
pass
except Exception as e:
logger.error(f"[Misskey API] 本地上传失败: {e}")
return None
async def get_current_user(self) -> Dict[str, Any]:
"""获取当前用户信息"""
return await self._make_request("i", {})
async def send_message(self, user_id: str, text: str) -> Dict[str, Any]:
"""发送聊天消息"""
result = await self._make_request(
"chat/messages/create-to-user", {"toUserId": user_id, "text": text}
)
async def send_message(
self, user_id_or_payload: Any, text: Optional[str] = None
) -> Dict[str, Any]:
"""发送聊天消息。
Accepts either (user_id: str, text: str) or a single dict payload prepared by caller.
"""
if isinstance(user_id_or_payload, dict):
data = user_id_or_payload
else:
data = {"toUserId": user_id_or_payload, "text": text}
result = await self._make_request("chat/messages/create-to-user", data)
message_id = result.get("id", "unknown")
logger.debug(f"聊天发送成功,message_id: {message_id}")
logger.debug(f"[Misskey API] 聊天消息发送成功: {message_id}")
return result
async def send_room_message(self, room_id: str, text: str) -> Dict[str, Any]:
"""发送房间消息"""
result = await self._make_request(
"chat/messages/create-to-room", {"toRoomId": room_id, "text": text}
)
async def send_room_message(
self, room_id_or_payload: Any, text: Optional[str] = None
) -> Dict[str, Any]:
"""发送房间消息。
Accepts either (room_id: str, text: str) or a single dict payload.
"""
if isinstance(room_id_or_payload, dict):
data = room_id_or_payload
else:
data = {"toRoomId": room_id_or_payload, "text": text}
result = await self._make_request("chat/messages/create-to-room", data)
message_id = result.get("id", "unknown")
logger.debug(f"房间消息发送成功message_id: {message_id}")
logger.debug(f"[Misskey API] 房间消息发送成功: {message_id}")
return result
async def get_messages(
@@ -381,9 +781,8 @@ class MisskeyAPI:
result = await self._make_request("chat/messages/user-timeline", data)
if isinstance(result, list):
return result
else:
logger.warning(f"获取聊天消息响应格式异常: {type(result)}")
return []
logger.warning(f"[Misskey API] 聊天消息响应格式异常: {type(result)}")
return []
async def get_mentions(
self, limit: int = 10, since_id: Optional[str] = None
@@ -400,5 +799,142 @@ class MisskeyAPI:
elif isinstance(result, dict) and "notifications" in result:
return result["notifications"]
else:
logger.warning(f"获取提及通知响应格式异常: {type(result)}")
logger.warning(f"[Misskey API] 提及通知响应格式异常: {type(result)}")
return []
async def send_message_with_media(
self,
message_type: str,
target_id: str,
text: Optional[str] = None,
media_urls: Optional[List[str]] = None,
local_files: Optional[List[str]] = None,
**kwargs,
) -> Dict[str, Any]:
"""
通用消息发送函数统一处理文本+媒体发送
Args:
message_type: 消息类型 ('chat', 'room', 'note')
target_id: 目标ID (用户ID/房间ID/频道ID等)
text: 文本内容
media_urls: 媒体文件URL列表
local_files: 本地文件路径列表
**kwargs: 其他参数如visibility等
Returns:
发送结果字典
Raises:
APIError: 参数错误或发送失败
"""
if not text and not media_urls and not local_files:
raise APIError("消息内容不能为空:需要文本或媒体文件")
file_ids = []
# 处理远程媒体文件
if media_urls:
file_ids.extend(await self._process_media_urls(media_urls))
# 处理本地文件
if local_files:
file_ids.extend(await self._process_local_files(local_files))
# 根据消息类型发送
return await self._dispatch_message(
message_type, target_id, text, file_ids, **kwargs
)
async def _process_media_urls(self, urls: List[str]) -> List[str]:
"""处理远程媒体文件URL列表,返回文件ID列表"""
file_ids = []
for url in urls:
try:
result = await self.upload_and_find_file(url)
if result and result.get("id"):
file_ids.append(result["id"])
logger.debug(f"[Misskey API] URL媒体上传成功: {result['id']}")
else:
logger.error(f"[Misskey API] URL媒体上传失败: {url}")
except Exception as e:
logger.error(f"[Misskey API] URL媒体处理失败 {url}: {e}")
# 继续处理其他文件,不中断整个流程
continue
return file_ids
async def _process_local_files(self, file_paths: List[str]) -> List[str]:
"""处理本地文件路径列表,返回文件ID列表"""
file_ids = []
for file_path in file_paths:
try:
result = await self.upload_file(file_path)
if result and result.get("id"):
file_ids.append(result["id"])
logger.debug(f"[Misskey API] 本地文件上传成功: {result['id']}")
else:
logger.error(f"[Misskey API] 本地文件上传失败: {file_path}")
except Exception as e:
logger.error(f"[Misskey API] 本地文件处理失败 {file_path}: {e}")
continue
return file_ids
async def _dispatch_message(
self,
message_type: str,
target_id: str,
text: Optional[str],
file_ids: List[str],
**kwargs,
) -> Dict[str, Any]:
"""根据消息类型分发到对应的发送方法"""
if message_type == "chat":
# 聊天消息使用 fileId (单数)
payload = {"toUserId": target_id}
if text:
payload["text"] = text
if file_ids:
if len(file_ids) == 1:
payload["fileId"] = file_ids[0]
else:
# 多文件时逐个发送
results = []
for file_id in file_ids:
single_payload = payload.copy()
single_payload["fileId"] = file_id
result = await self.send_message(single_payload)
results.append(result)
return {"multiple": True, "results": results}
return await self.send_message(payload)
elif message_type == "room":
# 房间消息使用 fileId (单数)
payload = {"toRoomId": target_id}
if text:
payload["text"] = text
if file_ids:
if len(file_ids) == 1:
payload["fileId"] = file_ids[0]
else:
# 多文件时逐个发送
results = []
for file_id in file_ids:
single_payload = payload.copy()
single_payload["fileId"] = file_id
result = await self.send_room_message(single_payload)
results.append(result)
return {"multiple": True, "results": results}
return await self.send_room_message(payload)
elif message_type == "note":
# 发帖使用 fileIds (复数)
note_kwargs = {
"text": text,
"file_ids": file_ids or None,
}
# 合并其他参数
note_kwargs.update(kwargs)
return await self.create_note(**note_kwargs)
else:
raise APIError(f"不支持的消息类型: {message_type}")
@@ -40,48 +40,83 @@ class MisskeyPlatformEvent(AstrMessageEvent):
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
async def send(self, message: MessageChain):
content, has_at = serialize_message_chain(message.chain)
if not content:
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
return
"""发送消息,使用适配器的完整上传和发送逻辑"""
try:
original_message_id = getattr(self.message_obj, "message_id", None)
raw_message = getattr(self.message_obj, "raw_message", {})
logger.debug(
f"[MisskeyEvent] send 方法被调用,消息链包含 {len(message.chain)} 个组件"
)
if raw_message and not has_at:
user_data = raw_message.get("user", {})
user_info = {
"username": user_data.get("username", ""),
"nickname": user_data.get("name", user_data.get("username", "")),
}
content = add_at_mention_if_needed(content, user_info, has_at)
# 使用适配器的 send_by_session 方法,它包含文件上传逻辑
from astrbot.core.platform.message_session import MessageSession
from astrbot.core.platform.message_type import MessageType
# 根据会话类型选择发送方式
if hasattr(self.client, "send_message") and is_valid_user_session_id(
self.session_id
):
user_id = extract_user_id_from_session_id(self.session_id)
await self.client.send_message(user_id, content)
elif hasattr(self.client, "send_room_message") and is_valid_room_session_id(
self.session_id
):
room_id = extract_room_id_from_session_id(self.session_id)
await self.client.send_room_message(room_id, content)
elif original_message_id and hasattr(self.client, "create_note"):
visibility, visible_user_ids = resolve_visibility_from_raw_message(
raw_message
)
await self.client.create_note(
content,
reply_id=original_message_id,
visibility=visibility,
visible_user_ids=visible_user_ids,
)
elif hasattr(self.client, "create_note"):
logger.debug("[MisskeyEvent] 创建新帖子")
await self.client.create_note(content)
# 根据session_id类型确定消息类型
if is_valid_user_session_id(self.session_id):
message_type = MessageType.FRIEND_MESSAGE
elif is_valid_room_session_id(self.session_id):
message_type = MessageType.GROUP_MESSAGE
else:
message_type = MessageType.FRIEND_MESSAGE # 默认
session = MessageSession(
platform_name=self.platform_meta.name,
message_type=message_type,
session_id=self.session_id,
)
logger.debug(
f"[MisskeyEvent] 检查适配器方法: hasattr(self.client, 'send_by_session') = {hasattr(self.client, 'send_by_session')}"
)
# 调用适配器的 send_by_session 方法
if hasattr(self.client, "send_by_session"):
logger.debug("[MisskeyEvent] 调用适配器的 send_by_session 方法")
await self.client.send_by_session(session, message)
else:
# 回退到原来的简化发送逻辑
content, has_at = serialize_message_chain(message.chain)
if not content:
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
return
original_message_id = getattr(self.message_obj, "message_id", None)
raw_message = getattr(self.message_obj, "raw_message", {})
if raw_message and not has_at:
user_data = raw_message.get("user", {})
user_info = {
"username": user_data.get("username", ""),
"nickname": user_data.get(
"name", user_data.get("username", "")
),
}
content = add_at_mention_if_needed(content, user_info, has_at)
# 根据会话类型选择发送方式
if hasattr(self.client, "send_message") and is_valid_user_session_id(
self.session_id
):
user_id = extract_user_id_from_session_id(self.session_id)
await self.client.send_message(user_id, content)
elif hasattr(
self.client, "send_room_message"
) and is_valid_room_session_id(self.session_id):
room_id = extract_room_id_from_session_id(self.session_id)
await self.client.send_room_message(room_id, content)
elif original_message_id and hasattr(self.client, "create_note"):
visibility, visible_user_ids = resolve_visibility_from_raw_message(
raw_message
)
await self.client.create_note(
content,
reply_id=original_message_id,
visibility=visibility,
visible_user_ids=visible_user_ids,
)
elif hasattr(self.client, "create_note"):
logger.debug("[MisskeyEvent] 创建新帖子")
await self.client.create_note(content)
await super().send(message)
@@ -5,6 +5,68 @@ import astrbot.api.message_components as Comp
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
class FileIDExtractor:
"""从 API 响应中提取文件 ID 的帮助类(无状态)。"""
@staticmethod
def extract_file_id(result: Any) -> Optional[str]:
if not isinstance(result, dict):
return None
id_paths = [
lambda r: r.get("createdFile", {}).get("id"),
lambda r: r.get("file", {}).get("id"),
lambda r: r.get("id"),
]
for p in id_paths:
try:
if fid := p(result):
return fid
except Exception:
continue
return None
class MessagePayloadBuilder:
"""构建不同类型消息负载的帮助类(无状态)。"""
@staticmethod
def build_chat_payload(
user_id: str, text: Optional[str], file_id: Optional[str] = None
) -> Dict[str, Any]:
payload = {"toUserId": user_id}
if text:
payload["text"] = text
if file_id:
payload["fileId"] = file_id
return payload
@staticmethod
def build_room_payload(
room_id: str, text: Optional[str], file_id: Optional[str] = None
) -> Dict[str, Any]:
payload = {"toRoomId": room_id}
if text:
payload["text"] = text
if file_id:
payload["fileId"] = file_id
return payload
@staticmethod
def build_note_payload(
text: Optional[str], file_ids: Optional[List[str]] = None, **kwargs
) -> Dict[str, Any]:
payload: Dict[str, Any] = {}
if text:
payload["text"] = text
if file_ids:
payload["fileIds"] = file_ids
payload |= kwargs
return payload
def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
"""将消息链序列化为文本字符串"""
text_parts = []
@@ -15,11 +77,19 @@ def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
if isinstance(component, Comp.Plain):
return component.text
elif isinstance(component, Comp.File):
file_name = getattr(component, "name", "文件")
return f"[文件: {file_name}]"
# 为文件组件返回占位符,但适配器仍会处理原组件
return "[文件]"
elif isinstance(component, Comp.Image):
# 为图片组件返回占位符,但适配器仍会处理原组件
return "[图片]"
elif isinstance(component, Comp.At):
has_at = True
return f"@{component.qq}"
# 优先使用name字段(用户名),如果没有则使用qq字段
# 这样可以避免在Misskey中生成 @<user_id> 这样的无效提及
if hasattr(component, "name") and component.name:
return f"@{component.name}"
else:
return f"@{component.qq}"
elif hasattr(component, "text"):
text = getattr(component, "text", "")
if "@" in text:
@@ -43,15 +113,22 @@ def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
def resolve_message_visibility(
user_id: Optional[str],
user_cache: Dict[str, Any],
self_id: Optional[str],
user_id: Optional[str] = None,
user_cache: Optional[Dict[str, Any]] = None,
self_id: Optional[str] = None,
raw_message: Optional[Dict[str, Any]] = None,
default_visibility: str = "public",
) -> Tuple[str, Optional[List[str]]]:
"""解析 Misskey 消息的可见性设置"""
"""解析 Misskey 消息的可见性设置
可以从 user_cache raw_message 中解析支持两种调用方式
1. 基于 user_cache: resolve_message_visibility(user_id, user_cache, self_id)
2. 基于 raw_message: resolve_message_visibility(raw_message=raw_message, self_id=self_id)
"""
visibility = default_visibility
visible_user_ids = None
# 优先从 user_cache 解析
if user_id and user_cache:
user_info = user_cache.get(user_id)
if user_info:
@@ -66,38 +143,36 @@ def resolve_message_visibility(
visible_user_ids = [uid for uid in visible_user_ids if uid]
else:
visibility = original_visibility
return visibility, visible_user_ids
# 回退到从 raw_message 解析
if raw_message:
original_visibility = raw_message.get("visibility", default_visibility)
if original_visibility == "specified":
visibility = "specified"
original_visible_users = raw_message.get("visibleUserIds", [])
sender_id = raw_message.get("userId", "")
users_to_include = []
if sender_id:
users_to_include.append(sender_id)
if self_id:
users_to_include.append(self_id)
visible_user_ids = list(set(original_visible_users + users_to_include))
visible_user_ids = [uid for uid in visible_user_ids if uid]
else:
visibility = original_visibility
return visibility, visible_user_ids
# 保留旧函数名作为向后兼容的别名
def resolve_visibility_from_raw_message(
raw_message: Dict[str, Any], self_id: Optional[str] = None
) -> Tuple[str, Optional[List[str]]]:
"""从原始消息数据中解析可见性设置"""
visibility = "public"
visible_user_ids = None
if not raw_message:
return visibility, visible_user_ids
original_visibility = raw_message.get("visibility", "public")
if original_visibility == "specified":
visibility = "specified"
original_visible_users = raw_message.get("visibleUserIds", [])
sender_id = raw_message.get("userId", "")
users_to_include = []
if sender_id:
users_to_include.append(sender_id)
if self_id:
users_to_include.append(self_id)
visible_user_ids = list(set(original_visible_users + users_to_include))
visible_user_ids = [uid for uid in visible_user_ids if uid]
else:
visibility = original_visibility
return visibility, visible_user_ids
"""从原始消息数据中解析可见性设置(已弃用,使用 resolve_message_visibility 替代)"""
return resolve_message_visibility(raw_message=raw_message, self_id=self_id)
def is_valid_user_session_id(session_id: Union[str, Any]) -> bool:
@@ -128,6 +203,20 @@ def is_valid_room_session_id(session_id: Union[str, Any]) -> bool:
)
def is_valid_chat_session_id(session_id: Union[str, Any]) -> bool:
"""检查 session_id 是否是有效的聊天 session_id (仅限chat%前缀)"""
if not isinstance(session_id, str) or "%" not in session_id:
return False
parts = session_id.split("%")
return (
len(parts) == 2
and parts[0] == "chat"
and bool(parts[1])
and parts[1] != "unknown"
)
def extract_user_id_from_session_id(session_id: str) -> str:
"""从 session_id 中提取用户 ID"""
if "%" in session_id:
@@ -149,21 +238,22 @@ def extract_room_id_from_session_id(session_id: str) -> str:
def add_at_mention_if_needed(
text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False
) -> str:
"""如果需要且没有@用户,则添加@用户"""
"""如果需要且没有@用户,则添加@用户
注意仅在有有效的username时才添加@提及避免使用用户ID
"""
if has_at or not user_info:
return text
username = user_info.get("username")
nickname = user_info.get("nickname")
# 如果没有username,则不添加@提及,返回原文本
# 这样可以避免生成 @<user_id> 这样的无效提及
if not username:
return text
if username:
mention = f"@{username}"
if not text.startswith(mention):
text = f"{mention}\n{text}".strip()
elif nickname:
mention = f"@{nickname}"
if not text.startswith(mention):
text = f"{mention}\n{text}".strip()
mention = f"@{username}"
if not text.startswith(mention):
text = f"{mention}\n{text}".strip()
return text
@@ -197,6 +287,22 @@ def process_files(
return file_parts
def format_poll(poll: Dict[str, Any]) -> str:
"""将 Misskey 的 poll 对象格式化为可读字符串。"""
if not poll or not isinstance(poll, dict):
return ""
multiple = poll.get("multiple", False)
choices = poll.get("choices", [])
text_choices = [
f"({idx}) {c.get('text', '')} [{c.get('votes', 0)}票]"
for idx, c in enumerate(choices, start=1)
]
parts = ["[投票]", ("允许多选" if multiple else "单选")] + (
["选项: " + ", ".join(text_choices)] if text_choices else []
)
return " ".join(parts)
def extract_sender_info(
raw_data: Dict[str, Any], is_chat: bool = False
) -> Dict[str, Any]:
@@ -248,7 +354,7 @@ def create_base_message(
else:
session_prefix = "note"
session_id = f"{session_prefix}%{sender_info['sender_id']}"
message.type = MessageType.FRIEND_MESSAGE
message.type = MessageType.OTHER_MESSAGE
message.session_id = (
session_id if sender_info["sender_id"] else f"{session_prefix}%unknown"
@@ -303,6 +409,8 @@ def cache_user_info(
"nickname": sender_info["nickname"],
"visibility": raw_data.get("visibility", "public"),
"visible_user_ids": raw_data.get("visibleUserIds", []),
# 保存原消息ID,用于回复时作为reply_id
"reply_to_note_id": raw_data.get("id"),
}
user_cache[sender_info["sender_id"]] = user_cache_data
@@ -325,3 +433,106 @@ def cache_room_info(
"visibility": "specified",
"visible_user_ids": [client_self_id],
}
async def resolve_component_url_or_path(
comp: Any,
) -> Tuple[Optional[str], Optional[str]]:
"""尝试从组件解析可上传的远程 URL 或本地路径。
返回 (url_candidate, local_path)两者可能都为 None
这个函数尽量不抛异常调用方可按需处理 None
"""
url_candidate = None
local_path = None
async def _get_str_value(coro_or_val):
"""辅助函数:统一处理协程或普通值"""
try:
if hasattr(coro_or_val, "__await__"):
result = await coro_or_val
else:
result = coro_or_val
return result if isinstance(result, str) else None
except Exception:
return None
try:
# 1. 尝试异步方法
for method in ["convert_to_file_path", "get_file", "register_to_file_service"]:
if not hasattr(comp, method):
continue
try:
value = await _get_str_value(getattr(comp, method)())
if value:
if value.startswith("http"):
url_candidate = value
break
else:
local_path = value
except Exception:
continue
# 2. 尝试 get_file(True) 获取可直接访问的 URL
if not url_candidate and hasattr(comp, "get_file"):
try:
value = await _get_str_value(comp.get_file(True))
if value and value.startswith("http"):
url_candidate = value
except Exception:
pass
# 3. 回退到同步属性
if not url_candidate and not local_path:
for attr in ("file", "url", "path", "src", "source"):
try:
value = getattr(comp, attr, None)
if value and isinstance(value, str):
if value.startswith("http"):
url_candidate = value
break
else:
local_path = value
break
except Exception:
continue
except Exception:
pass
return url_candidate, local_path
def summarize_component_for_log(comp: Any) -> Dict[str, Any]:
"""生成适合日志的组件属性字典(尽量不抛异常)。"""
attrs = {}
for a in ("file", "url", "path", "src", "source", "name"):
try:
v = getattr(comp, a, None)
if v is not None:
attrs[a] = v
except Exception:
continue
return attrs
async def upload_local_with_retries(
api: Any,
local_path: str,
preferred_name: Optional[str],
folder_id: Optional[str],
) -> Optional[str]:
"""尝试本地上传,返回 file id 或 None。如果文件类型不允许则直接失败。"""
try:
res = await api.upload_file(local_path, preferred_name, folder_id)
if isinstance(res, dict):
fid = res.get("id") or (res.get("raw") or {}).get("createdFile", {}).get(
"id"
)
if fid:
return str(fid)
except Exception:
# 上传失败,直接返回 None,让上层处理错误
return None
return None
@@ -15,12 +15,13 @@ class QQOfficialWebhook:
self.appid = config["appid"]
self.secret = config["secret"]
self.port = config.get("port", 6196)
self.is_sandbox = config.get("is_sandbox", False)
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
if isinstance(self.port, str):
self.port = int(self.port)
self.http: BotHttp = BotHttp(timeout=300)
self.http: BotHttp = BotHttp(timeout=300, is_sandbox=self.is_sandbox)
self.api: BotAPI = BotAPI(http=self.http)
self.token = Token(self.appid, self.secret)
@@ -499,10 +499,36 @@ class SatoriPlatformAdapter(Platform):
}
return None
except ET.ParseError as e:
logger.warning(f"XML解析失败,使用正则提取: {e}")
return await self._extract_quote_with_regex(content)
except Exception as e:
logger.error(f"提取<quote>标签时发生错误: {e}")
return None
async def _extract_quote_with_regex(self, content: str) -> Optional[dict]:
"""使用正则表达式提取quote标签信息"""
import re
quote_pattern = r"<quote\s+([^>]*)>(.*?)</quote>"
match = re.search(quote_pattern, content, re.DOTALL)
if not match:
return None
attrs_str = match.group(1)
inner_content = match.group(2)
id_match = re.search(r'id\s*=\s*["\']([^"\']*)["\']', attrs_str)
quote_id = id_match.group(1) if id_match else ""
content_without_quote = content.replace(match.group(0), "")
content_without_quote = content_without_quote.strip()
return {
"quote": {"id": quote_id, "content": inner_content},
"content_without_quote": content_without_quote,
}
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
"""转换引用消息"""
try:
@@ -574,7 +600,7 @@ class SatoriPlatformAdapter(Platform):
root = ET.fromstring(processed_content)
await self._parse_xml_node(root, elements)
except ET.ParseError as e:
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
logger.warning(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
# 如果解析失败,将整个内容当作纯文本
if content.strip():
elements.append(Plain(text=content))
@@ -2,7 +2,18 @@ from typing import TYPE_CHECKING
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, At, File, Record, Video, Reply
from astrbot.api.message_components import (
Plain,
Image,
At,
File,
Record,
Video,
Reply,
Forward,
Node,
Nodes,
)
if TYPE_CHECKING:
from .satori_adapter import SatoriPlatformAdapter
@@ -48,55 +59,24 @@ class SatoriPlatformEvent(AstrMessageEvent):
content_parts = []
for component in message.chain:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
content_parts.append(text)
component_content = await cls._convert_component_to_satori_static(
component
)
if component_content:
content_parts.append(component_content)
elif isinstance(component, At):
if component.qq:
content_parts.append(f'<at id="{component.qq}"/>')
elif component.name:
content_parts.append(f'<at name="{component.name}"/>')
# 特殊处理 Node 和 Nodes 组件
if isinstance(component, Node):
# 单个转发节点
node_content = await cls._convert_node_to_satori_static(component)
if node_content:
content_parts.append(node_content)
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
content_parts.append(
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
)
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
content_parts.append(
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
content_parts.append(
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
)
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Reply):
content_parts.append(f'<reply id="{component.id}"/>')
elif isinstance(component, Video):
try:
video_path_url = await component.convert_to_file_path()
if video_path_url:
content_parts.append(f'<video src="{video_path_url}"/>')
except Exception as e:
logger.error(f"视频文件转换失败: {e}")
elif isinstance(component, Nodes):
# 合并转发消息
node_content = await cls._convert_nodes_to_satori_static(component)
if node_content:
content_parts.append(node_content)
content = "".join(content_parts)
channel_id = session_id
@@ -138,55 +118,22 @@ class SatoriPlatformEvent(AstrMessageEvent):
content_parts = []
for component in message.chain:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
content_parts.append(text)
component_content = await self._convert_component_to_satori(component)
if component_content:
content_parts.append(component_content)
elif isinstance(component, At):
if component.qq:
content_parts.append(f'<at id="{component.qq}"/>')
elif component.name:
content_parts.append(f'<at name="{component.name}"/>')
# 特殊处理 Node 和 Nodes 组件
if isinstance(component, Node):
# 单个转发节点
node_content = await self._convert_node_to_satori(component)
if node_content:
content_parts.append(node_content)
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
content_parts.append(
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
)
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
content_parts.append(
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
content_parts.append(
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
)
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Reply):
content_parts.append(f'<reply id="{component.id}"/>')
elif isinstance(component, Video):
try:
video_path_url = await component.convert_to_file_path()
if video_path_url:
content_parts.append(f'<video src="{video_path_url}"/>')
except Exception as e:
logger.error(f"视频文件转换失败: {e}")
elif isinstance(component, Nodes):
# 合并转发消息
node_content = await self._convert_nodes_to_satori(component)
if node_content:
content_parts.append(node_content)
content = "".join(content_parts)
channel_id = self.session_id
@@ -250,3 +197,227 @@ class SatoriPlatformEvent(AstrMessageEvent):
logger.error(f"Satori 流式消息发送异常: {e}")
return await super().send_streaming(generator, use_fallback)
async def _convert_component_to_satori(self, component) -> str:
"""将单个消息组件转换为 Satori 格式"""
try:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
return text
elif isinstance(component, At):
if component.qq:
return f'<at id="{component.qq}"/>'
elif component.name:
return f'<at name="{component.name}"/>'
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
return (
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Reply):
return f'<reply id="{component.id}"/>'
elif isinstance(component, Video):
try:
video_path_url = await component.convert_to_file_path()
if video_path_url:
return f'<video src="{video_path_url}"/>'
except Exception as e:
logger.error(f"视频文件转换失败: {e}")
elif isinstance(component, Forward):
return f'<message id="{component.id}" forward/>'
# 对于其他未处理的组件类型,返回空字符串
return ""
except Exception as e:
logger.error(f"转换消息组件失败: {e}")
return ""
async def _convert_node_to_satori(self, node: Node) -> str:
"""将单个转发节点转换为 Satori 格式"""
try:
content_parts = []
if node.content:
for content_component in node.content:
component_content = await self._convert_component_to_satori(
content_component
)
if component_content:
content_parts.append(component_content)
content = "".join(content_parts)
# 如果内容为空,添加默认内容
if not content.strip():
content = "[转发消息]"
# 构建 Satori 格式的转发节点
author_attrs = []
if node.uin:
author_attrs.append(f'id="{node.uin}"')
if node.name:
author_attrs.append(f'name="{node.name}"')
author_attr_str = " ".join(author_attrs)
return f"<message><author {author_attr_str}/>{content}</message>"
except Exception as e:
logger.error(f"转换转发节点失败: {e}")
return ""
@classmethod
async def _convert_component_to_satori_static(cls, component) -> str:
"""将单个消息组件转换为 Satori 格式"""
try:
if isinstance(component, Plain):
text = (
component.text.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
return text
elif isinstance(component, At):
if component.qq:
return f'<at id="{component.qq}"/>'
elif component.name:
return f'<at name="{component.name}"/>'
elif isinstance(component, Image):
try:
image_base64 = await component.convert_to_base64()
if image_base64:
return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
except Exception as e:
logger.error(f"图片转换为base64失败: {e}")
elif isinstance(component, File):
return (
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
)
elif isinstance(component, Record):
try:
record_base64 = await component.convert_to_base64()
if record_base64:
return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
except Exception as e:
logger.error(f"语音转换为base64失败: {e}")
elif isinstance(component, Reply):
return f'<reply id="{component.id}"/>'
elif isinstance(component, Video):
try:
video_path_url = await component.convert_to_file_path()
if video_path_url:
return f'<video src="{video_path_url}"/>'
except Exception as e:
logger.error(f"视频文件转换失败: {e}")
elif isinstance(component, Forward):
return f'<message id="{component.id}" forward/>'
# 对于其他未处理的组件类型,返回空字符串
return ""
except Exception as e:
logger.error(f"转换消息组件失败: {e}")
return ""
@classmethod
async def _convert_node_to_satori_static(cls, node: Node) -> str:
"""将单个转发节点转换为 Satori 格式"""
try:
content_parts = []
if node.content:
for content_component in node.content:
component_content = await cls._convert_component_to_satori_static(
content_component
)
if component_content:
content_parts.append(component_content)
content = "".join(content_parts)
# 如果内容为空,添加默认内容
if not content.strip():
content = "[转发消息]"
author_attrs = []
if node.uin:
author_attrs.append(f'id="{node.uin}"')
if node.name:
author_attrs.append(f'name="{node.name}"')
author_attr_str = " ".join(author_attrs)
return f"<message><author {author_attr_str}/>{content}</message>"
except Exception as e:
logger.error(f"转换转发节点失败: {e}")
return ""
async def _convert_nodes_to_satori(self, nodes: Nodes) -> str:
"""将多个转发节点转换为 Satori 格式的合并转发"""
try:
node_parts = []
for node in nodes.nodes:
node_content = await self._convert_node_to_satori(node)
if node_content:
node_parts.append(node_content)
if node_parts:
return f"<message forward>{''.join(node_parts)}</message>"
else:
return ""
except Exception as e:
logger.error(f"转换合并转发消息失败: {e}")
return ""
@classmethod
async def _convert_nodes_to_satori_static(cls, nodes: Nodes) -> str:
"""将多个转发节点转换为 Satori 格式的合并转发"""
try:
node_parts = []
for node in nodes.nodes:
node_content = await cls._convert_node_to_satori_static(node)
if node_content:
node_parts.append(node_content)
if node_parts:
return f"<message forward>{''.join(node_parts)}</message>"
else:
return ""
except Exception as e:
logger.error(f"转换合并转发消息失败: {e}")
return ""
+67
View File
@@ -1,4 +1,5 @@
import abc
import asyncio
from typing import List
from typing import AsyncGenerator
from astrbot.core.agent.tool import ToolSet
@@ -203,6 +204,72 @@ class EmbeddingProvider(AbstractProvider):
"""获取向量的维度"""
...
async def get_embeddings_batch(
self,
texts: list[str],
batch_size: int = 16,
tasks_limit: int = 3,
max_retries: int = 3,
progress_callback=None,
) -> list[list[float]]:
"""批量获取文本的向量,分批处理以节省内存
Args:
texts: 文本列表
batch_size: 每批处理的文本数量
tasks_limit: 并发任务数量限制
max_retries: 失败时的最大重试次数
progress_callback: 进度回调函数接收参数 (current, total)
Returns:
向量列表
"""
semaphore = asyncio.Semaphore(tasks_limit)
all_embeddings: list[list[float]] = []
failed_batches: list[tuple[int, list[str]]] = []
completed_count = 0
total_count = len(texts)
async def process_batch(batch_idx: int, batch_texts: list[str]):
nonlocal completed_count
async with semaphore:
for attempt in range(max_retries):
try:
batch_embeddings = await self.get_embeddings(batch_texts)
all_embeddings.extend(batch_embeddings)
completed_count += len(batch_texts)
if progress_callback:
await progress_callback(completed_count, total_count)
return
except Exception as e:
if attempt == max_retries - 1:
# 最后一次重试失败,记录失败的批次
failed_batches.append((batch_idx, batch_texts))
raise Exception(
f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {str(e)}"
)
# 等待一段时间后重试,使用指数退避
await asyncio.sleep(2**attempt)
tasks = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
batch_idx = i // batch_size
tasks.append(process_batch(batch_idx, batch_texts))
# 收集所有任务的结果,包括失败的任务
results = await asyncio.gather(*tasks, return_exceptions=True)
# 检查是否有失败的任务
errors = [r for r in results if isinstance(r, Exception)]
if errors:
error_msg = (
f"{len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
)
raise Exception(error_msg)
return all_embeddings
class RerankProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -10,7 +10,7 @@ from anthropic.types import Message
from astrbot.core.utils.io import download_image_by_url
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.func_tool_manager import ToolSet
from ..register import register_provider_adapter
from astrbot.core.provider.entities import LLMResponse
from typing import AsyncGenerator
@@ -104,7 +104,7 @@ class ProviderAnthropic(Provider):
return system_prompt, new_messages
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
@@ -135,7 +135,7 @@ class ProviderAnthropic(Provider):
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall
self, payloads: dict, tools: ToolSet | None
) -> AsyncGenerator[LLMResponse, None]:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
@@ -326,7 +326,7 @@ class ProviderAnthropic(Provider):
async for llm_response in self._query_stream(payloads, func_tool):
yield llm_response
async def assemble_context(self, text: str, image_urls: List[str] = None):
async def assemble_context(self, text: str, image_urls: List[str] | None = None):
"""组装上下文,支持文本和图片"""
if not image_urls:
return {"role": "user", "content": text}
@@ -1,15 +1,14 @@
import re
import asyncio
import functools
from typing import List
from .. import Provider, Personality
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.message.message_event_result import MessageChain
from .openai_source import ProviderOpenAIOfficial
from astrbot.core import logger, sp
from dashscope import Application
from dashscope.app.application_response import ApplicationResponse
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
@@ -62,11 +61,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
session_id=None,
image_urls=[],
func_tool=None,
contexts=None,
system_prompt=None,
model=None,
**kwargs,
) -> LLMResponse:
@@ -122,6 +121,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
assert isinstance(response, ApplicationResponse)
logger.debug(f"dashscope resp: {response}")
if response.status_code != 200:
@@ -135,12 +136,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
),
)
output_text = response.output.get("text", "")
output_text = response.output.get("text", "") or ""
# RAG 引用脚标格式化
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
if self.output_reference and response.output.get("doc_references", None):
ref_str = ""
for ref in response.output.get("doc_references", []):
for ref in response.output.get("doc_references", []) or []:
ref_title = (
ref.get("title", "")
if ref.get("title")
+6 -8
View File
@@ -1,9 +1,7 @@
import astrbot.core.message.components as Comp
import os
from typing import List
from .. import Provider
from ..entities import LLMResponse
from ..func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url, download_file
@@ -55,11 +53,11 @@ class ProviderDify(Provider):
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
@@ -223,7 +221,7 @@ class ProviderDify(Provider):
# Chat
return MessageChain(chain=[Comp.Plain(chunk)])
async def parse_file(item: dict) -> Comp:
async def parse_file(item: dict):
match item["type"]:
case "image":
return Comp.Image(file=item["url"], url=item["url"])
+18 -15
View File
@@ -16,7 +16,7 @@ from astrbot.core.message.message_event_result import MessageChain
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.func_tool_manager import ToolSet
from typing import List, AsyncGenerator
from ..register import register_provider_adapter
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
@@ -49,7 +49,7 @@ class ProviderOpenAIOfficial(Provider):
self.client = AsyncAzureOpenAI(
api_key=self.chosen_api_key,
api_version=provider_config.get("api_version", None),
base_url=provider_config.get("api_base", None),
base_url=provider_config.get("api_base", ""),
timeout=self.timeout,
)
else:
@@ -79,7 +79,7 @@ class ProviderOpenAIOfficial(Provider):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
@@ -126,7 +126,7 @@ class ProviderOpenAIOfficial(Provider):
return llm_response
async def _query_stream(
self, payloads: dict, tools: FuncCall
self, payloads: dict, tools: ToolSet
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API,逐步返回结果"""
if tools:
@@ -183,9 +183,7 @@ class ProviderOpenAIOfficial(Provider):
yield llm_response
async def parse_openai_completion(
self, completion: ChatCompletion, tools: FuncCall
):
async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
"""解析 OpenAI 的 ChatCompletion 响应"""
llm_response = LLMResponse("assistant")
@@ -208,7 +206,10 @@ class ProviderOpenAIOfficial(Provider):
# workaround for #1359
tool_call = json.loads(tool_call)
for tool in tools.func_list:
if tool.name == tool_call.function.name:
if (
tool_call.type == "function"
and tool.name == tool_call.function.name
):
# workaround for #1454
if isinstance(tool_call.function.arguments, str):
args = json.loads(tool_call.function.arguments)
@@ -277,7 +278,7 @@ class ProviderOpenAIOfficial(Provider):
e: Exception,
payloads: dict,
context_query: list,
func_tool: FuncCall,
func_tool: ToolSet,
chosen_key: str,
available_api_keys: List[str],
retry_cnt: int,
@@ -420,7 +421,7 @@ class ProviderOpenAIOfficial(Provider):
if success:
break
if retry_cnt == max_retries - 1:
if retry_cnt == max_retries - 1 or llm_response is None:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
if last_exception is None:
raise Exception("未知错误")
@@ -430,10 +431,10 @@ class ProviderOpenAIOfficial(Provider):
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = [],
func_tool: FuncCall = None,
contexts=[],
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
@@ -526,7 +527,9 @@ class ProviderOpenAIOfficial(Provider):
def set_key(self, key):
self.client.api_key = key
async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
async def assemble_context(
self, text: str, image_urls: List[str] | None = None
) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
if image_urls:
user_content = {
@@ -30,7 +30,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
timeout=timeout,
)
self.set_model(provider_config.get("model", None))
self.set_model(provider_config.get("model", ""))
async def get_audio(self, text: str) -> str:
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
+1 -1
View File
@@ -19,7 +19,7 @@ from astrbot.core.platform import Platform
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.persona_mgr import PersonaManager
from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
+81
View File
@@ -0,0 +1,81 @@
from astrbot.core.utils.shared_preferences import SharedPreferences
class UmopConfigRouter:
"""UMOP 配置路由器"""
def __init__(self, sp: SharedPreferences):
self.umop_to_conf_id: dict[str, str] = {}
"""UMOP 到配置文件 ID 的映射"""
self.sp = sp
self._load_routing_table()
def _load_routing_table(self):
"""加载路由表"""
# 从 SharedPreferences 中加载 umop_to_conf_id 映射
sp_data = self.sp.get(
"umop_config_routing", {}, scope="global", scope_id="global"
)
self.umop_to_conf_id = sp_data
def _is_umo_match(self, p1: str, p2: str) -> bool:
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
p1_ls = p1.split(":")
p2_ls = p2.split(":")
if len(p1_ls) != 3 or len(p2_ls) != 3:
return False # 非法格式
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
def get_conf_id_for_umop(self, umo: str) -> str | None:
"""根据 UMO 获取对应的配置文件 ID
Args:
umo (str): UMO 字符串
Returns:
str | None: 配置文件 ID如果没有找到则返回 None
"""
for pattern, conf_id in self.umop_to_conf_id.items():
if self._is_umo_match(pattern, umo):
return conf_id
return None
async def update_routing_data(self, new_routing: dict[str, str]):
"""更新路由表
Args:
new_routing (dict[str, str]): 新的 UMOP 到配置文件 ID 的映射umo 由三个部分组成 [platform_id]:[message_type]:[session_id]
umop 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)
Raises:
ValueError: 如果 new_routing 中的 key 格式不正确
"""
for part in new_routing.keys():
if not isinstance(part, str) or len(part.split(":")) != 3:
raise ValueError(
"umop keys must be strings in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all"
)
self.umop_to_conf_id = new_routing
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
async def update_route(self, umo: str, conf_id: str):
"""更新一条路由
Args:
umo (str): UMO 字符串
conf_id (str): 配置文件 ID
Raises:
ValueError: 如果 umo 格式不正确
"""
if not isinstance(umo, str) or len(umo.split(":")) != 3:
raise ValueError(
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all"
)
self.umop_to_conf_id[umo] = conf_id
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
+79 -4
View File
@@ -6,6 +6,7 @@ from .route import Route, Response, RouteContext
from astrbot.core.provider.entities import ProviderType
from quart import request
from astrbot.core.config.default import (
DEFAULT_CONFIG,
CONFIG_METADATA_2,
DEFAULT_VALUE_MAP,
CONFIG_METADATA_3,
@@ -152,13 +153,19 @@ class ConfigRoute(Route):
self.config: AstrBotConfig = core_lifecycle.astrbot_config
self._logo_token_cache = {} # 缓存logo token,避免重复注册
self.acm = core_lifecycle.astrbot_config_mgr
self.ucr = core_lifecycle.umop_config_router
self.routes = {
"/config/abconf/new": ("POST", self.create_abconf),
"/config/abconf": ("GET", self.get_abconf),
"/config/abconfs": ("GET", self.get_abconf_list),
"/config/abconf/delete": ("POST", self.delete_abconf),
"/config/abconf/update": ("POST", self.update_abconf),
"/config/umo_abconf_routes": ("GET", self.get_uc_table),
"/config/umo_abconf_route/update_all": ("POST", self.update_ucr_all),
"/config/umo_abconf_route/update": ("POST", self.update_ucr),
"/config/umo_abconf_route/delete": ("POST", self.delete_ucr),
"/config/get": ("GET", self.get_configs),
"/config/default": ("GET", self.get_default_config),
"/config/astrbot/update": ("POST", self.post_astrbot_configs),
"/config/plugin/update": ("POST", self.post_plugin_configs),
"/config/platform/new": ("POST", self.post_new_platform),
@@ -174,6 +181,75 @@ class ConfigRoute(Route):
}
self.register_routes()
async def get_uc_table(self):
"""获取 UMOP 配置路由表"""
return Response().ok({"routing": self.ucr.umop_to_conf_id}).__dict__
async def update_ucr_all(self):
"""更新 UMOP 配置路由表的全部内容"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
new_routing = post_data.get("routing", None)
if not new_routing or not isinstance(new_routing, dict):
return Response().error("缺少或错误的路由表数据").__dict__
try:
await self.ucr.update_routing_data(new_routing)
return Response().ok(message="更新成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"更新路由表失败: {str(e)}").__dict__
async def update_ucr(self):
"""更新 UMOP 配置路由表"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
umo = post_data.get("umo", None)
conf_id = post_data.get("conf_id", None)
if not umo or not conf_id:
return Response().error("缺少 UMO 或配置文件 ID").__dict__
try:
await self.ucr.update_route(umo, conf_id)
return Response().ok(message="更新成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"更新路由表失败: {str(e)}").__dict__
async def delete_ucr(self):
"""删除 UMOP 配置路由表中的一项"""
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
umo = post_data.get("umo", None)
if not umo:
return Response().error("缺少 UMO").__dict__
try:
if umo in self.ucr.umop_to_conf_id:
del self.ucr.umop_to_conf_id[umo]
await self.ucr.update_routing_data(self.ucr.umop_to_conf_id)
return Response().ok(message="删除成功").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"删除路由表项失败: {str(e)}").__dict__
async def get_default_config(self):
"""获取默认配置文件"""
return (
Response()
.ok({"config": DEFAULT_CONFIG, "metadata": CONFIG_METADATA_3})
.__dict__
)
async def get_abconf_list(self):
"""获取所有 AstrBot 配置文件的列表"""
abconf_list = self.acm.get_conf_list()
@@ -184,11 +260,11 @@ class ConfigRoute(Route):
post_data = await request.json
if not post_data:
return Response().error("缺少配置数据").__dict__
umo_parts = post_data["umo_parts"]
name = post_data.get("name", None)
config = post_data.get("config", DEFAULT_CONFIG)
try:
conf_id = self.acm.create_conf(umo_parts=umo_parts, name=name)
conf_id = self.acm.create_conf(name=name, config=config)
return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
@@ -250,10 +326,9 @@ class ConfigRoute(Route):
return Response().error("缺少配置文件 ID").__dict__
name = post_data.get("name")
umo_parts = post_data.get("umo_parts")
try:
success = self.acm.update_conf_info(conf_id, name=name, umo_parts=umo_parts)
success = self.acm.update_conf_info(conf_id, name=name)
if success:
return Response().ok(message="更新成功").__dict__
else:
File diff suppressed because it is too large Load Diff
+10 -2
View File
@@ -52,10 +52,17 @@ class UpdateRoute(Route):
try:
dv = await get_dashboard_version()
# WebUI 版本独立于核心版本:不再用 dv 与 v{VERSION} 比较,避免误报
if type_ == "dashboard":
return (
Response()
.ok({"has_new_version": dv != f"v{VERSION}", "current_version": dv})
.ok(
{
"has_new_version": False,
"current_version": dv,
"installed": bool(dv),
}
)
.__dict__
)
else:
@@ -67,7 +74,8 @@ class UpdateRoute(Route):
"version": f"v{VERSION}",
"has_new_version": ret is not None,
"dashboard_version": dv,
"dashboard_has_new_version": dv and dv != f"v{VERSION}",
# dv正常获取则不会提示需要更新
"dashboard_has_new_version": not bool(dv),
},
).__dict__
except Exception as e:
+161
View File
@@ -0,0 +1,161 @@
import base64
import os
import traceback
from io import BytesIO
from astrbot.api import logger
from astrbot.core.knowledge_base.kb_helper import KBHelper
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
async def generate_tsne_visualization(
query: str, kb_names: list[str], kb_manager: KnowledgeBaseManager
) -> str | None:
"""生成 t-SNE 可视化图片
Args:
query: 查询文本
kb_names: 知识库名称列表
kb_manager: 知识库管理器
Returns:
图片路径或 None
"""
try:
import faiss
import numpy as np
import matplotlib
matplotlib.use("Agg") # 使用非交互式后端
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
except ImportError as e:
raise Exception(
"缺少必要的库以生成 t-SNE 可视化。请安装 matplotlib 和 scikit-learn: {e}"
) from e
try:
# 获取第一个知识库的向量数据
kb_helper: KBHelper | None = None
for kb_name in kb_names:
kb_helper = await kb_manager.get_kb_by_name(kb_name)
if kb_helper:
break
if not kb_helper:
logger.warning("未找到知识库")
return None
kb = kb_helper.kb
index_path = f"data/knowledge_base/{kb.kb_id}/index.faiss"
# 读取 FAISS 索引
if not os.path.exists(index_path):
logger.warning(f"FAISS 索引不存在: {index_path}")
return None
index = faiss.read_index(index_path)
if index.ntotal == 0:
logger.warning("索引为空")
return None
# 提取所有向量
logger.info(f"提取 {index.ntotal} 个向量用于可视化...")
if isinstance(index, faiss.IndexIDMap):
base_index = faiss.downcast_index(index.index)
if hasattr(base_index, "reconstruct_n"):
vectors = base_index.reconstruct_n(0, index.ntotal)
else:
vectors = np.zeros((index.ntotal, index.d), dtype=np.float32)
for i in range(index.ntotal):
base_index.reconstruct(i, vectors[i])
elif hasattr(index, "reconstruct_n"):
vectors = index.reconstruct_n(0, index.ntotal)
else:
vectors = np.zeros((index.ntotal, index.d), dtype=np.float32)
for i in range(index.ntotal):
index.reconstruct(i, vectors[i])
# 获取查询向量
vec_db: FaissVecDB = kb_helper.vec_db # type: ignore
embedding_provider = vec_db.embedding_provider
query_embedding = await embedding_provider.get_embedding(query)
query_vector = np.array([query_embedding], dtype=np.float32)
# 合并所有向量和查询向量
all_vectors = np.vstack([vectors, query_vector])
# t-SNE 降维
logger.info("开始 t-SNE 降维...")
perplexity = min(30, all_vectors.shape[0] - 1)
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
vectors_2d = tsne.fit_transform(all_vectors)
# 分离知识库向量和查询向量
kb_vectors_2d = vectors_2d[:-1]
query_vector_2d = vectors_2d[-1]
# 可视化
logger.info("生成可视化图表...")
plt.figure(figsize=(14, 10))
# 绘制知识库向量
scatter = plt.scatter(
kb_vectors_2d[:, 0],
kb_vectors_2d[:, 1],
alpha=0.5,
s=40,
c=range(len(kb_vectors_2d)),
cmap="viridis",
label="Knowledge Base Vectors",
)
# 绘制查询向量(红色 X
plt.scatter(
query_vector_2d[0],
query_vector_2d[1],
c="red",
s=300,
marker="X",
edgecolors="black",
linewidths=2,
label="Query",
zorder=5,
)
# 添加查询文本标注
plt.annotate(
"Query",
(query_vector_2d[0], query_vector_2d[1]),
xytext=(10, 10),
textcoords="offset points",
fontsize=10,
bbox={"boxstyle": "round,pad=0.5", "fc": "yellow", "alpha": 0.7},
arrowprops={"arrowstyle": "->", "connectionstyle": "arc3,rad=0"},
)
plt.colorbar(scatter, label="Vector Index")
plt.title(
f"t-SNE Visualization: Query in Knowledge Base\n"
f"({index.ntotal} vectors, {index.d} dimensions, KB: {kb.kb_name})",
fontsize=14,
pad=20,
)
plt.xlabel("t-SNE Dimension 1", fontsize=12)
plt.ylabel("t-SNE Dimension 2", fontsize=12)
plt.grid(True, alpha=0.3)
plt.legend(fontsize=10, loc="upper right")
# base64 编码图片返回
buffer = BytesIO()
plt.savefig(buffer, format="png", dpi=150, bbox_inches="tight")
plt.close()
buffer.seek(0)
img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
return img_base64
except Exception as e:
logger.error(f"生成 t-SNE 可视化时出错: {e}")
logger.error(traceback.format_exc())
return None
+1
View File
@@ -44,6 +44,7 @@
"@mdi/font": "7.2.96",
"@rushstack/eslint-patch": "1.3.3",
"@types/chance": "1.1.3",
"@types/markdown-it": "^14.1.2",
"@types/node": "^20.5.7",
"@vitejs/plugin-vue": "4.3.3",
"@vue/eslint-config-prettier": "8.0.0",
@@ -0,0 +1,105 @@
<template>
<div :class="$vuetify.display.mobile ? '' : 'd-flex'">
<v-tabs v-model="tab" :direction="$vuetify.display.mobile ? 'horizontal' : 'vertical'"
:align-tabs="$vuetify.display.mobile ? 'left' : 'start'" color="deep-purple-accent-4" class="config-tabs">
<v-tab v-for="(val, key, index) in metadata" :key="index" :value="index"
style="font-weight: 1000; font-size: 15px">
{{ metadata[key]['name'] }}
</v-tab>
</v-tabs>
<v-tabs-window v-model="tab" class="config-tabs-window" :style="readonly ? 'pointer-events: none; opacity: 0.6;' : ''">
<v-tabs-window-item v-for="(val, key, index) in metadata" v-show="index == tab" :key="index">
<v-container fluid>
<div v-for="(val2, key2, index2) in metadata[key]['metadata']" :key="key2">
<!-- Support both traditional and JSON selector metadata -->
<AstrBotConfigV4 :metadata="{ [key2]: metadata[key]['metadata'][key2] }" :iterable="config_data"
:metadataKey="key2">
</AstrBotConfigV4>
</div>
</v-container>
</v-tabs-window-item>
<div style="margin-left: 16px; padding-bottom: 16px">
<small>{{ tm('help.helpPrefix') }}
<a href="https://astrbot.app/" target="_blank">{{ tm('help.documentation') }}</a>
{{ tm('help.helpMiddle') }}
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=EYGsuUTfe00_iOu9JTXS7_TEpMkXOvwv&jump_from=webapi&authKey=uUEMKCROfsseS+8IzqPjzV3y1tzy4AkykwTib2jNkOFdzezF9s9XknqnIaf3CDft"
target="_blank">{{ tm('help.support') }}</a>{{ tm('help.helpSuffix') }}
</small>
</div>
</v-tabs-window>
</div>
</template>
<script>
import AstrBotConfigV4 from '@/components/shared/AstrBotConfigV4.vue';
import { useModuleI18n } from '@/i18n/composables';
export default {
name: 'AstrBotCoreConfigWrapper',
components: {
AstrBotConfigV4
},
props: {
metadata: {
type: Object,
required: true,
default: () => ({})
},
config_data: {
type: Object,
required: true,
default: () => ({})
},
readonly: {
type: Boolean,
default: false
}
},
setup() {
const { tm } = useModuleI18n('features/config');
return {
tm
};
},
data() {
return {
tab: 0, //
}
},
methods: {
//
}
}
</script>
<style>
@media (min-width: 768px) {
.config-tabs {
display: flex;
margin: 16px 16px 0 0;
}
.config-tabs-window {
flex: 1;
}
.config-tabs .v-tab {
justify-content: flex-start !important;
text-align: left;
min-height: 48px;
}
}
@media (max-width: 767px) {
.config-tabs {
width: 100%;
}
.config-tabs-window {
margin-top: 16px;
}
}
</style>
File diff suppressed because it is too large Load Diff
@@ -135,7 +135,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
<!-- Regular Property -->
<template v-else>
<v-row v-if="!metadata[metadataKey].items[key]?.invisible && shouldShowItem(metadata[metadataKey].items[key], key)" class="config-row">
<v-col cols="12" sm="6" class="property-info">
<v-col cols="12" sm="7" class="property-info">
<v-list-item density="compact">
<v-list-item-title class="property-name">
<span v-if="metadata[metadataKey].items[key]?.description">
@@ -153,16 +153,6 @@ function hasVisibleItemsAfter(items, currentIndex) {
</v-list-item>
</v-col>
<v-col cols="12" sm="1" class="d-flex align-center type-indicator">
<v-chip v-if="!metadata[metadataKey].items[key]?.invisible"
color="primary"
label
size="x-small"
variant="flat">
{{ metadata[metadataKey].items[key]?.type || 'string' }}
</v-chip>
</v-col>
<v-col cols="12" sm="5" class="config-input">
<div v-if="metadata[metadataKey].items[key]" class="w-100">
<!-- Special handling for specific metadata types -->
@@ -335,7 +325,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
<!-- Simple Value Configuration -->
<div v-else class="simple-config">
<v-row class="config-row">
<v-col cols="12" sm="6" class="property-info">
<v-col cols="12" sm="7" class="property-info">
<v-list-item density="compact">
<v-list-item-title class="property-name">
{{ metadata[metadataKey]?.description }}
@@ -349,16 +339,6 @@ function hasVisibleItemsAfter(items, currentIndex) {
</v-list-item>
</v-col>
<v-col cols="12" sm="1" class="d-flex align-center type-indicator">
<v-chip v-if="!metadata[metadataKey]?.invisible"
color="primary"
label
size="x-small"
variant="flat">
{{ metadata[metadataKey]?.type }}
</v-chip>
</v-col>
<v-col cols="12" sm="5" class="config-input">
<div class="w-100">
<!-- Select input -->
@@ -548,8 +528,8 @@ function hasVisibleItemsAfter(items, currentIndex) {
}
.config-divider {
border-color: rgba(0, 0, 0, 0.1);
margin: 4px 0;
border-color: rgba(0, 0, 0, 0.05);
margin: 0px 16px;
}
.editor-container {
@@ -120,7 +120,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
<v-card style="margin-bottom: 16px; padding-bottom: 8px; background-color: rgb(var(--v-theme-background));" rounded="md" variant="outlined">
<v-card-text class="config-section" v-if="metadata[metadataKey]?.type === 'object'">
<v-card-text class="config-section" v-if="metadata[metadataKey]?.type === 'object'" style="padding-bottom: 8px;">
<v-list-item-title class="config-title">
{{ metadata[metadataKey]?.description }}
</v-list-item-title>
@@ -365,7 +365,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
.config-row {
margin: 0;
align-items: center;
padding: 10px 8px;
padding: 8px 8px;
border-radius: 4px;
}
+115 -115
View File
@@ -1,6 +1,6 @@
<script setup lang="ts">
import { ref, computed, inject } from 'vue';
import {useCustomizerStore} from "@/stores/customizer";
import { useCustomizerStore } from "@/stores/customizer";
import { useModuleI18n } from '@/i18n/composables';
const props = defineProps({
@@ -84,130 +84,130 @@ const viewReadme = () => {
</script>
<template>
<v-card class="mx-auto d-flex flex-column" :elevation="highlight ? 0 : 1"
:style="{ height: $vuetify.display.xs ? '250px' : '220px',
backgroundColor: useCustomizerStore().uiTheme==='PurpleTheme' ? marketMode ? '#f8f0dd' : '#ffffff' : '#282833',
color: useCustomizerStore().uiTheme==='PurpleTheme' ? '#000000dd' : '#ffffff'}">
<v-card-text style="padding: 16px; padding-bottom: 0px; display: flex; justify-content: space-between;">
<v-card class="mx-auto d-flex flex-column" elevation="2" :style="{
position: 'relative',
backgroundColor: useCustomizerStore().uiTheme === 'PurpleTheme' ? marketMode ? '#f8f0dd' : '#ffffff' : '#282833',
color: useCustomizerStore().uiTheme === 'PurpleTheme' ? '#000000dd' : '#ffffff'
}">
<v-card-text style="padding: 16px; padding-bottom: 0px; display: flex; gap: 16px; width: 100%;">
<div class="flex-grow-1">
<div>{{ extension.author }} /</div>
<div v-if="extension?.icon">
<v-avatar size="65">
<v-img :src="extension.icon"
:alt="extension.name" cover></v-img>
</v-avatar>
</div>
<p class="text-h4 font-weight-black" :class="{ 'text-h4': $vuetify.display.xs }">
{{ extension.name }}
<v-tooltip location="top" v-if="extension?.has_update && !marketMode">
<template v-slot:activator="{ props: tooltipProps }">
<v-icon v-bind="tooltipProps" color="warning" class="ml-2" icon="mdi-update" size="small"></v-icon>
<div style="width: 100%;">
<!-- Top-right three-dot menu -->
<div style="position: absolute; right: 8px; top: 8px; z-index: 5;">
<v-menu offset-y>
<template v-slot:activator="{ props: menuProps }">
<v-btn v-bind="menuProps" icon variant="text" aria-label="more">
<v-icon icon="mdi-dots-vertical"></v-icon>
</v-btn>
</template>
<span>{{ tm("card.status.hasUpdate") }}: {{ extension.online_version }}</span>
</v-tooltip>
<v-tooltip location="top" v-if="!extension.activated && !marketMode">
<template v-slot:activator="{ props: tooltipProps }">
<v-icon v-bind="tooltipProps" color="error" class="ml-2" icon="mdi-cancel" size="small"></v-icon>
</template>
<span>{{ tm("card.status.disabled") }}</span>
</v-tooltip>
</p>
<div class="mt-1 d-flex flex-wrap">
<v-chip color="primary" label size="small">
<v-icon icon="mdi-source-branch" start></v-icon>
{{ extension.version }}
</v-chip>
<v-chip v-if="extension?.has_update " color="warning" label size="small" class="ml-2">
<v-icon icon="mdi-arrow-up-bold" start></v-icon>
{{ extension.online_version }}
</v-chip>
<v-chip color="primary" label size="small" class="ml-2" v-if="extension.handlers?.length">
<v-icon icon="mdi-cogs" start></v-icon>
{{ extension.handlers?.length }}{{ tm("card.status.handlersCount") }}
</v-chip>
<v-chip v-for="tag in extension.tags" :key="tag" :color="tag === 'danger' ? 'error' : 'primary'" label
size="small" class="ml-2">
{{ tag === 'danger' ? tm('tags.danger') : tag }}
</v-chip>
<v-list>
<v-list-item @click="viewReadme">
<v-list-item-title>📄 {{ tm('buttons.viewDocs') }}</v-list-item-title>
</v-list-item>
<v-list-item v-if="marketMode && !extension?.installed" @click="installExtension">
<v-list-item-title>
{{ tm('buttons.install') }}</v-list-item-title>
</v-list-item>
<v-list-item v-if="marketMode && extension?.installed">
<v-list-item-title class="text--disabled">{{ tm('status.installed') }}</v-list-item-title>
</v-list-item>
<!-- Divider between market actions and plugin actions -->
<v-divider v-if="!marketMode" />
<template v-if="!marketMode">
<v-list-item @click="configure">
<v-list-item-title>
{{ tm('card.actions.pluginConfig') }}</v-list-item-title>
</v-list-item>
<v-list-item @click="uninstallExtension">
<v-list-item-title class="text-error">{{ tm('card.actions.uninstallPlugin') }}</v-list-item-title>
</v-list-item>
<v-list-item @click="reloadExtension">
<v-list-item-title>{{ tm('card.actions.reloadPlugin') }}</v-list-item-title>
</v-list-item>
<v-list-item @click="toggleActivation">
<v-list-item-title>
{{ extension.activated ? tm('buttons.disable') : tm('buttons.enable') }}{{
tm('card.actions.togglePlugin') }}
</v-list-item-title>
</v-list-item>
<v-list-item @click="viewHandlers">
<v-list-item-title>{{ tm('card.actions.viewHandlers') }} ({{ extension.handlers.length
}})</v-list-item-title>
</v-list-item>
<v-list-item @click="updateExtension" :disabled="!extension?.has_update">
<v-list-item-title>
{{ tm('card.actions.updateTo') }} {{ extension.online_version || extension.version }}
</v-list-item-title>
</v-list-item>
</template>
</v-list>
</v-menu>
</div>
<div class="mt-2" :class="{ 'text-caption': $vuetify.display.xs }" style="max-height: 65px; overflow-y: auto;">
{{ extension.desc }}
<div style="width: 100%; margin-bottom: 24px;">
<!-- 最多一行 -->
<div class="text-caption" style="color: gray; white-space: nowrap; overflow: hidden; text-overflow: ellipsis; margin-right: 36px;">
{{ extension.author }} / {{ extension.name }}
</div>
<p class="text-h3 font-weight-black" :class="{ 'text-h4': $vuetify.display.xs }">
{{ extension.name }}
<v-tooltip location="top" v-if="extension?.has_update && !marketMode">
<template v-slot:activator="{ props: tooltipProps }">
<v-icon v-bind="tooltipProps" color="warning" class="ml-2" icon="mdi-update" size="small"></v-icon>
</template>
<span>{{ tm("card.status.hasUpdate") }}: {{ extension.online_version }}</span>
</v-tooltip>
<v-tooltip location="top" v-if="!extension.activated && !marketMode">
<template v-slot:activator="{ props: tooltipProps }">
<v-icon v-bind="tooltipProps" color="error" class="ml-2" icon="mdi-cancel" size="small"></v-icon>
</template>
<span>{{ tm("card.status.disabled") }}</span>
</v-tooltip>
</p>
<div class="mt-1 d-flex flex-wrap">
<v-chip color="primary" label size="small">
<v-icon icon="mdi-source-branch" start></v-icon>
{{ extension.version }}
</v-chip>
<v-chip v-if="extension?.has_update" color="warning" label size="small" class="ml-2">
<v-icon icon="mdi-arrow-up-bold" start></v-icon>
{{ extension.online_version }}
</v-chip>
<v-chip color="primary" label size="small" class="ml-2" v-if="extension.handlers?.length">
<v-icon icon="mdi-cogs" start></v-icon>
{{ extension.handlers?.length }}{{ tm("card.status.handlersCount") }}
</v-chip>
<v-chip v-for="tag in extension.tags" :key="tag" :color="tag === 'danger' ? 'error' : 'primary'" label
size="small" class="ml-2">
{{ tag === 'danger' ? tm('tags.danger') : tag }}
</v-chip>
</div>
<div class="mt-2" :class="{ 'text-caption': $vuetify.display.xs }" style="overflow-y: auto; height: 60px;">
{{ extension.desc }}
</div>
</div>
</div>
<div class="extension-image-container" v-if="extension.logo">
<img :src="extension.logo" :style="{
height: $vuetify.display.xs ? '75px' : '100px',
width: $vuetify.display.xs ? '75px' : '100px',
borderRadius: '8px',
objectFit: 'cover',
objectPosition: 'center'
}" :alt="tm('card.alt.logo')" />
</div>
</v-card-text>
<v-card-actions style="margin-left: 0px; gap: 2px;">
<v-btn color="teal-accent-4" :text="tm('buttons.viewDocs')" variant="text" @click="viewReadme"></v-btn>
<v-btn v-if="!marketMode" color="teal-accent-4" :text="tm('buttons.actions')" variant="text" @click="reveal = true"></v-btn>
<v-btn v-if="marketMode && !extension?.installed" color="teal-accent-4" :text="tm('buttons.install')" variant="text"
@click="installExtension"></v-btn>
<v-btn v-if="marketMode && extension?.installed" color="teal-accent-4" :text="tm('status.installed')" variant="text" disabled></v-btn>
</v-card-actions>
<v-expand-transition v-if="!marketMode">
<v-card v-if="reveal" class="position-absolute w-100" height="100%"
style="bottom: 0; display: flex; flex-direction: column;">
<v-card-text style="overflow-y: auto;">
<div class="d-flex align-center mb-4">
<img v-if="extension.logo" :src="extension.logo"
style="height: 50px; width: 50px; border-radius: 8px; margin-right: 16px;" :alt="tm('card.alt.extensionIcon')" />
<h3>{{ extension.name }}</h3>
</div>
<div class="mt-4" :style="{
justifyContent: 'center',
display: 'flex',
alignItems: 'center',
flexWrap: 'wrap',
gap: '8px',
flexDirection: $vuetify.display.xs ? 'column' : 'row'
}">
<v-btn prepend-icon="mdi-cog" color="primary" variant="tonal" @click="configure"
:block="$vuetify.display.xs">
{{ tm("card.actions.pluginConfig") }}
</v-btn>
<v-btn prepend-icon="mdi-delete" color="error" variant="tonal" @click="uninstallExtension"
:block="$vuetify.display.xs">
{{ tm("card.actions.uninstallPlugin") }}
</v-btn>
<v-btn prepend-icon="mdi-reload" color="primary" variant="tonal" @click="reloadExtension"
:block="$vuetify.display.xs">
{{ tm("card.actions.reloadPlugin") }}
</v-btn>
<v-btn :prepend-icon="extension.activated ? 'mdi-cancel' : 'mdi-check-circle'"
:color="extension.activated ? 'error' : 'success'" variant="tonal" @click="toggleActivation"
:block="$vuetify.display.xs">
{{ extension.activated ? tm('buttons.disable') : tm('buttons.enable') }}{{ tm("card.actions.togglePlugin") }}
</v-btn>
<v-btn prepend-icon="mdi-cogs" color="info" variant="tonal" @click="viewHandlers"
:block="$vuetify.display.xs">
{{ tm("card.actions.viewHandlers") }} ({{ extension.handlers.length }})
</v-btn>
<v-btn prepend-icon="mdi-update" color="primary" variant="tonal" :disabled="!extension?.has_update "
@click="updateExtension" :block="$vuetify.display.xs">
{{ tm("card.actions.updateTo") }} {{ extension.online_version || extension.version }}
</v-btn>
</div>
</v-card-text>
<v-card-actions class="pt-0 d-flex justify-center">
<v-btn color="teal-accent-4" :text="tm('buttons.back')" variant="text" @click="reveal = false"></v-btn>
</v-card-actions>
</v-card>
</v-expand-transition>
</v-card>
</template>
@@ -127,7 +127,6 @@ export default {
transition: all 0.3s ease;
overflow: hidden;
min-height: 220px;
margin-bottom: 16px;
display: flex;
flex-direction: column;
justify-content: space-between;
@@ -1,11 +1,21 @@
<template>
<div class="d-flex align-center justify-space-between">
<span v-if="!modelValue" style="color: rgb(var(--v-theme-primaryText));">
<span v-if="!modelValue || (Array.isArray(modelValue) && modelValue.length === 0)"
style="color: rgb(var(--v-theme-primaryText));">
未选择
</span>
<span v-else>
{{ modelValue }}
</span>
<div v-else class="d-flex flex-wrap gap-1">
<v-chip
v-for="name in modelValue"
:key="name"
size="small"
color="primary"
variant="tonal"
closable
@click:close="removeKnowledgeBase(name)">
{{ name }}
</v-chip>
</div>
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
{{ buttonText }}
</v-btn>
@@ -21,86 +31,57 @@
<v-card-text class="pa-0" style="max-height: 400px; overflow-y: auto;">
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
<!-- 插件未安装提示 -->
<div v-if="!loading && !pluginInstalled" class="text-center py-8">
<v-icon size="64" color="grey-lighten-1">mdi-puzzle-outline</v-icon>
<p class="text-grey mt-4 mb-4">知识库插件未安装</p>
<v-btn color="primary" variant="tonal" @click="goToKnowledgeBasePage">
前往知识库页面
</v-btn>
</div>
<!-- 知识库列表 -->
<v-list v-else-if="!loading && pluginInstalled" density="compact">
<!-- 不使用选项 -->
<v-list-item
:value="''"
@click="selectKnowledgeBase({ collection_name: '' })"
:active="selectedKnowledgeBase === ''"
rounded="md"
class="ma-1">
<template v-slot:prepend>
<v-icon color="grey-lighten-1">mdi-close-circle-outline</v-icon>
</template>
<v-list-item-title>不使用</v-list-item-title>
<v-list-item-subtitle>不使用任何知识库</v-list-item-subtitle>
<template v-slot:append>
<v-icon v-if="selectedKnowledgeBase === ''" color="primary">mdi-check-circle</v-icon>
</template>
</v-list-item>
<v-divider v-if="knowledgeBaseList.length > 0" class="my-2"></v-divider>
<v-list v-if="!loading" density="compact">
<!-- 知识库选项 -->
<v-list-item
v-for="kb in knowledgeBaseList"
:key="kb.collection_name"
:value="kb.collection_name"
@click="selectKnowledgeBase(kb)"
:active="selectedKnowledgeBase === kb.collection_name"
:key="kb.kb_id"
:value="kb.kb_name"
@click="selectKnowledgeBase(kb.kb_name)"
:active="isSelected(kb.kb_name)"
rounded="md"
class="ma-1">
<template v-slot:prepend>
<span class="emoji-icon">{{ kb.emoji || '🙂' }}</span>
<span class="emoji-icon">{{ kb.emoji || '📚' }}</span>
</template>
<v-list-item-title>{{ kb.collection_name }}</v-list-item-title>
<v-list-item-title>{{ kb.kb_name }}</v-list-item-title>
<v-list-item-subtitle>
{{ kb.description || '无描述' }}
<span v-if="kb.count !== undefined"> - {{ kb.count }} 项知识</span>
<span v-if="kb.doc_count !== undefined"> - {{ kb.doc_count }} 个文档</span>
<span v-if="kb.chunk_count !== undefined"> - {{ kb.chunk_count }} 个块</span>
</v-list-item-subtitle>
<template v-slot:append>
<v-icon v-if="selectedKnowledgeBase === kb.collection_name" color="primary">mdi-check-circle</v-icon>
<v-icon v-if="isSelected(kb.kb_name)" color="primary">
mdi-checkbox-marked
</v-icon>
<v-icon v-else color="grey-lighten-1">
mdi-checkbox-blank-outline
</v-icon>
</template>
</v-list-item>
<!-- 当没有知识库时显示创建提示 -->
<div v-if="knowledgeBaseList.length === 0" class="text-center py-4">
<p class="text-grey mb-4">暂无知识库</p>
<v-btn color="primary" variant="tonal" size="small" @click="goToKnowledgeBasePage">
<div v-if="knowledgeBaseList.length === 0" class="text-center py-8">
<v-icon size="64" color="grey-lighten-1">mdi-database-off</v-icon>
<p class="text-grey mt-4 mb-4">暂无知识库</p>
<v-btn color="primary" variant="tonal" @click="goToKnowledgeBasePage">
创建知识库
</v-btn>
</div>
</v-list>
<!-- 空状态插件未安装时保留原有逻辑 -->
<div v-else-if="!loading && !pluginInstalled && knowledgeBaseList.length === 0" class="text-center py-8">
<v-icon size="64" color="grey-lighten-1">mdi-database-off</v-icon>
<p class="text-grey mt-4 mb-4">暂无知识库</p>
<v-btn color="primary" variant="tonal" @click="goToKnowledgeBasePage">
创建知识库
</v-btn>
</div>
</v-card-text>
<v-card-actions class="pa-4">
<div v-if="selectedKnowledgeBases.length > 0" class="text-caption text-grey">
已选择 {{ selectedKnowledgeBases.length }} 个知识库
</div>
<v-spacer></v-spacer>
<v-btn variant="text" @click="cancelSelection">取消</v-btn>
<v-btn
color="primary"
@click="confirmSelection"
:disabled="selectedKnowledgeBase === null || selectedKnowledgeBase === undefined">
@click="confirmSelection">
确认选择
</v-btn>
</v-card-actions>
@@ -115,8 +96,8 @@ import { useRouter } from 'vue-router'
const props = defineProps({
modelValue: {
type: String,
default: ''
type: Array,
default: () => []
},
buttonText: {
type: String,
@@ -130,74 +111,87 @@ const router = useRouter()
const dialog = ref(false)
const knowledgeBaseList = ref([])
const loading = ref(false)
const selectedKnowledgeBase = ref('')
const pluginInstalled = ref(false)
const selectedKnowledgeBases = ref([])
// modelValue selectedKnowledgeBase
// modelValue selectedKnowledgeBases
watch(() => props.modelValue, (newValue) => {
selectedKnowledgeBase.value = newValue || ''
selectedKnowledgeBases.value = Array.isArray(newValue) ? [...newValue] : []
}, { immediate: true })
async function openDialog() {
selectedKnowledgeBase.value = props.modelValue || ''
//
selectedKnowledgeBases.value = Array.isArray(props.modelValue)
? [...props.modelValue]
: []
dialog.value = true
await checkPluginAndLoadKnowledgeBases()
await loadKnowledgeBases()
}
async function checkPluginAndLoadKnowledgeBases() {
async function loadKnowledgeBases() {
loading.value = true
try {
//
const pluginResponse = await axios.get('/api/plugin/get?name=astrbot_plugin_knowledge_base')
const response = await axios.get('/api/kb/list', {
params: {
page: 1,
page_size: 100
}
})
if (pluginResponse.data.status === 'ok' && pluginResponse.data.data.length > 0) {
pluginInstalled.value = true
//
await loadKnowledgeBases()
if (response.data.status === 'ok') {
knowledgeBaseList.value = response.data.data.items || []
} else {
pluginInstalled.value = false
console.error('加载知识库列表失败:', response.data.message)
knowledgeBaseList.value = []
}
} catch (error) {
console.error('检查知识库插件失败:', error)
pluginInstalled.value = false
console.error('加载知识库列表失败:', error)
knowledgeBaseList.value = []
} finally {
loading.value = false
}
}
async function loadKnowledgeBases() {
try {
const response = await axios.get('/api/plug/alkaid/kb/collections')
if (response.data.status === 'ok') {
knowledgeBaseList.value = response.data.data || []
} else {
knowledgeBaseList.value = []
}
} catch (error) {
console.error('加载知识库列表失败:', error)
knowledgeBaseList.value = []
function isSelected(kbName) {
return selectedKnowledgeBases.value.includes(kbName)
}
function selectKnowledgeBase(kbName) {
//
const index = selectedKnowledgeBases.value.indexOf(kbName)
if (index > -1) {
selectedKnowledgeBases.value.splice(index, 1)
} else {
selectedKnowledgeBases.value.push(kbName)
}
}
function selectKnowledgeBase(kb) {
selectedKnowledgeBase.value = kb.collection_name
function removeKnowledgeBase(kbName) {
const index = selectedKnowledgeBases.value.indexOf(kbName)
if (index > -1) {
selectedKnowledgeBases.value.splice(index, 1)
}
//
emit('update:modelValue', [...selectedKnowledgeBases.value])
}
function confirmSelection() {
emit('update:modelValue', selectedKnowledgeBase.value)
emit('update:modelValue', [...selectedKnowledgeBases.value])
dialog.value = false
}
function cancelSelection() {
selectedKnowledgeBase.value = props.modelValue || ''
//
selectedKnowledgeBases.value = Array.isArray(props.modelValue)
? [...props.modelValue]
: []
dialog.value = false
}
function goToKnowledgeBasePage() {
dialog.value = false
router.push('/alkaid/knowledge-base')
router.push('/knowledge-base')
}
</script>
@@ -222,4 +216,8 @@ function goToKnowledgeBasePage() {
align-items: center;
justify-content: center;
}
.gap-1 {
gap: 4px;
}
</style>
@@ -0,0 +1,536 @@
<template>
<v-dialog v-model="showDialog" max-width="500px" persistent>
<v-card>
<v-card-title class="text-h2">
{{ editingPersona ? tm('dialog.edit.title') : tm('dialog.create.title') }}
</v-card-title>
<v-card-text>
<v-form ref="personaForm" v-model="formValid">
<v-text-field v-model="personaForm.persona_id" :label="tm('form.personaId')"
:rules="personaIdRules" :disabled="editingPersona" variant="outlined" density="comfortable"
class="mb-4" />
<v-textarea v-model="personaForm.system_prompt" :label="tm('form.systemPrompt')"
:rules="systemPromptRules" variant="outlined" rows="6" class="mb-4" />
<v-expansion-panels v-model="expandedPanels" multiple>
<!-- 工具选择面板 -->
<v-expansion-panel value="tools">
<v-expansion-panel-title>
<v-icon class="mr-2">mdi-tools</v-icon>
{{ tm('form.tools') }}
<v-chip v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
size="small" color="primary" variant="tonal" class="ml-2">
{{ personaForm.tools.length }}
</v-chip>
</v-expansion-panel-title>
<v-expansion-panel-text>
<div class="mb-3">
<p class="text-body-2 text-medium-emphasis">
{{ tm('form.toolsHelp') }}
</p>
</div>
<v-radio-group class="mt-2" v-model="toolSelectValue" hide-details="true">
<v-radio label="默认使用全部函数工具" value="0"></v-radio>
<v-radio label="选择指定函数工具" value="1">
</v-radio>
</v-radio-group>
<div v-if="toolSelectValue === '1'" class="mt-3 ml-8">
<!-- 工具搜索 -->
<v-text-field v-model="toolSearch" :label="tm('form.searchTools')"
prepend-inner-icon="mdi-magnify" variant="outlined" density="compact"
hide-details clearable class="mb-3" />
<!-- MCP 服务器 -->
<div v-if="mcpServers.length > 0" class="mb-4">
<h4 class="text-subtitle-2 mb-2">{{ tm('form.mcpServersQuickSelect') }}</h4>
<div class="d-flex flex-wrap ga-2">
<v-chip v-for="server in mcpServers" :key="server.name"
:color="isServerSelected(server) ? 'primary' : 'default'"
:variant="isServerSelected(server) ? 'flat' : 'outlined'"
size="small" clickable @click="toggleMcpServer(server)"
:disabled="!server.tools || server.tools.length === 0">
<v-icon start size="small">mdi-server</v-icon>
{{ server.name }}
<v-chip-text v-if="server.tools" class="ml-1">
({{ server.tools.length }})
</v-chip-text>
</v-chip>
</div>
</div>
<!-- 工具选择列表 -->
<div v-if="filteredTools.length > 0" class="tools-selection">
<v-virtual-scroll :items="filteredTools" height="300" item-height="48">
<template v-slot:default="{ item }">
<v-list-item :key="item.name" density="comfortable"
@click="toggleTool(item.name)">
<template v-slot:prepend>
<v-checkbox-btn :model-value="isToolSelected(item.name)"
@click.stop="toggleTool(item.name)" />
</template>
<v-list-item-title>
{{ item.name }}
<v-chip v-if="item.mcp_server_name" size="x-small"
color="secondary" variant="tonal" class="ml-2">
{{ item.mcp_server_name }}
</v-chip>
</v-list-item-title>
<v-list-item-subtitle v-if="item.description">
{{ truncateText(item.description, 100) }}
</v-list-item-subtitle>
</v-list-item>
</template>
</v-virtual-scroll>
</div>
<div v-else-if="!loadingTools && availableTools.length === 0"
class="text-center pa-4">
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-tools</v-icon>
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsAvailable')
}}
</p>
</div>
<div v-else-if="!loadingTools && filteredTools.length === 0"
class="text-center pa-4">
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-magnify</v-icon>
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsFound') }}
</p>
</div>
<!-- 加载状态 -->
<div v-if="loadingTools" class="text-center pa-4">
<v-progress-circular indeterminate color="primary" />
<p class="text-body-2 text-medium-emphasis mt-2">{{ tm('form.loadingTools')
}}
</p>
</div>
<!-- 已选择的工具 -->
<div class="mt-4">
<h4 class="text-subtitle-2 mb-2">
{{ tm('form.selectedTools') }}
<span v-if="personaForm.tools === null" class="text-success">
({{ tm('form.allSelected') }})
</span>
<span v-else-if="Array.isArray(personaForm.tools)">
({{ personaForm.tools.length }})
</span>
</h4>
<div v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
class="d-flex flex-wrap ga-1" style="max-height: 100px; overflow-y: auto;">
<v-chip v-for="toolName in personaForm.tools" :key="toolName"
size="small" color="primary" variant="tonal" closable
@click:close="removeTool(toolName)">
{{ toolName }}
</v-chip>
</div>
<div v-else class="text-body-2 text-medium-emphasis">
{{ tm('form.noToolsSelected') }}
</div>
</div>
</div>
</v-expansion-panel-text>
</v-expansion-panel>
<!-- 预设对话面板 -->
<v-expansion-panel value="dialogs">
<v-expansion-panel-title>
<v-icon class="mr-2">mdi-chat</v-icon>
{{ tm('form.presetDialogs') }}
<v-chip v-if="personaForm.begin_dialogs.length > 0" size="small" color="primary"
variant="tonal" class="ml-2">
{{ personaForm.begin_dialogs.length / 2 }}
</v-chip>
</v-expansion-panel-title>
<v-expansion-panel-text>
<div class="mb-3">
<p class="text-body-2 text-medium-emphasis">
{{ tm('form.presetDialogsHelp') }}
</p>
</div>
<div v-for="(dialog, index) in personaForm.begin_dialogs" :key="index" class="mb-3">
<v-textarea v-model="personaForm.begin_dialogs[index]"
:label="index % 2 === 0 ? tm('form.userMessage') : tm('form.assistantMessage')"
:rules="getDialogRules(index)" variant="outlined" rows="2"
density="comfortable">
<template v-slot:append>
<v-btn icon="mdi-delete" variant="text" size="small" color="error"
@click="removeDialog(index)" />
</template>
</v-textarea>
</div>
<v-btn variant="outlined" prepend-icon="mdi-plus" @click="addDialogPair" block>
{{ tm('buttons.addDialogPair') }}
</v-btn>
</v-expansion-panel-text>
</v-expansion-panel>
</v-expansion-panels>
</v-form>
</v-card-text>
<v-card-actions>
<v-spacer />
<v-btn color="grey" variant="text" @click="closeDialog">
{{ tm('buttons.cancel') }}
</v-btn>
<v-btn color="primary" variant="flat" @click="savePersona" :loading="saving" :disabled="!formValid">
{{ tm('buttons.save') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</template>
<script>
import axios from 'axios';
import { useModuleI18n } from '@/i18n/composables';
export default {
name: 'PersonaForm',
props: {
modelValue: {
type: Boolean,
default: false
},
editingPersona: {
type: Object,
default: null
}
},
emits: ['update:modelValue', 'saved', 'error'],
setup() {
const { tm } = useModuleI18n('features/persona');
return { tm };
},
data() {
return {
toolSelectValue: '0', //
saving: false,
expandedPanels: [],
formValid: false,
mcpServers: [],
availableTools: [],
loadingTools: false,
personaForm: {
persona_id: '',
system_prompt: '',
begin_dialogs: [],
tools: []
},
personaIdRules: [
v => !!v || this.tm('validation.required'),
v => (v && v.length >= 0) || this.tm('validation.minLength', { min: 2 }),
],
systemPromptRules: [
v => !!v || this.tm('validation.required'),
v => (v && v.length >= 10) || this.tm('validation.minLength', { min: 10 })
],
toolSearch: ''
}
},
computed: {
showDialog: {
get() {
return this.modelValue;
},
set(value) {
this.$emit('update:modelValue', value);
}
},
filteredTools() {
if (!this.toolSearch) {
return this.availableTools;
}
const search = this.toolSearch.toLowerCase();
return this.availableTools.filter(tool =>
tool.name.toLowerCase().includes(search) ||
(tool.description && tool.description.toLowerCase().includes(search)) ||
(tool.mcp_server_name && tool.mcp_server_name.toLowerCase().includes(search))
);
}
},
watch: {
modelValue(newValue) {
if (newValue) {
//
if (this.editingPersona) {
this.initFormWithPersona(this.editingPersona);
} else {
this.initForm();
}
this.loadMcpServers();
this.loadTools();
}
},
editingPersona: {
immediate: true,
handler(newPersona) {
// editingPersona
if (this.modelValue) {
if (newPersona) {
this.initFormWithPersona(newPersona);
} else {
this.initForm();
}
}
}
},
toolSelectValue(newValue) {
if (newValue === '0') {
//
this.personaForm.tools = null;
} else if (newValue === '1') {
// null
if (this.personaForm.tools === null) {
this.personaForm.tools = [];
}
}
}
},
methods: {
initForm() {
this.personaForm = {
persona_id: '',
system_prompt: '',
begin_dialogs: [],
tools: []
};
this.toolSelectValue = '0';
this.expandedPanels = [];
},
initFormWithPersona(persona) {
this.personaForm = {
persona_id: persona.persona_id,
system_prompt: persona.system_prompt,
begin_dialogs: [...(persona.begin_dialogs || [])],
tools: persona.tools === null ? null : [...(persona.tools || [])]
};
// tools toolSelectValue
this.toolSelectValue = persona.tools === null ? '0' : '1';
this.expandedPanels = [];
},
closeDialog() {
this.showDialog = false;
},
async loadMcpServers() {
try {
const response = await axios.get('/api/tools/mcp/servers');
if (response.data.status === 'ok') {
this.mcpServers = response.data.data || [];
} else {
this.$emit('error', response.data.message || 'Failed to load MCP servers');
}
} catch (error) {
this.$emit('error', error.response?.data?.message || 'Failed to load MCP servers');
this.mcpServers = [];
}
},
async loadTools() {
this.loadingTools = true;
try {
const response = await axios.get('/api/tools/list');
if (response.data.status === 'ok') {
this.availableTools = response.data.data || [];
} else {
this.$emit('error', response.data.message || 'Failed to load tools');
}
} catch (error) {
this.$emit('error', error.response?.data?.message || 'Failed to load tools');
this.availableTools = [];
} finally {
this.loadingTools = false;
}
},
async savePersona() {
if (!this.formValid) return;
//
if (this.personaForm.begin_dialogs.length > 0) {
for (let i = 0; i < this.personaForm.begin_dialogs.length; i++) {
if (!this.personaForm.begin_dialogs[i] || this.personaForm.begin_dialogs[i].trim() === '') {
const dialogType = i % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
this.$emit('error', this.tm('validation.dialogRequired', { type: dialogType }));
return;
}
}
}
this.saving = true;
try {
const url = this.editingPersona ? '/api/persona/update' : '/api/persona/create';
const response = await axios.post(url, this.personaForm);
if (response.data.status === 'ok') {
this.$emit('saved', response.data.message || this.tm('messages.saveSuccess'));
this.closeDialog();
} else {
this.$emit('error', response.data.message || this.tm('messages.saveError'));
}
} catch (error) {
this.$emit('error', error.response?.data?.message || this.tm('messages.saveError'));
}
this.saving = false;
},
addDialogPair() {
this.personaForm.begin_dialogs.push('', '');
//
if (!this.expandedPanels.includes('dialogs')) {
this.expandedPanels.push('dialogs');
}
},
removeDialog(index) {
//
if (index % 2 === 0 && index + 1 < this.personaForm.begin_dialogs.length) {
this.personaForm.begin_dialogs.splice(index, 2);
}
//
else if (index % 2 === 1 && index - 1 >= 0) {
this.personaForm.begin_dialogs.splice(index - 1, 2);
}
},
toggleMcpServer(server) {
if (!server.tools || server.tools.length === 0) return;
//
if (this.personaForm.tools === null) {
//
this.personaForm.tools = this.availableTools.map(tool => tool.name)
.filter(toolName => !server.tools.includes(toolName));
this.toolSelectValue = '1'; //
return;
}
// tools
if (!Array.isArray(this.personaForm.tools)) {
this.personaForm.tools = [];
this.toolSelectValue = '1';
}
//
const serverTools = server.tools;
const allSelected = serverTools.every(toolName => this.personaForm.tools.includes(toolName));
if (allSelected) {
//
this.personaForm.tools = this.personaForm.tools.filter(
toolName => !serverTools.includes(toolName)
);
} else {
//
serverTools.forEach(toolName => {
if (!this.personaForm.tools.includes(toolName)) {
this.personaForm.tools.push(toolName);
}
});
}
},
toggleTool(toolName) {
//
if (this.personaForm.tools === null) {
//
//
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
this.toolSelectValue = '1'; //
} else if (Array.isArray(this.personaForm.tools)) {
const index = this.personaForm.tools.indexOf(toolName);
if (index !== -1) {
//
this.personaForm.tools.splice(index, 1);
} else {
//
this.personaForm.tools.push(toolName);
}
} else {
// toolsnull
this.personaForm.tools = [toolName];
this.toolSelectValue = '1';
}
},
removeTool(toolName) {
//
if (this.personaForm.tools === null) {
//
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
this.toolSelectValue = '1'; //
} else if (Array.isArray(this.personaForm.tools)) {
const index = this.personaForm.tools.indexOf(toolName);
if (index !== -1) {
this.personaForm.tools.splice(index, 1);
}
}
},
truncateText(text, maxLength) {
if (!text) return '';
return text.length > maxLength ? text.substring(0, maxLength) + '...' : text;
},
getDialogRules(index) {
const dialogType = index % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
return [
v => !!v || this.tm('validation.dialogRequired', { type: dialogType }),
v => (v && v.trim().length > 0) || this.tm('validation.dialogRequired', { type: dialogType })
];
},
isToolSelected(toolName) {
//
if (this.personaForm.tools === null) {
return true;
}
return Array.isArray(this.personaForm.tools) && this.personaForm.tools.includes(toolName);
},
isServerSelected(server) {
if (!server.tools || server.tools.length === 0) return false;
//
if (this.personaForm.tools === null) {
return true;
}
//
return Array.isArray(this.personaForm.tools) &&
server.tools.every(toolName => this.personaForm.tools.includes(toolName));
}
}
}
</script>
<style scoped>
.tools-selection {
max-height: 300px;
overflow-y: auto;
}
.v-virtual-scroll {
padding-bottom: 16px;
}
</style>
@@ -18,7 +18,7 @@
选择人格
</v-card-title>
<v-card-text class="pa-0" style="max-height: 400px; overflow-y: auto;">
<v-card-text class="pa-2" style="max-height: 400px; overflow-y: auto;">
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
<v-list v-if="!loading && personaList.length > 0" density="compact">
@@ -48,6 +48,9 @@
</v-card-text>
<v-card-actions class="pa-4">
<v-btn variant="text" color="primary" prepend-icon="mdi-plus" @click="openCreatePersona">
创建新人格
</v-btn>
<v-spacer></v-spacer>
<v-btn variant="text" @click="cancelSelection">取消</v-btn>
<v-btn
@@ -59,11 +62,22 @@
</v-card-actions>
</v-card>
</v-dialog>
<!-- 创建人格对话框 -->
<PersonaForm
v-model="showCreateDialog"
:editing-persona="null"
:mcp-servers="mcpServers"
:available-tools="availableTools"
:loading-tools="loadingTools"
@saved="handlePersonaCreated"
@error="handleError" />
</template>
<script setup>
import { ref, watch } from 'vue'
import axios from 'axios'
import PersonaForm from './PersonaForm.vue'
const props = defineProps({
modelValue: {
@@ -82,6 +96,7 @@ const dialog = ref(false)
const personaList = ref([])
const loading = ref(false)
const selectedPersona = ref('')
const showCreateDialog = ref(false)
// modelValue selectedPersona
watch(() => props.modelValue, (newValue) => {
@@ -135,6 +150,21 @@ function cancelSelection() {
selectedPersona.value = props.modelValue || ''
dialog.value = false
}
function openCreatePersona() {
showCreateDialog.value = true
}
async function handlePersonaCreated(message) {
console.log('人格创建成功:', message)
showCreateDialog.value = false
//
await loadPersonas()
}
function handleError(error) {
console.error('创建人格失败:', error)
}
</script>
<style scoped>
@@ -1,5 +1,5 @@
{
"dashboard": "Dashboard",
"dashboard": "Dashboard",
"platforms": "Platforms",
"providers": "Providers",
"persona": "Persona",
@@ -12,10 +12,12 @@
"console": "Console",
"alkaid": "Alkaid Lab",
"knowledgeBase": "Knowledge Base",
"knowledgeBaseBeta": "Knowledge Base (Beta)",
"about": "About",
"settings": "Settings",
"documentation": "Documentation",
"github": "GitHub",
"drag": "Drag"
}
"drag": "Drag",
"groups": {
"more": "More Features"
}
}
@@ -24,7 +24,9 @@
"title": "Conversation Title",
"platform": "Platform",
"type": "Type",
"sessionId": "ID",
"cid": "Conversation ID",
"umo": "Unified Message Origin",
"sessionId": "Session ID",
"createdAt": "Created At",
"updatedAt": "Updated At",
"actions": "Actions"
@@ -46,36 +46,23 @@
"selectFile": "Select File",
"dropzone": "Drop files here or click to select",
"supportedFormats": "Supported formats: TXT, PDF, Markdown",
"maxSize": "Max file size: 50MB",
"maxSize": "Max file size: 128MB",
"chunkSettings": "Chunk Settings",
"batchSettings": "Batch Settings",
"chunkSize": "Chunk Size",
"chunkSizeHint": "Number of characters per chunk (default: 512)",
"chunkOverlap": "Chunk Overlap",
"chunkOverlapHint": "Overlapping characters between chunks (default: 50)",
"batchSize": "Batch Size",
"batchSizeHint": "Number of chunks to process in each batch (default: 32)",
"tasksLimit": "Concurrent Tasks Limit",
"tasksLimitHint": "Maximum number of concurrent upload tasks (default: 3)",
"maxRetries": "Max Retries",
"maxRetriesHint": "Number of times to retry a failed upload task (default: 3)",
"cancel": "Cancel",
"submit": "Upload",
"fileRequired": "Please select a file to upload"
},
"sessions": {
"title": "Session Configuration",
"subtitle": "Configure which sessions can use this knowledge base",
"empty": "No session configurations",
"add": "Add Configuration",
"scope": "Scope",
"scopeId": "Identifier",
"topK": "Top K Results",
"enableRerank": "Enable Rerank",
"actions": "Actions",
"edit": "Edit",
"delete": "Delete",
"scopeSession": "Session Level",
"scopePlatform": "Platform Level",
"deleteConfirm": "Are you sure you want to delete this configuration?",
"addSuccess": "Configuration added successfully",
"addFailed": "Failed to add configuration",
"deleteSuccess": "Configuration deleted successfully",
"deleteFailed": "Failed to delete configuration"
},
"settings": {
"title": "Knowledge Base Settings",
"basic": "Basic Settings",
@@ -21,7 +21,8 @@
"delete": "Delete",
"preview": "Preview",
"search": "Search Chunks",
"searchPlaceholder": "Enter keywords to search chunks..."
"searchPlaceholder": "Enter keywords to search chunks...",
"showing": "Showing"
},
"edit": {
"title": "Edit Chunk",
@@ -1,7 +1,7 @@
{
"dashboard": "统计",
"platforms": "消息平台",
"providers": "服务提供商",
"dashboard": "数据统计",
"platforms": "机器人",
"providers": "模型提供商",
"persona": "人格设定",
"toolUse": "MCP",
"extension": "插件",
@@ -12,10 +12,12 @@
"console": "控制台",
"alkaid": "Alkaid",
"knowledgeBase": "知识库",
"knowledgeBaseBeta": "知识库 (Beta)",
"about": "关于",
"settings": "设置",
"documentation": "官方文档",
"github": "GitHub",
"drag": "拖拽"
}
"drag": "拖拽",
"groups": {
"more": "更多功能"
}
}
@@ -3,7 +3,7 @@
"subtitle": "管理和查看用户对话历史记录",
"filters": {
"title": "筛选条件",
"platform": "消息平台 ID",
"platform": "机器人 ID",
"type": "类型",
"search": "搜索关键词",
"reset": "重置"
@@ -22,9 +22,11 @@
"table": {
"headers": {
"title": "对话标题",
"platform": "消息平台 ID",
"type": "类型",
"sessionId": "ID (UMO)",
"platform": "机器人 ID",
"type": "消息类型",
"cid": "对话 ID",
"umo": "消息会话来源",
"sessionId": "会话 ID",
"createdAt": "创建时间",
"updatedAt": "更新时间",
"actions": "操作"
@@ -47,32 +47,23 @@
"selectFile": "选择文件",
"dropzone": "拖放文件到这里或点击选择",
"supportedFormats": "支持的格式: TXT, PDF, Markdown",
"maxSize": "最大文件大小: 50MB",
"maxSize": "最大文件大小: 128MB",
"chunkSettings": "分块设置",
"batchSettings": "批处理设置",
"chunkSize": "分块大小",
"chunkSizeHint": "每个文本块的字符数 (默认: 512)",
"chunkOverlap": "分块重叠",
"chunkOverlapHint": "相邻文本块之间的重叠字符数 (默认: 50)",
"batchSize": "批处理大小",
"batchSizeHint": "每批处理的文本块数量 (默认: 32)",
"tasksLimit": "并发任务限制",
"tasksLimitHint": "最大并发上传任务数 (默认: 3)",
"maxRetries": "最大重试次数",
"maxRetriesHint": "上传失败任务的重试次数 (默认: 3)",
"cancel": "取消",
"submit": "上传",
"fileRequired": "请选择要上传的文件"
},
"sessions": {
"title": "使用该知识库的会话",
"subtitle": "以下会话正在使用此知识库",
"empty": "暂无会话使用此知识库",
"refresh": "刷新",
"scope": "范围",
"scopeId": "会话标识",
"topK": "返回结果数",
"enableRerank": "启用重排序",
"actions": "操作",
"scopeSession": "会话级别",
"scopePlatform": "平台级别",
"viewInSessionManagement": "在会话管理中查看",
"goToSessionManagement": "前往会话管理",
"loadFailed": "加载会话列表失败"
},
"retrieval": {
"title": "知识库检索",
"subtitle": "使用稠密检索和稀疏检索测试知识库内容",
@@ -21,7 +21,11 @@
"delete": "删除",
"preview": "预览",
"search": "搜索分块",
"searchPlaceholder": "输入关键词搜索分块内容..."
"searchPlaceholder": "输入关键词搜索分块内容...",
"showing": "显示",
"deleteConfirm": "确定要删除该文本块吗?",
"deleteSuccess": "文本块删除成功",
"deleteFailed": "文本块删除失败"
},
"edit": {
"title": "编辑分块",
@@ -1,9 +1,9 @@
{
"title": "平台适配器管理",
"subtitle": "管理机器人的平台适配器,连接到不同的聊天平台",
"title": "机器人",
"subtitle": "管理平台适配器实例,连接到不同的消息平台",
"adapters": "平台适配器",
"addAdapter": "新增适配器",
"emptyText": "暂无平台适配器,点击 新增适配器 添加",
"addAdapter": "创建机器人",
"emptyText": "暂无平台适配器,点击 创建机器人 添加",
"details": {
"adapterType": "适配器类型",
"token": "Token",
@@ -17,11 +17,11 @@
"dialog": {
"add": "新增",
"edit": "编辑",
"adapter": "平台适配器",
"adapter": "机器人",
"refresh": "刷新",
"cancel": "取消",
"save": "保存",
"addPlatform": "添加平台适配器",
"addPlatform": "创建机器人",
"connectTitle": "接入 {name}",
"viewTutorial": "查看接入教程",
"noTemplates": "暂无平台模板",
@@ -1,10 +1,10 @@
{
"title": "服务提供商管理",
"subtitle": "管理模型服务提供商",
"title": "模型提供商",
"subtitle": "管理模型提供商",
"providers": {
"title": "服务提供商",
"title": "模型提供商",
"settings": "设置",
"addProvider": "新增服务提供商",
"addProvider": "新增模型提供商",
"providerType": "提供商类型",
"tabs": {
"all": "全部",
@@ -15,8 +15,8 @@
"rerank": "重排序(Rerank)"
},
"empty": {
"all": "暂无服务提供商,点击 新增服务提供商 添加",
"typed": "暂无{type}类型的服务提供商,点击 新增服务提供商 添加"
"all": "暂无模型提供商,点击 新增模型提供商 添加",
"typed": "暂无{type}类型的模型提供商,点击 新增模型提供商 添加"
},
"description": {
"openai": "也支持所有兼容 OpenAI API 的模型提供商。",
@@ -25,10 +25,10 @@
}
},
"availability": {
"title": "服务提供商可用性",
"title": "模型提供商可用性",
"subtitle": "通过测试模型对话可用性判断,可能产生API费用",
"refresh": "刷新状态",
"noData": "点击\"刷新状态\"按钮获取服务提供商可用性",
"noData": "点击\"刷新状态\"按钮获取模型提供商可用性",
"available": "可用",
"unavailable": "不可用",
"pending": "检查中...",
@@ -42,7 +42,7 @@
},
"dialogs": {
"addProvider": {
"title": "服务提供商",
"title": "模型提供商",
"tabs": {
"basic": "基本",
"speechToText": "语音转文字",
@@ -55,15 +55,15 @@
"config": {
"addTitle": "新增",
"editTitle": "编辑",
"provider": "服务提供商",
"provider": "模型提供商",
"cancel": "取消",
"save": "保存"
},
"settings": {
"title": "服务提供商设置",
"title": "模型提供商设置",
"sessionSeparation": {
"title": "启用提供商会话隔离",
"description": "不同会话将可独立选择文本生成、TTS、STT 等服务提供商。"
"description": "不同会话将可独立选择文本生成、TTS、STT 等模型提供商。"
},
"close": "关闭"
}
@@ -78,11 +78,11 @@
},
"error": {
"sessionSeparation": "获取会话隔离配置失败",
"fetchStatus": "获取服务提供商状态失败",
"fetchStatus": "获取模型提供商状态失败",
"testError": "测试 {id} 失败: {error}"
},
"confirm": {
"delete": "确定要删除服务提供商 {id} 吗?"
"delete": "确定要删除模型提供商 {id} 吗?"
}
}
}
@@ -23,7 +23,7 @@
"table": {
"headers": {
"sessionStatus": "会话状态",
"sessionInfo": "ID (UMO)",
"sessionInfo": "消息会话来源",
"persona": "人格",
"chatProvider": "聊天模型",
"sttProvider": "语音识别模型",
@@ -311,7 +311,7 @@ commonStore.getStartTime();
<v-icon>mdi-menu</v-icon>
</v-btn>
<div class="logo-container" :class="{ 'mobile-logo': $vuetify.display.xs }" @click="$router.push('/about')">
<div class="logo-container" :class="{ 'mobile-logo': $vuetify.display.xs }" @click="router.push('/about')">
<span class="logo-text">Astr<span class="logo-text-light">Bot</span></span>
<span class="version-text hidden-xs">{{ botCurrVersion }}</span>
</div>
@@ -1,20 +1,38 @@
<script setup>
import { useI18n } from '@/i18n/composables';
import { useCustomizerStore } from '@/stores/customizer';
import { computed } from 'vue';
const props = defineProps({ item: Object, level: Number });
const { t } = useI18n();
const customizer = useCustomizerStore();
const itemStyle = computed(() => {
const lvl = props.level ?? 0;
const indent = customizer.mini_sidebar ? '0px' : `${lvl * 24}px`;
return { '--indent-padding': indent };
});
</script>
<template>
<v-list-item
:to="item.type === 'external' ? '' : item.to"
:href="item.type === 'external' ? item.to : ''"
rounded
class="mb-1"
color="secondary"
:disabled="item.disabled"
:target="item.type === 'external' ? '_blank' : ''"
>
<v-list-group v-if="item.children" :value="item.title" :class="{ 'group-bordered': customizer.mini_sidebar }">
<template v-slot:activator="{ props }">
<v-list-item v-bind="props" rounded class="mb-1" color="secondary" :prepend-icon="item.icon"
:style="{ '--indent-padding': '0px' }">
<v-list-item-title style="font-size: 14px; font-weight: 500; line-height: 1.2; word-break: break-word;">
{{ t(item.title) }}
</v-list-item-title>
</v-list-item>
</template>
<!-- children -->
<template v-for="(child, index) in item.children" :key="index">
<NavItem :item="child" :level="(level || 0) + 1" />
</template>
</v-list-group>
<v-list-item v-else :to="item.type === 'external' ? '' : item.to" :href="item.type === 'external' ? item.to : ''" rounded
class="mb-1" color="secondary" :disabled="item.disabled" :target="item.type === 'external' ? '_blank' : ''" :style="itemStyle">
<template v-slot:prepend>
<v-icon v-if="item.icon" :size="item.iconSize" class="hide-menu" :icon="item.icon"></v-icon>
</template>
@@ -23,15 +41,19 @@ const { t } = useI18n();
{{ item.subCaption }}
</v-list-item-subtitle>
<template v-slot:append v-if="item.chip">
<v-chip
:color="item.chipColor"
class="sidebarchip hide-menu"
:size="item.chipIcon ? 'small' : 'default'"
:variant="item.chipVariant"
:prepend-icon="item.chipIcon"
>
<v-chip :color="item.chipColor" class="sidebarchip hide-menu" :size="item.chipIcon ? 'small' : 'default'"
:variant="item.chipVariant" :prepend-icon="item.chipIcon">
{{ item.chip }}
</v-chip>
</template>
</v-list-item>
</template>
<style>
/* 在折叠(mini)状态下,分组展开时给整个分组(母项+子项)加边框以便区分 */
.group-bordered.v-list-group--open {
border: 2px solid rgba(var(--v-theme-borderLight), 0.35);
border-radius: 8px;
background: rgba(var(--v-theme-borderLight), 0.04);
}
</style>
@@ -11,8 +11,13 @@ const customizer = useCustomizerStore();
const sidebarMenu = shallowRef(sidebarItems);
const showIframe = ref(false);
const starCount = ref(null);
const sidebarWidth = ref(235);
const minSidebarWidth = 200;
const maxSidebarWidth = 300;
const isResizing = ref(false);
// iframe
const iframeStyle = ref({
position: 'fixed',
bottom: '16px',
@@ -29,14 +34,13 @@ const iframeStyle = ref({
boxShadow: '0px 4px 12px rgba(0, 0, 0, 0.1)',
});
//
if (window.innerWidth < 768) {
iframeStyle.value = {
position: 'fixed',
top: '10%',
left: '0%',
width: '100%',
height: '50%',
height: '80%',
minWidth: '300px',
minHeight: '200px',
background: 'white',
@@ -46,7 +50,6 @@ if (window.innerWidth < 768) {
borderRadius: '12px',
boxShadow: '0px 4px 12px rgba(0, 0, 0, 0.1)',
};
//
customizer.Sidebar_drawer = false;
}
@@ -74,12 +77,10 @@ function openIframeLink(url) {
}
}
//
let offsetX = 0;
let offsetY = 0;
let isDragging = false;
//
function clamp(value, min, max) {
return Math.min(Math.max(value, min), max);
}
@@ -91,7 +92,6 @@ function startDrag(clientX, clientY) {
offsetX = clientX - rect.left;
offsetY = clientY - rect.top;
document.body.style.userSelect = 'none';
//
document.addEventListener('mousemove', onMouseMove);
document.addEventListener('mouseup', onMouseUp);
document.addEventListener('touchmove', onTouchMove, { passive: false });
@@ -149,6 +149,53 @@ function endDrag() {
document.removeEventListener('touchend', onTouchEnd);
}
function startSidebarResize(event) {
isResizing.value = true;
document.body.style.userSelect = 'none';
document.body.style.cursor = 'ew-resize';
const startX = event.clientX;
const startWidth = sidebarWidth.value;
function onMouseMoveResize(event) {
if (!isResizing.value) return;
const deltaX = event.clientX - startX;
const newWidth = Math.max(minSidebarWidth, Math.min(maxSidebarWidth, startWidth + deltaX));
sidebarWidth.value = newWidth;
}
function onMouseUpResize() {
isResizing.value = false;
document.body.style.userSelect = '';
document.body.style.cursor = '';
document.removeEventListener('mousemove', onMouseMoveResize);
document.removeEventListener('mouseup', onMouseUpResize);
}
document.addEventListener('mousemove', onMouseMoveResize);
document.addEventListener('mouseup', onMouseUpResize);
}
function formatNumber(num) {
return num.toString().replace(/\B(?=(\d{3})+(?!\d))/g, ',');
}
async function fetchStarCount() {
try {
const response = await fetch('https://cloud.astrbot.app/api/v1/github/repo-info');
const data = await response.json();
if (data.data && data.data.stargazers_count) {
starCount.value = data.data.stargazers_count;
console.debug('Fetched star count:', starCount.value);
}
} catch (error) {
console.debug('Failed to fetch star count:', error);
}
}
fetchStarCount();
</script>
<template>
@@ -159,7 +206,7 @@ function endDrag() {
rail-width="80"
app
class="leftSidebar"
width="220"
:width="sidebarWidth"
:rail="customizer.mini_sidebar"
>
<div class="sidebar-container">
@@ -177,9 +224,24 @@ function endDrag() {
</v-btn>
<v-btn style="margin-bottom: 8px;" size="small" variant="plain" @click="openIframeLink('https://github.com/AstrBotDevs/AstrBot')">
{{ t('core.navigation.github') }}
<v-chip
v-if="starCount"
size="x-small"
variant="outlined"
class="ml-2"
style="font-weight: normal;"
>{{ formatNumber(starCount) }}</v-chip>
</v-btn>
</div>
</div>
<div
v-if="!customizer.mini_sidebar && customizer.Sidebar_drawer"
class="sidebar-resize-handle"
@mousedown="startSidebarResize"
:class="{ 'resizing': isResizing }"
>
</div>
</v-navigation-drawer>
<div
@@ -187,14 +249,13 @@ function endDrag() {
id="draggable-iframe"
:style="iframeStyle"
>
<!-- 拖拽头部支持鼠标和触摸 -->
<div :style="dragHeaderStyle" @mousedown="onMouseDown" @touchstart="onTouchStart">
<div style="display: flex; align-items: center;">
<v-icon icon="mdi-cursor-move" />
<span style="margin-left: 8px;">{{ t('core.navigation.drag') }}</span>
</div>
<div style="display: flex; gap: 8px;">
<!-- 跳转按钮 -->
<v-btn
icon
@click.stop="openIframeLink('https://astrbot.app')"
@@ -203,7 +264,6 @@ function endDrag() {
>
<v-icon icon="mdi-open-in-new" />
</v-btn>
<!-- 关闭按钮 -->
<v-btn
icon
@click.stop="toggleIframe"
@@ -214,10 +274,53 @@ function endDrag() {
</v-btn>
</div>
</div>
<!-- iframe 区域 -->
<iframe
src="https://astrbot.app"
style="width: 100%; height: calc(100% - 56px); border: none; border-bottom-left-radius: 12px; border-bottom-right-radius: 12px;"
></iframe>
</div>
</template>
</template>
<style scoped>
.sidebar-resize-handle {
position: absolute;
top: 0;
right: 0;
width: 4px;
height: 100%;
background: transparent;
cursor: ew-resize;
user-select: none;
z-index: 1000;
transition: background-color 0.2s ease;
}
.sidebar-resize-handle:hover,
.sidebar-resize-handle.resizing {
background: rgba(var(--v-theme-primary), 0.3);
}
.sidebar-resize-handle::before {
content: '';
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
width: 2px;
height: 30px;
background: rgba(var(--v-theme-on-surface), 0.3);
border-radius: 1px;
opacity: 0;
transition: opacity 0.2s ease;
}
.sidebar-resize-handle:hover::before,
.sidebar-resize-handle.resizing::before {
opacity: 1;
}
/* 确保侧边栏容器支持相对定位 */
.leftSidebar .v-navigation-drawer__content {
position: relative;
}
</style>
@@ -18,31 +18,26 @@ export interface menu {
// 在组件中使用时需要通过t()函数进行翻译
// 所有键名都使用 core.navigation.* 格式
const sidebarItem: menu[] = [
{
title: 'core.navigation.dashboard',
icon: 'mdi-view-dashboard',
to: '/dashboard/default'
},
{
title: 'core.navigation.platforms',
icon: 'mdi-message-processing',
to: '/platforms',
icon: 'mdi-robot',
to: '/',
},
{
title: 'core.navigation.providers',
icon: 'mdi-creation',
to: '/providers',
},
{
title: 'core.navigation.config',
icon: 'mdi-cog',
to: '/config',
},
{
title: 'core.navigation.toolUse',
icon: 'mdi-function-variant',
to: '/tool-use'
},
{
title: 'core.navigation.persona',
icon: 'mdi-heart',
to: '/persona'
},
{
title: 'core.navigation.extension',
icon: 'mdi-puzzle',
@@ -50,21 +45,8 @@ const sidebarItem: menu[] = [
},
{
title: 'core.navigation.knowledgeBase',
icon: 'mdi-text-box-search',
to: '/alkaid/knowledge-base',
},
{
title: 'core.navigation.knowledgeBaseBeta',
icon: 'mdi-book-open-variant',
to: '/knowledge-base',
chip: 'Beta',
chipColor: 'primary',
chipVariant: 'tonal',
},
{
title: 'core.navigation.config',
icon: 'mdi-cog',
to: '/config',
},
{
title: 'core.navigation.chat',
@@ -72,20 +54,36 @@ const sidebarItem: menu[] = [
to: '/chat'
},
{
title: 'core.navigation.conversation',
icon: 'mdi-database',
to: '/conversation'
},
{
title: 'core.navigation.sessionManagement',
icon: 'mdi-account-group',
to: '/session-management'
},
{
title: 'core.navigation.console',
icon: 'mdi-console',
to: '/console'
},
title: 'core.navigation.groups.more',
icon: 'mdi-dots-horizontal',
children: [
{
title: 'core.navigation.persona',
icon: 'mdi-heart',
to: '/persona'
},
{
title: 'core.navigation.conversation',
icon: 'mdi-database',
to: '/conversation'
},
{
title: 'core.navigation.sessionManagement',
icon: 'mdi-account-group',
to: '/session-management'
},
{
title: 'core.navigation.dashboard',
icon: 'mdi-view-dashboard',
to: '/dashboard/default'
},
{
title: 'core.navigation.console',
icon: 'mdi-console',
to: '/console'
},
]
}
// {
// title: 'Project ATRI',
// icon: 'mdi-grain',
+1 -1
View File
@@ -49,6 +49,6 @@ axios.interceptors.request.use((config) => {
loader.config({
paths: {
vs: 'https://cdn.jsdelivr.net/npm/monaco-editor@0.43.0/min/vs',
vs: 'https://cdn.jsdelivr.net/npm/monaco-editor@0.54.0/min/vs',
},
})
+10 -25
View File
@@ -3,13 +3,13 @@ const MainRoutes = {
meta: {
requiresAuth: true
},
redirect: '/main/dashboard/default',
redirect: '/main/platforms',
component: () => import('@/layouts/full/FullLayout.vue'),
children: [
{
name: 'Dashboard',
name: 'MainPage',
path: '/',
component: () => import('@/views/dashboards/default/DefaultDashboard.vue')
component: () => import('@/views/PlatformPage.vue')
},
{
name: 'Extensions',
@@ -90,6 +90,13 @@ const MainRoutes = {
}
]
},
// 旧版本的知识库路由
{
name: 'KnowledgeBase',
path: '/alkaid/knowledge-base',
component: () => import('@/views/alkaid/KnowledgeBase.vue'),
},
// {
// name: 'Alkaid',
// path: '/alkaid',
@@ -112,28 +119,6 @@ const MainRoutes = {
// }
// ]
// },
{
name: 'Alkaid',
path: '/alkaid',
component: () => import('@/views/AlkaidPage.vue'),
children: [
{
path: 'knowledge-base',
name: 'KnowledgeBase',
component: () => import('@/views/alkaid/KnowledgeBase.vue')
},
{
path: 'long-term-memory',
name: 'LongTermMemory',
component: () => import('@/views/alkaid/LongTermMemory.vue')
},
{
path: 'other',
name: 'OtherFeatures',
component: () => import('@/views/alkaid/Other.vue')
}
]
},
{
name: 'Chat',
path: '/chat',
+12 -639
View File
@@ -11,12 +11,6 @@
<v-select style="min-width: 130px;" v-model="selectedConfigID" :items="configSelectItems" item-title="name"
v-if="!isSystemConfig" item-value="id" label="选择配置文件" hide-details density="compact" rounded="md"
variant="outlined" @update:model-value="onConfigSelect">
<template v-slot:item="{ props: itemProps, item }">
<v-list-item v-bind="itemProps"
:subtitle="item.raw.id === '_%manage%_' ? '管理所有配置文件' : formatUmop(item.raw.umop)"
:class="item.raw.id === '_%manage%_' ? 'text-primary' : ''">
</v-list-item>
</template>
</v-select>
<a style="color: inherit;" href="https://blog.astrbot.app/posts/what-is-changed-in-4.0.0/#%E5%A4%9A%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6" target="_blank"><v-btn icon="mdi-help-circle" size="small" variant="plain"></v-btn></a>
@@ -37,38 +31,10 @@
<div v-if="(selectedConfigID || isSystemConfig) && fetched" style="width: 100%;">
<!-- 可视化编辑 -->
<div :class="$vuetify.display.mobile ? '' : 'd-flex'">
<v-tabs v-model="tab" :direction="$vuetify.display.mobile ? 'horizontal' : 'vertical'"
:align-tabs="$vuetify.display.mobile ? 'left' : 'start'" color="deep-purple-accent-4" class="config-tabs">
<v-tab v-for="(val, key, index) in metadata" :key="index" :value="index"
style="font-weight: 1000; font-size: 15px">
{{ metadata[key]['name'] }}
</v-tab>
</v-tabs>
<v-tabs-window v-model="tab" class="config-tabs-window">
<v-tabs-window-item v-for="(val, key, index) in metadata" v-show="index == tab" :key="index">
<v-container fluid>
<div v-for="(val2, key2, index2) in metadata[key]['metadata']" :key="key2">
<!-- Support both traditional and JSON selector metadata -->
<AstrBotConfigV4 :metadata="{ [key2]: metadata[key]['metadata'][key2] }" :iterable="config_data"
:metadataKey="key2">
</AstrBotConfigV4>
</div>
</v-container>
</v-tabs-window-item>
<div style="margin-left: 16px; padding-bottom: 16px">
<small>{{ tm('help.helpPrefix') }}
<a href="https://astrbot.app/" target="_blank">{{ tm('help.documentation') }}</a>
{{ tm('help.helpMiddle') }}
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=EYGsuUTfe00_iOu9JTXS7_TEpMkXOvwv&jump_from=webapi&authKey=uUEMKCROfsseS+8IzqPjzV3y1tzy4AkykwTib2jNkOFdzezF9s9XknqnIaf3CDft"
target="_blank">{{ tm('help.support') }}</a>{{ tm('help.helpSuffix') }}
</small>
</div>
</v-tabs-window>
</div>
<AstrBotCoreConfigWrapper
:metadata="metadata"
:config_data="config_data"
/>
<v-btn icon="mdi-content-save" size="x-large" style="position: fixed; right: 52px; bottom: 52px;"
color="darkprimary" @click="updateConfig">
@@ -118,7 +84,7 @@
</v-card-title>
<v-card-text>
<small>AstrBot 支持针对不同消息平台实例分别设置配置文件默认会使用 `default` 配置</small>
<small>AstrBot 支持针对不同机器人分别设置配置文件默认会使用 `default` 配置</small>
<div class="mt-6 mb-4">
<v-btn prepend-icon="mdi-plus" @click="startCreateConfig" variant="tonal" color="primary">
新建配置文件
@@ -128,8 +94,6 @@
<!-- Config List -->
<v-list lines="two">
<v-list-item v-for="config in configInfoList" :key="config.id" :title="config.name">
<v-list-item-subtitle>当前应用于: {{ formatUmop(config.umop) }} </v-list-item-subtitle>
<template v-slot:append v-if="config.id !== 'default'">
<div class="d-flex align-center" style="gap: 8px;">
<v-btn icon="mdi-pencil" size="small" variant="text" color="warning"
@@ -147,149 +111,15 @@
<div v-if="showConfigForm">
<h3 class="mb-4">{{ isEditingConfig ? '编辑配置文件' : '新建配置文件' }}</h3>
<div class="mb-4">
<div v-if="conflictMessage" class="text-warning">
<div v-html="conflictMessage" style="font-size: 0.875rem; line-height: 1.4;"></div>
</div>
</div>
<h4>名称</h4>
<v-text-field v-model="configFormData.name" label="填写配置文件名称" variant="outlined" class="mt-4 mb-4"
hide-details></v-text-field>
<h4>应用于</h4>
<v-radio-group class="mt-2" v-model="appliedToRadioValue" hide-details="true">
<v-radio value="0">
<template v-slot:label>
<span>指定消息平台...</span>
</template>
</v-radio>
<v-select v-if="appliedToRadioValue === '0'" v-model="configFormData.umop" :items="platformList" item-title="id" item-value="id"
label="选择已配置的消息平台(可多选)" variant="outlined" hide-details multiple class="ma-2"
@update:model-value="checkPlatformConflictOnForm">
<template v-slot:item="{ props: itemProps, item }">
<v-list-item v-bind="itemProps" :subtitle="item.raw.type"></v-list-item>
</template>
</v-select>
<v-radio value="1" label="自定义规则(实验性)">
</v-radio>
<!-- 自定义规则界面 -->
<div v-if="appliedToRadioValue === '1'" class="ma-2">
<small class="text-medium-emphasis mb-4 d-block">UMO 格式: [platform_id]:[message_type]:[session_id]通配符 * 或留空表示全部使用 /sid 查看某个聊天的 UMO</small>
<!-- 输入方式切换 -->
<v-btn-toggle v-model="customRuleInputMode" mandatory color="primary" variant="outlined" density="compact"
rounded="md" class="mb-4">
<v-btn value="builder" prepend-icon="mdi-tune" size="x-small">
可视化
</v-btn>
<v-btn value="manual" prepend-icon="mdi-code-tags" size="x-small">
手动编辑
</v-btn>
</v-btn-toggle>
<!-- 快速规则构建 -->
<div v-if="customRuleInputMode === 'builder'" class="mb-4">
<div v-for="(rule, index) in customRules" :key="index" class="d-flex align-center mb-2" style="gap: 8px;">
<v-select
v-model="rule.platform"
:items="[{ id: '*', type: '所有平台' }, ...platformList]"
item-title="id"
item-value="id"
label="平台"
variant="outlined"
density="compact"
style="min-width: 120px;"
@update:model-value="updateCustomRule(index)">
<template v-slot:item="{ props: itemProps, item }">
<v-list-item v-bind="itemProps" :subtitle="item.raw.type"></v-list-item>
</template>
</v-select>
<v-select
v-model="rule.messageType"
:items="messageTypeOptions"
item-title="label"
item-value="value"
label="消息类型"
variant="outlined"
density="compact"
style="min-width: 130px;"
@update:model-value="updateCustomRule(index)">
</v-select>
<v-text-field
v-model="rule.sessionId"
label="会话ID"
variant="outlined"
density="compact"
placeholder="* 或留空表示全部"
style="min-width: 120px;"
@update:model-value="updateCustomRule(index)">
</v-text-field>
<v-btn
icon="mdi-delete"
size="small"
variant="text"
color="error"
@click="removeCustomRule(index)"
:disabled="customRules.length === 1">
</v-btn>
</div>
<v-btn
prepend-icon="mdi-plus"
size="small"
variant="tonal"
color="primary"
@click="addCustomRule">
添加规则
</v-btn>
</div>
<!-- 手动输入 -->
<div v-if="customRuleInputMode === 'manual'" class="mb-4">
<v-textarea
v-model="manualRulesText"
label="手动输入规则(每行一个)"
variant="outlined"
rows="4"
placeholder="每行一个规则,例如:&#10;platform1:GroupMessage:*&#10;*:FriendMessage:session123&#10;*:*:*"
@update:model-value="updateManualRules">
</v-textarea>
</div>
<!-- 规则预览 -->
<div class="mb-2">
<small class="text-medium-emphasis">
<strong>预览:</strong>
<span v-if="!configFormData.umop.length" class="text-error">未配置任何规则</span>
<div v-else class="mt-1">
<v-chip
v-for="(rule, index) in configFormData.umop"
:key="index"
size="x-small"
rounded="sm"
class="mr-1">
{{ rule }}
</v-chip>
</div>
<small>这些规则对应的会话将使用此配置文件</small>
</small>
</div>
</div>
</v-radio-group>
<div class="d-flex justify-end mt-4" style="gap: 8px;">
<v-btn variant="text" @click="cancelConfigForm">取消</v-btn>
<v-btn color="primary" @click="saveConfigForm"
:disabled="!configFormData.name || !configFormData.umop.length">
:disabled="!configFormData.name">
{{ isEditingConfig ? '更新' : '创建' }}
</v-btn>
</div>
@@ -308,7 +138,7 @@
<script>
import axios from 'axios';
import AstrBotConfigV4 from '@/components/shared/AstrBotConfigV4.vue';
import AstrBotCoreConfigWrapper from '@/components/config/AstrBotCoreConfigWrapper.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
import { useI18n, useModuleI18n } from '@/i18n/composables';
@@ -316,7 +146,7 @@ import { useI18n, useModuleI18n } from '@/i18n/composables';
export default {
name: 'ConfigPage',
components: {
AstrBotConfigV4,
AstrBotCoreConfigWrapper,
VueMonacoEditor,
WaitingForRestart
},
@@ -359,15 +189,6 @@ export default {
watch: {
config_data_str: function (val) {
this.config_data_has_changed = true;
},
customRuleInputMode: function (newVal) {
if (newVal === 'builder') {
//
this.syncCustomRulesFromManual();
} else if (newVal === 'manual') {
//
this.syncManualRulesText();
}
}
},
data() {
@@ -387,8 +208,6 @@ export default {
save_message: "",
save_message_success: "",
tab: 0, //
//
configType: 'normal', // 'normal' 'system'
@@ -396,32 +215,12 @@ export default {
isSystemConfig: false,
//
appliedToRadioValue: '0',
selectedConfigID: null, //
configInfoList: [],
platformList: [],
configFormData: {
name: '',
umop: [],
},
editingConfigId: null,
conflictMessage: '', //
//
customRuleInputMode: 'builder', // 'builder' 'manual'
customRules: [
{
platform: '*',
messageType: '*',
sessionId: '*'
}
],
manualRulesText: '',
messageTypeOptions: [
{ label: '所有消息类型', value: '*' },
{ label: '群组消息', value: 'GroupMessage' },
{ label: '私聊消息', value: 'FriendMessage' }
],
}
},
mounted() {
@@ -450,13 +249,6 @@ export default {
this.save_message_success = "error";
});
},
getPlatformList() {
axios.get('/api/config/platform/list').then((res) => {
this.platformList = res.data.data.platforms;
}).catch((err) => {
console.error(this.t('status.dataError'), err);
});
},
getConfig(abconf_id) {
this.fetched = false
const params = {};
@@ -532,18 +324,7 @@ export default {
}
},
createNewConfig() {
let umo_parts = [];
if (this.appliedToRadioValue === '0') {
// umo part -
umo_parts = this.configFormData.umop.map(platform => platform + "::");
} else if (this.appliedToRadioValue === '1') {
//
umo_parts = [...this.configFormData.umop]; // 使 umop
}
axios.post('/api/config/abconf/new', {
umo_parts: umo_parts,
name: this.configFormData.name
}).then((res) => {
if (res.data.status === "ok") {
@@ -564,210 +345,9 @@ export default {
this.save_message_success = "error";
});
},
checkPlatformConflict(newRules) {
const conflictConfigs = [];
// "default"
for (const config of this.configInfoList) {
if (config.name === 'default') {
continue; // default
}
if (config.umop && config.umop.length > 0) {
//
const hasConflict = this.hasUmoConflict(newRules, config.umop);
if (hasConflict) {
conflictConfigs.push(config);
}
}
}
return conflictConfigs;
},
hasUmoConflict(newRules, existingRules) {
//
for (const newRule of newRules) {
for (const existingRule of existingRules) {
if (this.isUmoMatch(newRule, existingRule) || this.isUmoMatch(existingRule, newRule)) {
return true;
}
}
}
return false;
},
isUmoMatch(p1, p2) {
// p2 umo p1 umo
// _is_umo_match
//
const p1_normalized = this.normalizeUmoRule(p1);
const p2_normalized = this.normalizeUmoRule(p2);
const p1_parts = p1_normalized.split(":");
const p2_parts = p2_normalized.split(":");
if (p1_parts.length !== 3 || p2_parts.length !== 3) {
return false; //
}
//
return p1_parts.every((p, index) => {
const t = p2_parts[index];
return p === "" || p === "*" || p === t;
});
},
normalizeUmoRule(rule) {
//
if (typeof rule !== 'string') {
return "*:*:*";
}
const parts = rule.split(":");
if (parts.length === 2 && parts[1] === "") {
// "platform::" -> "platform:*:*"
return `${parts[0] || "*"}:*:*`;
} else if (parts.length === 3) {
//
return parts.map(part => part === "" ? "*" : part).join(":");
} else if (parts.length === 1) {
// "platform" -> "platform:*:*"
return `${parts[0] || "*"}:*:*`;
}
//
return "*:*:*";
},
getDetailedConflictInfo(newRules) {
const conflictDetails = [];
//
const sortedConfigs = [...this.configInfoList]
.filter(config => config.name !== 'default')
.sort((a, b) => {
//
return a.id.localeCompare(b.id);
});
for (const config of sortedConfigs) {
if (!config.umop || config.umop.length === 0) continue;
const conflictingRules = [];
for (const newRule of newRules) {
for (const existingRule of config.umop) {
if (this.isUmoMatch(newRule, existingRule) || this.isUmoMatch(existingRule, newRule)) {
conflictingRules.push({
newRule: newRule,
existingRule: existingRule,
matchType: this.getMatchType(newRule, existingRule)
});
}
}
}
if (conflictingRules.length > 0) {
conflictDetails.push({
config: config,
conflicts: conflictingRules
});
}
}
return conflictDetails;
},
getMatchType(rule1, rule2) {
const r1_normalized = this.normalizeUmoRule(rule1);
const r2_normalized = this.normalizeUmoRule(rule2);
const isR1MatchR2 = this.isUmoMatch(rule1, rule2);
const isR2MatchR1 = this.isUmoMatch(rule2, rule1);
if (isR1MatchR2 && isR2MatchR1) {
return 'exact'; //
} else if (isR1MatchR2) {
return 'new_covers_existing'; //
} else if (isR2MatchR1) {
return 'existing_covers_new'; //
}
return 'overlap'; //
},
formatConflictMessage(conflictDetails) {
if (conflictDetails.length === 0) return '';
let message = '⚠️ <strong>规则冲突警告:</strong><br><br>';
//
const sortedDetails = [...conflictDetails].sort((a, b) =>
a.config.id.localeCompare(b.config.id)
);
sortedDetails.forEach((detail, index) => {
const configName = detail.config.name || detail.config.id;
message += `<strong>${index + 1}. 与配置文件 "${configName}" 冲突:</strong><br>`;
detail.conflicts.forEach(conflict => {
const newRuleFormatted = this.formatRuleForDisplay(conflict.newRule);
const existingRuleFormatted = this.formatRuleForDisplay(conflict.existingRule);
switch (conflict.matchType) {
case 'exact':
message += `规则完全相同: <code>${newRuleFormatted}</code><br>`;
message += `<span style="color: orange;">"${configName}" 将覆盖当前配置</span><br>`;
break;
case 'new_covers_existing':
message += `当前规则 <code>${newRuleFormatted}</code> 包含现有规则 <code>${existingRuleFormatted}</code><br>`;
message += `<span style="color: red;">"${configName}" 的规则将优先匹配</span><br>`;
break;
case 'existing_covers_new':
message += `现有规则 <code>${existingRuleFormatted}</code> 包含当前规则 <code>${newRuleFormatted}</code><br>`;
message += `<span style="color: red;">"${configName}" 的规则将优先匹配</span><br>`;
break;
case 'overlap':
message += `规则重叠: <code>${newRuleFormatted}</code> ↔ <code>${existingRuleFormatted}</code><br>`;
message += `<span style="color: orange;">"${configName}" 在匹配范围内优先</span><br>`;
break;
}
});
if (index < sortedDetails.length - 1) {
message += '<br>';
}
});
message += '<br><small><strong>💡 说明:</strong> 您仍可创建此配置文件。AstrBot 按配置文件创建顺序匹配规则,先创建的配置文件优先级更高。当多个配置文件的规则匹配同一个消息会话来源时,优先级最高的配置文件会生效(default 配置文件除外)。</small>';
return message;
},
formatRuleForDisplay(rule) {
const parts = this.normalizeUmoRule(rule).split(':');
const platform = parts[0] === '*' || parts[0] === '' ? '任意平台' : parts[0];
const messageType = parts[1] === '*' || parts[1] === '' ? '任意消息' : this.getMessageTypeLabel(parts[1]);
const sessionId = parts[2] === '*' || parts[2] === '' ? '任意会话' : parts[2];
return `${platform}:${messageType}:${sessionId}`;
},
getMessageTypeLabel(messageType) {
const typeMap = {
'GroupMessage': '群组消息',
'FriendMessage': '私聊消息',
};
return typeMap[messageType] || messageType;
},
onConfigSelect(value) {
if (value === '_%manage%_') {
this.configManageDialog = true;
this.getPlatformList();
//
this.$nextTick(() => {
this.selectedConfigID = this.selectedConfigInfo.id || 'default';
@@ -781,26 +361,17 @@ export default {
this.isEditingConfig = false;
this.configFormData = {
name: '',
umop: [],
};
this.editingConfigId = null;
this.conflictMessage = '';
this.resetCustomRules();
},
startEditConfig(config) {
this.appliedToRadioValue = "1";
this.showConfigForm = true;
this.isEditingConfig = true;
this.editingConfigId = config.id;
this.parseExistingCustomRules(config.umop || []);
this.configFormData = {
name: config.name || '',
umop: [...(config.umop || [])],
};
this.conflictMessage = '';
},
cancelConfigForm() {
this.showConfigForm = false;
@@ -808,14 +379,11 @@ export default {
this.editingConfigId = null;
this.configFormData = {
name: '',
umop: [],
};
this.conflictMessage = '';
this.resetCustomRules();
},
saveConfigForm() {
if (!this.configFormData.name || !this.configFormData.umop.length) {
this.save_message = "请填写配置名称和选择应用平台";
if (!this.configFormData.name) {
this.save_message = "请填写配置名称";
this.save_message_snack = true;
this.save_message_success = "error";
return;
@@ -827,106 +395,6 @@ export default {
this.createNewConfig();
}
},
//
addCustomRule() {
this.customRules.push({
platform: '*',
messageType: '*',
sessionId: '*'
});
this.updateCustomRulesFromBuilder();
},
removeCustomRule(index) {
if (this.customRules.length > 1) {
this.customRules.splice(index, 1);
this.updateCustomRulesFromBuilder();
}
},
updateCustomRule(index) {
this.updateCustomRulesFromBuilder();
},
updateCustomRulesFromBuilder() {
// umop
const rules = this.customRules.map(rule => {
const platform = rule.platform === '*' ? '' : rule.platform;
const messageType = rule.messageType === '*' ? '' : rule.messageType;
const sessionId = rule.sessionId === '*' ? '' : (rule.sessionId || '');
return `${platform}:${messageType}:${sessionId}`;
});
this.configFormData.umop = rules;
this.syncManualRulesText();
//
this.checkPlatformConflictOnForm();
},
updateManualRules() {
// umop
const rules = this.manualRulesText
.split('\n')
.map(rule => rule.trim())
.filter(rule => rule);
this.configFormData.umop = rules;
this.syncCustomRulesFromManual();
//
this.checkPlatformConflictOnForm();
},
syncManualRulesText() {
//
this.manualRulesText = this.configFormData.umop.join('\n');
},
syncCustomRulesFromManual() {
//
this.customRules = this.configFormData.umop.map(rule => {
const parts = rule.split(':');
return {
platform: parts[0] || '*',
messageType: parts[1] || '*',
sessionId: parts[2] || '*'
};
});
},
resetCustomRules() {
this.customRuleInputMode = 'builder'; //
this.customRules = [
{
platform: '*',
messageType: '*',
sessionId: '*'
}
];
this.manualRulesText = '';
if (this.appliedToRadioValue === '1') {
this.updateCustomRulesFromBuilder();
}
},
parseExistingCustomRules(umop) {
//
if (!umop || umop.length === 0) {
this.resetCustomRules();
return;
}
this.customRules = umop.map(rule => {
const parts = rule.split(':');
return {
platform: parts[0] || '*',
messageType: parts[1] || '*',
sessionId: parts[2] || '*'
};
});
this.syncManualRulesText();
},
confirmDeleteConfig(config) {
if (confirm(`确定要删除配置文件 "${config.name}" 吗?此操作不可恢复。`)) {
this.deleteConfig(config.id);
@@ -940,6 +408,7 @@ export default {
this.save_message = res.data.message;
this.save_message_snack = true;
this.save_message_success = "success";
this.cancelConfigForm();
//
this.getConfigInfoList("default");
} else {
@@ -954,52 +423,10 @@ export default {
this.save_message_success = "error";
});
},
checkPlatformConflictOnForm() {
if (!this.configFormData.umop || this.configFormData.umop.length === 0) {
this.conflictMessage = '';
return;
}
//
let rulesToCheck = [];
if (this.appliedToRadioValue === '0') {
// UMO
rulesToCheck = this.configFormData.umop.map(platform => `${platform}:*:*`);
} else {
// 使
rulesToCheck = [...this.configFormData.umop];
}
//
let conflictDetails = this.getDetailedConflictInfo(rulesToCheck);
//
if (this.isEditingConfig && this.editingConfigId) {
conflictDetails = conflictDetails.filter(detail => detail.config.id !== this.editingConfigId);
}
if (conflictDetails.length > 0) {
this.conflictMessage = this.formatConflictMessage(conflictDetails);
} else {
this.conflictMessage = '';
}
},
updateConfigInfo() {
let umo_parts = [];
if (this.appliedToRadioValue === '0') {
// umo part -
umo_parts = this.configFormData.umop.map(platform => platform + "::");
} else if (this.appliedToRadioValue === '1') {
//
umo_parts = [...this.configFormData.umop]; // 使 umop
}
axios.post('/api/config/abconf/update', {
id: this.editingConfigId,
name: this.configFormData.name,
umo_parts: umo_parts
name: this.configFormData.name
}).then((res) => {
if (res.data.status === "ok") {
this.save_message = res.data.message;
@@ -1019,38 +446,8 @@ export default {
this.save_message_success = "error";
});
},
formatUmop(umop) {
if (!umop) {
return
}
let ret = ""
for (let i = 0; i < umop.length; i++) {
const parts = umop[i].split(":");
if (parts.length === 3) {
// platform:messageType:sessionId
const platform = parts[0] || "*";
const messageType = parts[1] || "*";
const sessionId = parts[2] || "*";
if (platform === "*" && messageType === "*" && sessionId === "*") {
return "所有平台";
}
ret += `${platform}:${messageType}:${sessionId},`;
} else {
//
let platformPart = umop[i].split(":")[0];
if (platformPart === "") {
return "所有平台";
} else {
ret += platformPart + ",";
}
}
}
ret = ret.slice(0, -1);
return ret;
},
onConfigTypeToggle() {
this.isSystemConfig = this.configType === 'system';
this.tab = 0; //
this.fetched = false; //
if (this.isSystemConfig) {
@@ -1069,7 +466,6 @@ export default {
// configType
this.configType = this.isSystemConfig ? 'system' : 'normal';
this.tab = 0; //
this.fetched = false; //
if (this.isSystemConfig) {
@@ -1128,31 +524,12 @@ export default {
}
@media (min-width: 768px) {
.config-tabs {
display: flex;
margin: 16px 16px 0 0;
}
.config-panel {
width: 750px;
}
.config-tabs-window {
flex: 1;
}
.config-tabs .v-tab {
justify-content: flex-start !important;
text-align: left;
min-height: 48px;
}
}
@media (max-width: 767px) {
.config-tabs {
width: 100%;
}
.v-container {
padding: 4px;
}
@@ -1160,9 +537,5 @@ export default {
.config-panel {
width: 100%;
}
.config-tabs-window {
margin-top: 16px;
}
}
</style>
+3 -3
View File
@@ -433,14 +433,14 @@ export default {
tableHeaders() {
return [
{ title: this.tm('table.headers.title'), key: 'title', sortable: true },
{ title: '会话 ID', key: 'cid', sortable: true, width: '100px' },
{ title: this.tm('table.headers.cid'), key: 'cid', sortable: true, width: '100px' },
{
title: this.tm('table.headers.sessionId'),
title: this.tm('table.headers.umo'),
align: 'center',
children: [
{ title: this.tm('table.headers.platform'), key: 'platform', sortable: true, width: '120px' },
{ title: this.tm('table.headers.type'), key: 'messageType', sortable: true, width: '100px' },
{ title: '用户 ID', key: 'sessionId', sortable: true, width: '100px' },
{ title: this.tm('table.headers.sessionId'), key: 'sessionId', sortable: true, width: '100px' },
],
},
{ title: this.tm('table.headers.createdAt'), key: 'created_at', sortable: true, width: '180px' },
+2 -2
View File
@@ -749,9 +749,9 @@ onMounted(async () => {
</v-row>
<v-row>
<v-col cols="12" md="6" lg="4" v-for="extension in filteredPlugins" :key="extension.name"
<v-col cols="12" md="6" lg="6" v-for="extension in filteredPlugins" :key="extension.name"
class="pb-4">
<ExtensionCard :extension="extension" class="h-120 rounded-lg"
<ExtensionCard :extension="extension" class="rounded-lg"
@configure="openExtensionConfig(extension.name)" @uninstall="uninstallExtension(extension.name)"
@update="updateExtension(extension.name)" @reload="reloadPlugin(extension.name)"
@toggle-activation="extension.activated ? pluginOff(extension) : pluginOn(extension)"
+13 -494
View File
@@ -90,200 +90,11 @@
</v-container>
<!-- 创建/编辑人格对话框 -->
<v-dialog v-model="showPersonaDialog" max-width="800px" persistent>
<v-card>
<v-card-title class="text-h2">
{{ editingPersona ? tm('dialog.edit.title') : tm('dialog.create.title') }}
</v-card-title>
<v-card-text>
<v-form ref="personaForm" v-model="formValid">
<v-text-field v-model="personaForm.persona_id" :label="tm('form.personaId')"
:rules="personaIdRules" :disabled="editingPersona" variant="outlined" density="comfortable"
class="mb-4" />
<v-textarea v-model="personaForm.system_prompt" :label="tm('form.systemPrompt')"
:rules="systemPromptRules" variant="outlined" rows="6" class="mb-4" />
<v-expansion-panels v-model="expandedPanels" multiple>
<!-- 工具选择面板 -->
<v-expansion-panel value="tools">
<v-expansion-panel-title>
<v-icon class="mr-2">mdi-tools</v-icon>
{{ tm('form.tools') }}
<v-chip v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
size="small" color="primary" variant="tonal" class="ml-2">
{{ personaForm.tools.length }}
</v-chip>
</v-expansion-panel-title>
<v-expansion-panel-text>
<div class="mb-3">
<p class="text-body-2 text-medium-emphasis">
{{ tm('form.toolsHelp') }}
</p>
</div>
<v-radio-group class="mt-2" v-model="toolSelectValue" hide-details="true">
<v-radio label="默认使用全部函数工具" value="0"></v-radio>
<v-radio label="选择指定函数工具" value="1">
</v-radio>
</v-radio-group>
<div v-if="toolSelectValue === '1'" class="mt-3 ml-8">
<!-- 工具搜索 -->
<v-text-field v-model="toolSearch" :label="tm('form.searchTools')"
prepend-inner-icon="mdi-magnify" variant="outlined" density="compact"
hide-details clearable class="mb-3" />
<!-- MCP 服务器 -->
<div v-if="mcpServers.length > 0" class="mb-4">
<h4 class="text-subtitle-2 mb-2">{{ tm('form.mcpServersQuickSelect') }}</h4>
<div class="d-flex flex-wrap ga-2">
<v-chip v-for="server in mcpServers" :key="server.name"
:color="isServerSelected(server) ? 'primary' : 'default'"
:variant="isServerSelected(server) ? 'flat' : 'outlined'"
size="small" clickable @click="toggleMcpServer(server)"
:disabled="!server.tools || server.tools.length === 0">
<v-icon start size="small">mdi-server</v-icon>
{{ server.name }}
<v-chip-text v-if="server.tools" class="ml-1">
({{ server.tools.length }})
</v-chip-text>
</v-chip>
</div>
</div>
<!-- 工具选择列表 -->
<div v-if="filteredTools.length > 0" class="tools-selection">
<v-virtual-scroll :items="filteredTools" height="300" item-height="48">
<template v-slot:default="{ item }">
<v-list-item :key="item.name" density="comfortable"
@click="toggleTool(item.name)">
<template v-slot:prepend>
<v-checkbox-btn :model-value="isToolSelected(item.name)"
@click.stop="toggleTool(item.name)" />
</template>
<v-list-item-title>
{{ item.name }}
<v-chip v-if="item.mcp_server_name" size="x-small"
color="secondary" variant="tonal" class="ml-2">
{{ item.mcp_server_name }}
</v-chip>
</v-list-item-title>
<v-list-item-subtitle v-if="item.description">
{{ truncateText(item.description, 100) }}
</v-list-item-subtitle>
</v-list-item>
</template>
</v-virtual-scroll>
</div>
<div v-else-if="!loadingTools && availableTools.length === 0"
class="text-center pa-4">
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-tools</v-icon>
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsAvailable')
}}
</p>
</div>
<div v-else-if="!loadingTools && filteredTools.length === 0"
class="text-center pa-4">
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-magnify</v-icon>
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsFound') }}
</p>
</div>
<!-- 加载状态 -->
<div v-if="loadingTools" class="text-center pa-4">
<v-progress-circular indeterminate color="primary" />
<p class="text-body-2 text-medium-emphasis mt-2">{{ tm('form.loadingTools')
}}
</p>
</div>
<!-- 已选择的工具 -->
<div class="mt-4">
<h4 class="text-subtitle-2 mb-2">
{{ tm('form.selectedTools') }}
<span v-if="personaForm.tools === null" class="text-success">
({{ tm('form.allSelected') }})
</span>
<span v-else-if="Array.isArray(personaForm.tools)">
({{ personaForm.tools.length }})
</span>
</h4>
<div v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
class="d-flex flex-wrap ga-1" style="max-height: 100px; overflow-y: auto;">
<v-chip v-for="toolName in personaForm.tools" :key="toolName"
size="small" color="primary" variant="tonal" closable
@click:close="removeTool(toolName)">
{{ toolName }}
</v-chip>
</div>
<div v-else class="text-body-2 text-medium-emphasis">
{{ tm('form.noToolsSelected') }}
</div>
</div>
</div>
</v-expansion-panel-text>
</v-expansion-panel>
<!-- 预设对话面板 -->
<v-expansion-panel value="dialogs">
<v-expansion-panel-title>
<v-icon class="mr-2">mdi-chat</v-icon>
{{ tm('form.presetDialogs') }}
<v-chip v-if="personaForm.begin_dialogs.length > 0" size="small" color="primary"
variant="tonal" class="ml-2">
{{ personaForm.begin_dialogs.length / 2 }}
</v-chip>
</v-expansion-panel-title>
<v-expansion-panel-text>
<div class="mb-3">
<p class="text-body-2 text-medium-emphasis">
{{ tm('form.presetDialogsHelp') }}
</p>
</div>
<div v-for="(dialog, index) in personaForm.begin_dialogs" :key="index" class="mb-3">
<v-textarea v-model="personaForm.begin_dialogs[index]"
:label="index % 2 === 0 ? tm('form.userMessage') : tm('form.assistantMessage')"
:rules="getDialogRules(index)" variant="outlined" rows="2"
density="comfortable">
<template v-slot:append>
<v-btn icon="mdi-delete" variant="text" size="small" color="error"
@click="removeDialog(index)" />
</template>
</v-textarea>
</div>
<v-btn variant="outlined" prepend-icon="mdi-plus" @click="addDialogPair" block>
{{ tm('buttons.addDialogPair') }}
</v-btn>
</v-expansion-panel-text>
</v-expansion-panel>
</v-expansion-panels>
</v-form>
</v-card-text>
<v-card-actions>
<v-spacer />
<v-btn color="grey" variant="text" @click="closePersonaDialog">
{{ tm('buttons.cancel') }}
</v-btn>
<v-btn color="primary" variant="flat" @click="savePersona" :loading="saving" :disabled="!formValid">
{{ tm('buttons.save') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<PersonaForm
v-model="showPersonaDialog"
:editing-persona="editingPersona"
@saved="handlePersonaSaved"
@error="showError" />
<!-- 查看人格详情对话框 -->
<v-dialog v-model="showViewDialog" max-width="700px">
@@ -352,9 +163,13 @@
<script>
import axios from 'axios';
import { useI18n, useModuleI18n } from '@/i18n/composables';
import PersonaForm from '@/components/shared/PersonaForm.vue';
export default {
name: 'PersonaPage',
components: {
PersonaForm
},
setup() {
const { t } = useI18n();
const { tm } = useModuleI18n('features/persona');
@@ -362,76 +177,20 @@ export default {
},
data() {
return {
toolSelectValue: '0', //
personas: [],
loading: false,
saving: false,
showPersonaDialog: false,
showViewDialog: false,
editingPersona: null,
viewingPersona: null,
expandedPanels: [],
formValid: false,
personaForm: {
persona_id: '',
system_prompt: '',
begin_dialogs: [],
tools: []
},
showMessage: false,
message: '',
messageType: 'success',
personaIdRules: [
v => !!v || this.tm('validation.required'),
v => (v && v.length >= 0) || this.tm('validation.minLength', { min: 2 }),
],
systemPromptRules: [
v => !!v || this.tm('validation.required'),
v => (v && v.length >= 10) || this.tm('validation.minLength', { min: 10 })
],
mcpServers: [],
availableTools: [],
loadingTools: false,
toolSearch: ''
}
},
computed: {
filteredTools() {
if (!this.toolSearch) {
return this.availableTools;
}
const search = this.toolSearch.toLowerCase();
return this.availableTools.filter(tool =>
tool.name.toLowerCase().includes(search) ||
(tool.description && tool.description.toLowerCase().includes(search)) ||
(tool.mcp_server_name && tool.mcp_server_name.toLowerCase().includes(search))
);
}
},
watch: {
toolSearch() {
//
},
toolSelectValue(newValue) {
if (newValue === '0') {
//
this.personaForm.tools = null;
} else if (newValue === '1') {
// null
if (this.personaForm.tools === null) {
this.personaForm.tools = [];
}
}
messageType: 'success'
}
},
mounted() {
this.loadPersonas();
this.loadMcpServers();
this.loadTools();
},
methods: {
@@ -450,58 +209,13 @@ export default {
this.loading = false;
},
async loadMcpServers() {
try {
const response = await axios.get('/api/tools/mcp/servers');
if (response.data.status === 'ok') {
this.mcpServers = response.data.data;
} else {
this.showError(response.data.message || this.tm('messages.loadError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.loadError'));
}
},
async loadTools() {
this.loadingTools = true;
try {
const response = await axios.get('/api/tools/list');
if (response.data.status === 'ok') {
this.availableTools = response.data.data;
} else {
this.showError(response.data.message || this.tm('messages.loadError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.loadError'));
}
this.loadingTools = false;
},
openCreateDialog() {
this.editingPersona = null;
this.personaForm = {
persona_id: '',
system_prompt: '',
begin_dialogs: [],
tools: []
};
this.toolSelectValue = '0';
this.expandedPanels = [];
this.showPersonaDialog = true;
},
editPersona(persona) {
this.editingPersona = persona;
this.personaForm = {
persona_id: persona.persona_id,
system_prompt: persona.system_prompt,
begin_dialogs: [...(persona.begin_dialogs || [])],
tools: persona.tools === null ? null : [...(persona.tools || [])]
};
// tools toolSelectValue
this.toolSelectValue = persona.tools === null ? '0' : '1';
this.expandedPanels = [];
this.showPersonaDialog = true;
},
@@ -510,48 +224,9 @@ export default {
this.showViewDialog = true;
},
closePersonaDialog() {
this.showPersonaDialog = false;
this.editingPersona = null;
this.personaForm = {
persona_id: '',
system_prompt: '',
begin_dialogs: [],
tools: []
};
this.toolSelectValue = '1'; //
},
async savePersona() {
if (!this.formValid) return;
//
if (this.personaForm.begin_dialogs.length > 0) {
for (let i = 0; i < this.personaForm.begin_dialogs.length; i++) {
if (!this.personaForm.begin_dialogs[i] || this.personaForm.begin_dialogs[i].trim() === '') {
const dialogType = i % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
this.showError(this.tm('validation.dialogRequired', { type: dialogType }));
return;
}
}
}
this.saving = true;
try {
const url = this.editingPersona ? '/api/persona/update' : '/api/persona/create';
const response = await axios.post(url, this.personaForm);
if (response.data.status === 'ok') {
this.showSuccess(response.data.message || this.tm('messages.saveSuccess'));
this.closePersonaDialog();
await this.loadPersonas();
} else {
this.showError(response.data.message || this.tm('messages.saveError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.saveError'));
}
this.saving = false;
handlePersonaSaved(message) {
this.showSuccess(message);
this.loadPersonas();
},
async deletePersona(persona) {
@@ -575,124 +250,6 @@ export default {
}
},
addDialogPair() {
this.personaForm.begin_dialogs.push('', '');
//
if (!this.expandedPanels.includes('dialogs')) {
this.expandedPanels.push('dialogs');
}
},
removeDialog(index) {
//
if (index % 2 === 0 && index + 1 < this.personaForm.begin_dialogs.length) {
this.personaForm.begin_dialogs.splice(index, 2);
}
//
else if (index % 2 === 1 && index - 1 >= 0) {
this.personaForm.begin_dialogs.splice(index - 1, 2);
}
},
toggleMcpServer(server) {
if (!server.tools || server.tools.length === 0) return;
//
if (this.personaForm.tools === null) {
//
this.personaForm.tools = this.availableTools.map(tool => tool.name)
.filter(toolName => !server.tools.includes(toolName));
this.toolSelectValue = '1'; //
return;
}
// tools
if (!Array.isArray(this.personaForm.tools)) {
this.personaForm.tools = [];
this.toolSelectValue = '1';
}
//
const serverTools = server.tools;
const allSelected = serverTools.every(toolName => this.personaForm.tools.includes(toolName));
if (allSelected) {
//
this.personaForm.tools = this.personaForm.tools.filter(
toolName => !serverTools.includes(toolName)
);
} else {
//
serverTools.forEach(toolName => {
if (!this.personaForm.tools.includes(toolName)) {
this.personaForm.tools.push(toolName);
}
});
}
},
toggleTool(toolName) {
//
if (this.personaForm.tools === null) {
//
//
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
this.toolSelectValue = '1'; //
} else if (Array.isArray(this.personaForm.tools)) {
const index = this.personaForm.tools.indexOf(toolName);
if (index !== -1) {
//
this.personaForm.tools.splice(index, 1);
} else {
//
this.personaForm.tools.push(toolName);
}
} else {
// toolsnull
this.personaForm.tools = [toolName];
this.toolSelectValue = '1';
}
},
toggleAllTools() {
//
if (this.isAllToolsSelected()) {
this.personaForm.tools = [];
} else {
// null
this.personaForm.tools = null;
}
},
clearAllTools() {
//
this.personaForm.tools = [];
},
isAllToolsSelected() {
// toolsnull
return this.personaForm.tools === null;
},
isNoToolsSelected() {
//
return Array.isArray(this.personaForm.tools) && this.personaForm.tools.length === 0;
},
removeTool(toolName) {
//
if (this.personaForm.tools === null) {
//
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
this.toolSelectValue = '1'; //
} else if (Array.isArray(this.personaForm.tools)) {
const index = this.personaForm.tools.indexOf(toolName);
if (index !== -1) {
this.personaForm.tools.splice(index, 1);
}
}
},
truncateText(text, maxLength) {
if (!text) return '';
return text.length > maxLength ? text.substring(0, maxLength) + '...' : text;
@@ -713,35 +270,6 @@ export default {
this.message = message;
this.messageType = 'error';
this.showMessage = true;
},
getDialogRules(index) {
const dialogType = index % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
return [
v => !!v || this.tm('validation.dialogRequired', { type: dialogType }),
v => (v && v.trim().length > 0) || this.tm('validation.dialogRequired', { type: dialogType })
];
},
isToolSelected(toolName) {
//
if (this.personaForm.tools === null) {
return true;
}
return Array.isArray(this.personaForm.tools) && this.personaForm.tools.includes(toolName);
},
isServerSelected(server) {
if (!server.tools || server.tools.length === 0) return false;
//
if (this.personaForm.tools === null) {
return true;
}
//
return Array.isArray(this.personaForm.tools) &&
server.tools.every(toolName => this.personaForm.tools.includes(toolName));
}
}
}
@@ -791,13 +319,4 @@ export default {
white-space: pre-wrap;
word-break: break-word;
}
.tools-selection {
max-height: 300px;
overflow-y: auto;
}
.v-virtual-scroll {
padding-bottom: 16px;
}
</style>
+19 -227
View File
@@ -3,14 +3,14 @@
<v-container fluid class="pa-0">
<v-row class="d-flex justify-space-between align-center px-4 py-3 pb-8">
<div>
<h1 class="text-h1 font-weight-bold mb-2">
<v-icon color="black" class="me-2">mdi-connection</v-icon>{{ tm('title') }}
<h1 class="text-h1 font-weight-bold mb-2 d-flex align-center">
<v-icon color="black" class="me-2">mdi-robot</v-icon>{{ tm('title') }}
</h1>
<p class="text-subtitle-1 text-medium-emphasis mb-4">
{{ tm('subtitle') }}
</p>
</div>
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddPlatformDialog = true"
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="updatingMode = false; showAddPlatformDialog = true"
rounded="xl" size="x-large">
{{ tm('addAdapter') }}
</v-btn>
@@ -46,7 +46,6 @@
</v-btn>
</v-card-title>
<v-divider></v-divider>
<v-expand-transition>
<v-card-text class="pa-0" v-if="showConsole">
@@ -57,91 +56,15 @@
</v-container>
<!-- 添加平台适配器对话框 -->
<AddNewPlatform v-model:show="showAddPlatformDialog" :metadata="metadata"
@select-template="selectPlatformTemplate" />
<!-- 配置对话框 -->
<v-dialog v-model="showPlatformCfg" persistent width="900px" max-width="90%">
<v-card
:title="updatingMode ? tm('dialog.edit') : tm('dialog.add') + ` ${newSelectedPlatformName} ` + tm('dialog.adapter')">
<v-card-text class="py-4">
<v-row>
<v-col cols="12">
<AstrBotConfig :iterable="newSelectedPlatformConfig" :metadata="metadata['platform_group']?.metadata"
metadataKey="platform" />
</v-col>
</v-row>
<v-row class="mt-2">
<v-col cols="12" class="text-center">
<v-btn color="info" variant="outlined" @click="openTutorial">
<v-icon start>mdi-book-open-variant</v-icon>
{{ tm('dialog.viewTutorial') }}
</v-btn>
</v-col>
</v-row>
</v-card-text>
<v-divider></v-divider>
<v-card-actions class="pa-4">
<v-spacer></v-spacer>
<v-btn variant="text" @click="showPlatformCfg = false" :disabled="loading">
{{ tm('dialog.cancel') }}
</v-btn>
<v-btn color="primary" @click="newPlatform" :loading="loading">
{{ tm('dialog.save') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<AddNewPlatform v-model:show="showAddPlatformDialog" :metadata="metadata" :config_data="config_data" ref="addPlatformDialog"
:updating-mode="updatingMode" :updating-platform-config="updatingPlatformConfig" @update="getConfig"
@show-toast="showToast" @refresh-config="getConfig"/>
<!-- 消息提示 -->
<v-snackbar :timeout="3000" elevation="24" :color="save_message_success" v-model="save_message_snack"
location="top">
{{ save_message }}
</v-snackbar>
<!-- ID冲突确认对话框 -->
<v-dialog v-model="showIdConflictDialog" max-width="450" persistent>
<v-card>
<v-card-title class="text-h6 bg-warning d-flex align-center">
<v-icon start class="me-2">mdi-alert-circle-outline</v-icon>
{{ tm('dialog.idConflict.title') }}
</v-card-title>
<v-card-text class="py-4 text-body-1 text-medium-emphasis">
{{ tm('dialog.idConflict.message', { id: conflictId }) }}
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="grey" variant="text" @click="handleIdConflictConfirm(false)">{{ tm('dialog.idConflict.confirm')
}}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<!-- 安全警告对话框 -->
<v-dialog v-model="showOneBotEmptyTokenWarnDialog" max-width="600" persistent>
<v-card>
<v-card-title>
{{ tm('dialog.securityWarning.title') }}
</v-card-title>
<v-card-text class="py-4">
<p>{{ tm('dialog.securityWarning.aiocqhttpTokenMissing') }}</p>
<span><a
href="https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html#%E9%99%84%E5%BD%95-%E5%A2%9E%E5%BC%BA%E8%BF%9E%E6%8E%A5%E5%AE%89%E5%85%A8%E6%80%A7"
target="_blank">{{ tm('dialog.securityWarning.learnMore') }}</a></span>
</v-card-text>
<v-card-actions class="px-4 pb-4">
<v-spacer></v-spacer>
<v-btn color="error" @click="handleOneBotEmptyTokenWarningDismiss(true)">
无视警告并继续创建
</v-btn>
<v-btn color="primary" @click="handleOneBotEmptyTokenWarningDismiss(false)">
重新修改
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</div>
</template>
@@ -191,30 +114,17 @@ export default {
config_data: {},
fetched: false,
metadata: {},
showPlatformCfg: false,
showAddPlatformDialog: false,
newSelectedPlatformName: '',
newSelectedPlatformConfig: {},
updatingPlatformConfig: {},
updatingMode: false,
loading: false,
save_message_snack: false,
save_message: "",
save_message_success: "success",
showConsole: false,
// ID
showIdConflictDialog: false,
conflictId: '',
idConflictResolve: null,
// OneBot Empty Token Warning #2639
showOneBotEmptyTokenWarnDialog: false,
oneBotEmptyTokenWarningResolve: null,
store: useCommonStore()
}
},
@@ -251,11 +161,6 @@ export default {
return getPlatformIcon(platform_id);
},
openTutorial() {
const tutorialUrl = getTutorialLink(this.newSelectedPlatformConfig.type);
window.open(tutorialUrl, '_blank');
},
getConfig() {
axios.get('/api/config/get').then((res) => {
this.config_data = res.data.data.config;
@@ -266,134 +171,13 @@ export default {
});
},
//
selectPlatformTemplate(name) {
this.newSelectedPlatformName = name;
this.showPlatformCfg = true;
this.updatingMode = false;
this.newSelectedPlatformConfig = JSON.parse(JSON.stringify(
this.metadata['platform_group']?.metadata?.platform?.config_template[name] || {}
));
},
addFromDefaultConfigTmpl(index) {
this.newSelectedPlatformName = index[0];
this.showPlatformCfg = true;
this.updatingMode = false;
this.newSelectedPlatformConfig = JSON.parse(JSON.stringify(
this.metadata['platform_group']?.metadata?.platform?.config_template[index[0]] || {}
));
},
editPlatform(platform) {
this.newSelectedPlatformName = platform.id;
this.newSelectedPlatformConfig = JSON.parse(JSON.stringify(platform));
this.updatingPlatformConfig = JSON.parse(JSON.stringify(platform));
this.updatingMode = true;
this.showPlatformCfg = true;
},
newPlatform() {
this.loading = true;
if (this.updatingMode) {
if (this.newSelectedPlatformConfig.type === 'aiocqhttp') {
const token = this.newSelectedPlatformConfig.ws_reverse_token;
if (!token || token.trim() === '') {
this.showOneBotEmptyTokenWarning().then((continueWithWarning) => {
if (continueWithWarning) {
this.updatePlatform();
}
});
return;
}
}
this.updatePlatform();
} else {
this.savePlatform();
}
},
updatePlatform() {
axios.post('/api/config/platform/update', {
id: this.newSelectedPlatformName,
config: this.newSelectedPlatformConfig
}).then((res) => {
this.loading = false;
this.showPlatformCfg = false;
this.getConfig();
this.showSuccess(res.data.message || this.messages.updateSuccess);
}).catch((err) => {
this.loading = false;
this.showError(err.response?.data?.message || err.message);
this.showAddPlatformDialog = true;
this.$nextTick(() => {
this.$refs.addPlatformDialog.toggleShowConfigSection();
});
this.updatingMode = false;
},
async savePlatform() {
// ID
const existingPlatform = this.config_data.platform?.find(p => p.id === this.newSelectedPlatformConfig.id);
if (existingPlatform) {
const confirmed = await this.confirmIdConflict(this.newSelectedPlatformConfig.id);
if (!confirmed) {
this.loading = false;
return; //
}
}
// aiocqhttp
if (this.newSelectedPlatformConfig.type === 'aiocqhttp') {
const token = this.newSelectedPlatformConfig.ws_reverse_token;
if (!token || token.trim() === '') {
const continueWithWarning = await this.showOneBotEmptyTokenWarning();
if (!continueWithWarning) {
return;
}
}
}
try {
const res = await axios.post('/api/config/platform/new', this.newSelectedPlatformConfig);
this.loading = false;
this.showPlatformCfg = false;
this.getConfig();
this.showSuccess(res.data.message || this.messages.addSuccess);
} catch (err) {
this.loading = false;
this.showError(err.response?.data?.message || err.message);
}
},
confirmIdConflict(id) {
this.conflictId = id;
this.showIdConflictDialog = true;
return new Promise((resolve) => {
this.idConflictResolve = resolve;
});
},
handleIdConflictConfirm(confirmed) {
if (this.idConflictResolve) {
this.idConflictResolve(confirmed);
}
this.showIdConflictDialog = false;
},
showOneBotEmptyTokenWarning() {
this.showOneBotEmptyTokenWarnDialog = true;
return new Promise((resolve) => {
this.oneBotEmptyTokenWarningResolve = resolve;
});
},
handleOneBotEmptyTokenWarningDismiss(continueWithWarning) {
this.showOneBotEmptyTokenWarnDialog = false;
if (this.oneBotEmptyTokenWarningResolve) {
this.oneBotEmptyTokenWarningResolve(continueWithWarning);
this.oneBotEmptyTokenWarningResolve = null;
}
if (!continueWithWarning) {
this.loading = false;
}
},
deletePlatform(platform) {
@@ -422,6 +206,14 @@ export default {
});
},
showToast({ message, type }) {
if (type === 'success') {
this.showSuccess(message);
} else if (type === 'error') {
this.showError(message);
}
},
showSuccess(message) {
this.save_message = message;
this.save_message_success = "success";
+1 -5
View File
@@ -103,8 +103,6 @@
</v-btn>
</v-card-title>
<v-divider></v-divider>
<v-expand-transition>
<v-card-text class="pa-0" v-if="showStatus">
<v-card-text class="px-4 py-3">
@@ -158,8 +156,6 @@
</v-btn>
</v-card-title>
<v-divider></v-divider>
<v-expand-transition>
<v-card-text class="pa-0" v-if="showConsole">
<ConsoleDisplayer style="background-color: #1e1e1e; height: 300px; border-radius: 0"></ConsoleDisplayer>
@@ -234,7 +230,7 @@
确认保存
</v-card-title>
<v-card-text class="py-4 text-body-1 text-medium-emphasis">
您没有填写 API Key确定要保存吗这可能会导致该服务提供商无法正常工作
您没有填写 API Key确定要保存吗这可能会导致该模型无法正常工作
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
@@ -60,10 +60,10 @@
<p>使用 /sid 指令可查看会话 ID</p>
<p>会话信息</p>
<ul>
<li>平台: {{ item.platform }}</li>
<li v-if="item.user_name">用户: {{ item.user_name }}</li>
<li>机器人 ID: {{ item.platform }}</li>
<li v-if="item.message_type">消息类型: {{ item.message_type }}</li>
<li v-if="item.session_raw_name">会话 ID: {{ item.session_raw_name }}</li>
<li v-if="item.user_name">用户: {{ item.user_name }}</li>
</ul>
</div>
</v-tooltip>
+12 -6
View File
@@ -1,6 +1,11 @@
<template>
<div class="flex-grow-1" style="display: flex; flex-direction: column; height: 100%;">
<div style="flex-grow: 1; width: 100%; border: 1px solid #eee; border-radius: 8px; padding: 16px">
<v-banner lines="one">
<template v-slot:text>
建议您更换使用新版知识库功能
</template>
</v-banner>
<!-- knowledge card -->
<div v-if="!installed" class="d-flex align-center justify-center flex-column"
style="flex-grow: 1; width: 100%; height: 100%;">
@@ -105,9 +110,9 @@
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="error" variant="text" @click="showCreateDialog = false">{{ tm('createDialog.cancel')
}}</v-btn>
}}</v-btn>
<v-btn color="primary" variant="text" @click="submitCreateForm">{{ tm('createDialog.create')
}}</v-btn>
}}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -132,7 +137,7 @@
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="primary" variant="text" @click="showEmojiPicker = false">{{ tm('emojiPicker.close')
}}</v-btn>
}}</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
@@ -159,8 +164,8 @@
<v-chip v-if="currentKB.rerank_provider_id" color="tertiary" variant="tonal" size="small"
rounded="sm">
<v-icon start size="small">mdi-sort-variant</v-icon>
重排序模型: {{ rerankProviderConfigs.
find(provider => provider.id === currentKB.rerank_provider_id)?.rerank_model || '未设置' }}
重排序模型: {{rerankProviderConfigs.
find(provider => provider.id === currentKB.rerank_provider_id)?.rerank_model || '未设置'}}
</v-chip>
<small style="margin-left: 8px;">💡 使用方式: 在聊天页中输入 "/kb use {{ currentKB.collection_name }}"</small>
</div>
@@ -411,7 +416,7 @@
<v-spacer></v-spacer>
<v-btn color="grey-darken-1" variant="text" @click="showDeleteDialog = false">{{
tm('deleteDialog.cancel')
}}</v-btn>
}}</v-btn>
<v-btn color="error" variant="text" @click="deleteKnowledgeBase" :loading="deleting">{{
tm('deleteDialog.delete') }}</v-btn>
</v-card-actions>
@@ -603,6 +608,7 @@ export default {
.then(response => {
if (response.data.status !== 'ok' || response.data.data.length === 0) {
this.showSnackbar(this.tm('messages.pluginNotAvailable'), 'error');
this.installed = false;
return
}
if (!response.data.data[0].activated) {
@@ -82,10 +82,10 @@
<v-card-title class="d-flex align-center pa-4">
<span>{{ t('chunks.title') }}</span>
<v-chip class="ml-2" size="small" variant="tonal">
{{ chunks.length }} {{ t('chunks.title') }}
{{ totalChunks }} {{ t('chunks.title') }}
</v-chip>
<v-spacer />
<v-text-field
<!-- <v-text-field
v-model="searchQuery"
prepend-inner-icon="mdi-magnify"
:placeholder="t('chunks.searchPlaceholder')"
@@ -94,7 +94,7 @@
hide-details
clearable
style="max-width: 300px"
/>
/> -->
</v-card-title>
<v-divider />
@@ -104,7 +104,8 @@
:headers="headers"
:items="filteredChunks"
:loading="loadingChunks"
:items-per-page="10"
:items-per-page="pageSize"
hide-default-footer
>
<template #item.chunk_index="{ item }">
<v-chip size="small" variant="tonal" color="primary">
@@ -132,19 +133,13 @@
color="info"
@click="viewChunk(item)"
/>
<v-btn
icon="mdi-pencil"
variant="text"
size="small"
color="primary"
@click="editChunk(item)"
/>
<!-- 删除 -->
<v-btn
icon="mdi-delete"
variant="text"
size="small"
color="error"
@click="confirmDeleteChunk(item)"
@click="deleteChunk(item)"
/>
</template>
@@ -155,6 +150,31 @@
</div>
</template>
</v-data-table>
<!-- 自定义分页器 -->
<div v-if="!searchQuery && totalChunks > 0" class="pa-4 d-flex align-center justify-space-between">
<div class="text-caption text-medium-emphasis">
{{ t('chunks.showing') }} {{ (page - 1) * pageSize + 1 }} - {{ Math.min(page * pageSize, totalChunks) }} / {{ totalChunks }}
</div>
<div class="d-flex align-center gap-2">
<v-select
v-model="pageSize"
:items="[10, 25, 50, 100]"
density="compact"
variant="outlined"
hide-details
style="width: 100px"
@update:model-value="handlePageSizeChange"
/>
<v-pagination
v-model="page"
:length="Math.ceil(totalChunks / pageSize)"
:total-visible="5"
@update:model-value="handlePageChange"
/>
</div>
</div>
</v-card-text>
</v-card>
</div>
@@ -191,7 +211,7 @@
<v-icon>mdi-key</v-icon>
</template>
<v-list-item-title>{{ t('view.vecDocId') }}</v-list-item-title>
<v-list-item-subtitle>{{ selectedChunk?.vec_doc_id || '-' }}</v-list-item-subtitle>
<v-list-item-subtitle>{{ selectedChunk?.chunk_id || '-' }}</v-list-item-subtitle>
</v-list-item>
</v-list>
@@ -212,71 +232,6 @@
</v-card>
</v-dialog>
<!-- 编辑分块对话框 -->
<v-dialog v-model="showEditDialog" max-width="800px" persistent scrollable>
<v-card>
<v-card-title class="pa-4">
<span>{{ t('edit.title') }}</span>
<v-spacer />
<v-btn icon="mdi-close" variant="text" @click="closeEditDialog" />
</v-card-title>
<v-divider />
<v-card-text class="pa-6">
<v-textarea
v-model="editForm.content"
:label="t('edit.content')"
variant="outlined"
rows="15"
auto-grow
/>
</v-card-text>
<v-divider />
<v-card-actions class="pa-4">
<v-spacer />
<v-btn variant="text" @click="closeEditDialog">
{{ t('edit.cancel') }}
</v-btn>
<v-btn
color="primary"
variant="elevated"
@click="saveChunk"
:loading="saving"
>
{{ t('edit.save') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<!-- 删除确认对话框 -->
<v-dialog v-model="showDeleteDialog" max-width="450px">
<v-card>
<v-card-title class="pa-4 text-h6">{{ t('delete.title') }}</v-card-title>
<v-divider />
<v-card-text class="pa-6">
<p>{{ t('delete.confirmText') }}</p>
<v-alert type="warning" variant="tonal" density="compact" class="mt-4">
{{ t('delete.warning') }}
</v-alert>
</v-card-text>
<v-divider />
<v-card-actions class="pa-4">
<v-spacer />
<v-btn variant="text" @click="showDeleteDialog = false">
{{ t('delete.cancel') }}
</v-btn>
<v-btn
color="error"
variant="elevated"
@click="deleteChunk"
:loading="deleting"
>
{{ t('delete.confirm') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<!-- 消息提示 -->
<v-snackbar v-model="snackbar.show" :color="snackbar.color">
{{ snackbar.text }}
@@ -299,16 +254,16 @@ const docId = ref(route.params.docId as string)
//
const loading = ref(true)
const loadingChunks = ref(false)
const saving = ref(false)
const deleting = ref(false)
const document = ref<any>({})
const chunks = ref<any[]>([])
const searchQuery = ref('')
const showViewDialog = ref(false)
const showEditDialog = ref(false)
const showDeleteDialog = ref(false)
const selectedChunk = ref<any>(null)
const deleteTarget = ref<any>(null)
//
const page = ref(1)
const pageSize = ref(10)
const totalChunks = ref(0)
const snackbar = ref({
show: false,
@@ -322,17 +277,12 @@ const showSnackbar = (text: string, color: string = 'success') => {
snackbar.value.show = true
}
//
const editForm = ref({
content: ''
})
//
const headers = [
{ title: t('chunks.index'), key: 'chunk_index', width: 100 },
{ title: t('chunks.content'), key: 'content', sortable: false },
{ title: t('chunks.charCount'), key: 'char_count', width: 150 },
{ title: t('chunks.actions'), key: 'actions', sortable: false, align: 'end', width: 150 }
{ title: t('chunks.actions'), key: 'actions', sortable: false, width: 150 }
]
//
@@ -349,7 +299,7 @@ const loadDocument = async () => {
loading.value = true
try {
const response = await axios.get('/api/kb/document/get', {
params: { doc_id: docId.value }
params: { doc_id: docId.value, kb_id: kbId.value }
})
if (response.data.status === 'ok') {
document.value = response.data.data
@@ -367,10 +317,16 @@ const loadChunks = async () => {
loadingChunks.value = true
try {
const response = await axios.get('/api/kb/chunk/list', {
params: { doc_id: docId.value }
params: {
doc_id: docId.value,
kb_id: kbId.value,
page: page.value,
page_size: pageSize.value
}
})
if (response.data.status === 'ok') {
chunks.value = response.data.data.items || []
totalChunks.value = response.data.data.total || 0
}
} catch (error) {
console.error('Failed to load chunks:', error)
@@ -380,81 +336,42 @@ const loadChunks = async () => {
}
}
//
const handlePageChange = (newPage: number) => {
page.value = newPage
loadChunks()
}
const handlePageSizeChange = (newPageSize: number) => {
pageSize.value = newPageSize
page.value = 1
loadChunks()
}
//
const viewChunk = (chunk: any) => {
selectedChunk.value = chunk
showViewDialog.value = true
}
//
const editChunk = (chunk: any) => {
selectedChunk.value = chunk
editForm.value.content = chunk.content
showEditDialog.value = true
}
//
const closeEditDialog = () => {
showEditDialog.value = false
selectedChunk.value = null
editForm.value.content = ''
}
//
const saveChunk = async () => {
if (!selectedChunk.value) return
saving.value = true
try {
const response = await axios.post('/api/kb/chunk/update', {
chunk_id: selectedChunk.value.chunk_id,
content: editForm.value.content
})
if (response.data.status === 'ok') {
showSnackbar(t('edit.saveSuccess'))
closeEditDialog()
await loadChunks()
} else {
showSnackbar(response.data.message || t('edit.saveFailed'), 'error')
}
} catch (error) {
console.error('Failed to save chunk:', error)
showSnackbar(t('edit.saveFailed'), 'error')
} finally {
saving.value = false
}
}
//
const confirmDeleteChunk = (chunk: any) => {
deleteTarget.value = chunk
showDeleteDialog.value = true
}
//
const deleteChunk = async () => {
if (!deleteTarget.value) return
deleting.value = true
const deleteChunk = async (chunk: any) => {
if (!confirm(t('chunks.deleteConfirm'))) return
try {
const response = await axios.post('/api/kb/chunk/delete', {
chunk_id: deleteTarget.value.chunk_id
chunk_id: chunk.chunk_id,
doc_id: docId.value,
kb_id: kbId.value
})
if (response.data.status === 'ok') {
showSnackbar(t('delete.deleteSuccess'))
showDeleteDialog.value = false
await loadChunks()
await loadDocument()
showSnackbar(t('chunks.deleteSuccess'))
loadChunks()
} else {
showSnackbar(response.data.message || t('delete.deleteFailed'), 'error')
showSnackbar(t('chunks.deleteFailed'), 'error')
}
} catch (error) {
console.error('Failed to delete chunk:', error)
showSnackbar(t('delete.deleteFailed'), 'error')
} finally {
deleting.value = false
showSnackbar(t('chunks.deleteFailed'), 'error')
}
}
@@ -582,6 +499,10 @@ onMounted(() => {
font-family: 'Consolas', 'Monaco', monospace;
}
.gap-2 {
gap: 8px;
}
/* 响应式设计 */
@media (max-width: 768px) {
.document-detail-page {
@@ -40,10 +40,6 @@
<v-icon start>mdi-magnify</v-icon>
{{ t('tabs.retrieval') }}
</v-tab>
<v-tab value="sessions">
<v-icon start>mdi-account-multiple</v-icon>
{{ t('tabs.sessions') }}
</v-tab>
<v-tab value="settings">
<v-icon start>mdi-cog</v-icon>
{{ t('tabs.settings') }}
@@ -51,7 +47,7 @@
</v-tabs>
<!-- 标签页内容 -->
<v-window v-model="activeTab">
<v-window v-model="activeTab" style="padding: 8px;">
<!-- 概览 -->
<v-window-item value="overview">
<v-row>
@@ -163,12 +159,7 @@
<!-- 知识库检索 -->
<v-window-item value="retrieval">
<RetrievalTab :kb-id="kbId" />
</v-window-item>
<!-- 使用会话 -->
<v-window-item value="sessions">
<SessionsTab :kb-id="kbId" />
<RetrievalTab :kb-id="kbId" :kb-name="kb.kb_name"/>
</v-window-item>
<!-- 设置 -->
@@ -192,7 +183,6 @@ import axios from 'axios'
import { useModuleI18n } from '@/i18n/composables'
import DocumentsTab from './components/DocumentsTab.vue'
import RetrievalTab from './components/RetrievalTab.vue'
import SessionsTab from './components/SessionsTab.vue'
import SettingsTab from './components/SettingsTab.vue'
const { tm: t } = useModuleI18n('features/knowledge-base/detail')
@@ -255,7 +245,6 @@ onMounted(() => {
<style scoped>
.kb-detail-page {
padding: 24px;
max-width: 1400px;
margin: 0 auto;
animation: fadeIn 0.3s ease;
@@ -329,12 +318,11 @@ onMounted(() => {
padding: 24px;
text-align: center;
border-radius: 12px;
background: rgba(var(--v-theme-surface-variant), 0.3);
background: rgba(var(--v-theme-surface-variant), 0.1);
transition: all 0.3s ease;
}
.stat-box:hover {
transform: translateY(-4px);
background: rgba(var(--v-theme-surface-variant), 0.5);
}
@@ -342,21 +330,15 @@ onMounted(() => {
font-size: 2rem;
font-weight: 600;
margin-top: 8px;
color: rgb(var(--v-theme-on-surface));
}
.stat-label {
font-size: 0.875rem;
color: rgb(var(--v-theme-on-surface-variant));
margin-top: 4px;
}
/* 响应式设计 */
@media (max-width: 768px) {
.kb-detail-page {
padding: 16px;
}
.kb-title {
flex-direction: column;
align-items: flex-start;
+29 -181
View File
@@ -6,32 +6,16 @@
<h1 class="text-h4 mb-2">{{ t('list.title') }}</h1>
<p class="text-subtitle-1 text-medium-emphasis">{{ t('list.subtitle') }}</p>
</div>
<v-btn
icon="mdi-information-outline"
variant="text"
size="small"
color="grey"
href="https://astrbot.app/use/knowledge-base.html"
target="_blank"
/>
<v-btn icon="mdi-information-outline" variant="text" size="small" color="grey"
href="https://astrbot.app/use/knowledge-base.html" target="_blank" />
</div>
<!-- 操作按钮栏 -->
<div class="action-bar mb-6">
<v-btn
prepend-icon="mdi-plus"
color="primary"
variant="elevated"
@click="showCreateDialog = true"
>
<v-btn prepend-icon="mdi-plus" color="primary" variant="elevated" @click="showCreateDialog = true">
{{ t('list.create') }}
</v-btn>
<v-btn
prepend-icon="mdi-refresh"
variant="tonal"
@click="loadKnowledgeBases"
:loading="loading"
>
<v-btn prepend-icon="mdi-refresh" variant="tonal" @click="loadKnowledgeBases" :loading="loading">
{{ t('list.refresh') }}
</v-btn>
</div>
@@ -43,14 +27,8 @@
</div>
<div v-else-if="kbList.length > 0" class="kb-grid">
<v-card
v-for="kb in kbList"
:key="kb.kb_id"
class="kb-card"
elevation="2"
hover
@click="navigateToDetail(kb.kb_id)"
>
<v-card v-for="kb in kbList" :key="kb.kb_id" class="kb-card" elevation="2" hover
@click="navigateToDetail(kb.kb_id)">
<div class="kb-card-content">
<div class="kb-emoji">{{ kb.emoji || '📚' }}</div>
<h3 class="kb-name">{{ kb.kb_name }}</h3>
@@ -68,20 +46,8 @@
</div>
<div class="kb-actions">
<v-btn
icon="mdi-pencil"
size="small"
variant="text"
color="info"
@click.stop="editKB(kb)"
/>
<v-btn
icon="mdi-delete"
size="small"
variant="text"
color="error"
@click.stop="confirmDelete(kb)"
/>
<v-btn icon="mdi-pencil" size="small" variant="text" color="info" @click.stop="editKB(kb)" />
<v-btn icon="mdi-delete" size="small" variant="text" color="error" @click.stop="confirmDelete(kb)" />
</div>
</div>
</v-card>
@@ -91,14 +57,8 @@
<div v-else class="empty-state">
<v-icon size="100" color="grey-lighten-2">mdi-book-open-variant</v-icon>
<h2 class="mt-4">{{ t('list.empty') }}</h2>
<v-btn
class="mt-6"
prepend-icon="mdi-plus"
color="primary"
variant="elevated"
size="large"
@click="showCreateDialog = true"
>
<v-btn class="mt-6" prepend-icon="mdi-plus" color="primary" variant="elevated" size="large"
@click="showCreateDialog = true">
{{ t('list.create') }}
</v-btn>
</div>
@@ -125,35 +85,16 @@
<!-- 表单 -->
<v-form ref="formRef" @submit.prevent="submitForm">
<v-text-field
v-model="formData.kb_name"
:label="t('create.nameLabel')"
:placeholder="t('create.namePlaceholder')"
variant="outlined"
:rules="[v => !!v || t('create.nameRequired')]"
required
class="mb-4"
/>
<v-text-field v-model="formData.kb_name" :label="t('create.nameLabel')"
:placeholder="t('create.namePlaceholder')" variant="outlined"
:rules="[v => !!v || t('create.nameRequired')]" required class="mb-4" />
<v-textarea
v-model="formData.description"
:label="t('create.descriptionLabel')"
:placeholder="t('create.descriptionPlaceholder')"
variant="outlined"
rows="3"
class="mb-4"
/>
<v-textarea v-model="formData.description" :label="t('create.descriptionLabel')"
:placeholder="t('create.descriptionPlaceholder')" variant="outlined" rows="3" class="mb-4" />
<v-select
v-model="formData.embedding_provider_id"
:items="embeddingProviders"
:item-title="item => item.embedding_model || item.id"
:item-value="'id'"
:label="t('create.embeddingModelLabel')"
variant="outlined"
class="mb-4"
@update:model-value="handleEmbeddingProviderChange"
>
<v-select v-model="formData.embedding_provider_id" :items="embeddingProviders"
:item-title="item => item.embedding_model || item.id" :item-value="'id'"
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null">
<template #item="{ props, item }">
<v-list-item v-bind="props">
<template #subtitle>
@@ -166,20 +107,9 @@
</template>
</v-select>
<v-alert type="warning" variant="tonal" density="compact" class="mb-4" v-if="editingKB && showEmbeddingWarning">
<strong>注意:</strong> 修改嵌入模型会导致现有的向量数据失效,建议重新上传文档不同的嵌入模型生成的向量不兼容,可能导致检索结果不准确
</v-alert>
<v-select
v-model="formData.rerank_provider_id"
:items="rerankProviders"
:item-title="item => item.rerank_model || item.id"
:item-value="'id'"
:label="t('create.rerankModelLabel')"
variant="outlined"
clearable
class="mb-2"
>
<v-select v-model="formData.rerank_provider_id" :items="rerankProviders"
:item-title="item => item.rerank_model || item.id" :item-value="'id'"
:label="t('create.rerankModelLabel')" variant="outlined" clearable class="mb-2">
<template #item="{ props, item }">
<v-list-item v-bind="props">
<template #subtitle>
@@ -202,12 +132,7 @@
<v-btn variant="text" @click="closeCreateDialog">
{{ t('create.cancel') }}
</v-btn>
<v-btn
color="primary"
variant="elevated"
@click="submitForm"
:loading="saving"
>
<v-btn color="primary" variant="elevated" @click="submitForm" :loading="saving">
{{ editingKB ? t('edit.submit') : t('create.submit') }}
</v-btn>
</v-card-actions>
@@ -223,12 +148,7 @@
<div v-for="category in emojiCategories" :key="category.key" class="mb-4">
<p class="text-subtitle-2 mb-2">{{ t(`emoji.categories.${category.key}`) }}</p>
<div class="emoji-grid">
<div
v-for="emoji in category.emojis"
:key="emoji"
class="emoji-item"
@click="selectEmoji(emoji)"
>
<div v-for="emoji in category.emojis" :key="emoji" class="emoji-item" @click="selectEmoji(emoji)">
{{ emoji }}
</div>
</div>
@@ -261,12 +181,7 @@
<v-btn variant="text" @click="cancelDelete">
{{ t('delete.cancel') }}
</v-btn>
<v-btn
color="error"
variant="elevated"
@click="deleteKB"
:loading="deleting"
>
<v-btn color="error" variant="elevated" @click="deleteKB" :loading="deleting">
{{ t('delete.confirm') }}
</v-btn>
</v-card-actions>
@@ -278,38 +193,10 @@
{{ snackbar.text }}
</v-snackbar>
<!-- Embedding Provider 修改确认对话框 -->
<v-dialog v-model="embeddingChangeDialog" max-width="500px" persistent>
<v-card>
<v-card-title class="bg-warning text-white">
<v-icon class="mr-2">mdi-alert</v-icon>
确认修改嵌入模型
</v-card-title>
<v-card-text class="pa-6">
<v-alert type="warning" variant="tonal" class="mb-4">
<strong>警告:</strong> 修改嵌入模型将导致以下影响:
</v-alert>
<ul class="text-body-2">
<li>现有的向量数据将失效</li>
<li>检索功能可能无法正常工作</li>
<li>建议删除现有文档后重新上传</li>
<li>不同嵌入模型生成的向量不兼容</li>
</ul>
<div class="mt-4 text-body-2">
您确定要将嵌入模型从 <strong>{{ originalEmbeddingProvider }}</strong> 修改为 <strong>{{ pendingEmbeddingProvider }}</strong> ?
</div>
</v-card-text>
<v-card-actions class="pa-4">
<v-spacer />
<v-btn variant="text" @click="cancelEmbeddingChange">
取消
</v-btn>
<v-btn color="warning" variant="elevated" @click="confirmEmbeddingChange">
确认修改
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<div class="position-absolute" style="bottom: 0px; right: 16px;">
<small @click="router.push('/alkaid/knowledge-base')"><a style="text-decoration: underline; cursor: pointer;">切换到旧版知识库</a></small>
</div>
</div>
</template>
@@ -439,39 +326,6 @@ const editKB = (kb: any) => {
showCreateDialog.value = true
}
// embedding provider
const handleEmbeddingProviderChange = (newValue: string | null) => {
// embedding provider
if (newValue && originalEmbeddingProvider.value && newValue !== originalEmbeddingProvider.value) {
//
showEmbeddingWarning.value = true
pendingEmbeddingProvider.value = newValue
embeddingChangeDialog.value = true
} else {
showEmbeddingWarning.value = false
}
}
// embedding provider
const confirmEmbeddingChange = () => {
if (pendingEmbeddingProvider.value) {
formData.value.embedding_provider_id = pendingEmbeddingProvider.value
// ,
originalEmbeddingProvider.value = pendingEmbeddingProvider.value
}
embeddingChangeDialog.value = false
showEmbeddingWarning.value = true
}
// embedding provider
const cancelEmbeddingChange = () => {
//
formData.value.embedding_provider_id = originalEmbeddingProvider.value
embeddingChangeDialog.value = false
showEmbeddingWarning.value = false
pendingEmbeddingProvider.value = null
}
//
const confirmDelete = (kb: any) => {
deleteTarget.value = kb
@@ -640,13 +494,7 @@ onMounted(() => {
.kb-emoji {
font-size: 56px;
margin-bottom: 16px;
animation: float 3s ease-in-out infinite;
}
@keyframes float {
0%, 100% { transform: translateY(0); }
50% { transform: translateY(-10px); }
margin-bottom: 8px;
}
.kb-name {
@@ -2,41 +2,35 @@
<div class="documents-tab">
<!-- 操作栏 -->
<div class="action-bar mb-4">
<v-btn
prepend-icon="mdi-upload"
color="primary"
variant="elevated"
@click="showUploadDialog = true"
>
<v-btn prepend-icon="mdi-upload" color="primary" variant="elevated" @click="showUploadDialog = true">
{{ t('documents.upload') }}
</v-btn>
<v-text-field
v-model="searchQuery"
prepend-inner-icon="mdi-magnify"
:placeholder="'搜索文档...'"
variant="outlined"
density="compact"
hide-details
clearable
style="max-width: 300px"
/>
<v-text-field v-model="searchQuery" prepend-inner-icon="mdi-magnify" :placeholder="'搜索文档...'" variant="outlined"
density="compact" hide-details clearable style="max-width: 300px" />
</div>
<!-- 文档列表 -->
<v-card elevation="2">
<v-data-table
:headers="headers"
:items="documents"
:loading="loading"
:search="searchQuery"
:items-per-page="10"
>
<v-data-table :headers="headers" :items="documents" :loading="loading" :search="searchQuery" :items-per-page="10">
<template #item.doc_name="{ item }">
<div class="d-flex align-center gap-2">
<v-icon :color="getFileColor(item.file_type)">
<v-icon :color="getFileColor(item.file_type)" class="mr-2">
{{ getFileIcon(item.file_type) }}
</v-icon>
<span class="font-weight-medium">{{ item.doc_name }}</span>
<div class="flex-grow-1" style="padding: 4px 0px;">
<span class="font-weight-medium">{{ item.doc_name }}</span>
<!-- 上传进度 -->
<div v-if="item.uploading" class="mt-1">
<div class="text-caption text-medium-emphasis mb-1">
{{ getStageText(item.uploadProgress?.stage || 'waiting') }}
<span v-if="item.uploadProgress?.current">
({{ item.uploadProgress.current }} / {{ item.uploadProgress.total }})
</span>
</div>
<v-progress-linear :model-value="getUploadPercentage(item)" color="primary" height="4" rounded
striped />
</div>
</div>
</div>
</template>
@@ -49,20 +43,8 @@
</template>
<template #item.actions="{ item }">
<v-btn
icon="mdi-eye"
variant="text"
size="small"
color="info"
@click="viewDocument(item)"
/>
<v-btn
icon="mdi-delete"
variant="text"
size="small"
color="error"
@click="confirmDelete(item)"
/>
<v-btn icon="mdi-eye" variant="text" size="small" color="info" @click="viewDocument(item)" />
<v-btn icon="mdi-delete" variant="text" size="small" color="error" @click="confirmDelete(item)" />
</template>
<template #no-data>
@@ -75,9 +57,9 @@
</v-card>
<!-- 上传对话框 -->
<v-dialog v-model="showUploadDialog" max-width="600px" persistent>
<v-dialog v-model="showUploadDialog" max-width="600px" persistent @after-enter="initUploadSettings">
<v-card>
<v-card-title class="pa-4">
<v-card-title class="pa-4 d-flex align-center">
<span class="text-h5">{{ t('upload.title') }}</span>
<v-spacer />
<v-btn icon="mdi-close" variant="text" @click="closeUploadDialog" />
@@ -87,84 +69,88 @@
<v-card-text class="pa-6">
<!-- 文件选择 -->
<div
class="upload-dropzone"
:class="{ 'dragover': isDragging }"
@drop.prevent="handleDrop"
@dragover.prevent="isDragging = true"
@dragleave="isDragging = false"
@click="$refs.fileInput.click()"
>
<div class="upload-dropzone" :class="{ 'dragover': isDragging }" @drop.prevent="handleDrop"
@dragover.prevent="isDragging = true" @dragleave="isDragging = false" @click="fileInput?.click()">
<v-icon size="64" color="primary">mdi-cloud-upload</v-icon>
<p class="mt-4 text-h6">{{ t('upload.dropzone') }}</p>
<p class="text-caption text-medium-emphasis mt-2">{{ t('upload.supportedFormats') }}</p>
<p class="text-caption text-medium-emphasis">{{ t('upload.maxSize') }}</p>
<input
ref="fileInput"
type="file"
hidden
accept=".txt,.md,.pdf"
@change="handleFileSelect"
/>
<p class="text-caption text-medium-emphasis">最多可上传 10 个文件</p>
<input ref="fileInput" type="file" multiple hidden accept=".txt,.md,.pdf" @change="handleFileSelect" />
</div>
<div v-if="selectedFile" class="mt-4 pa-4 rounded bg-surface-variant">
<div class="d-flex align-center justify-space-between">
<div class="d-flex align-center gap-2">
<v-icon>{{ getFileIcon(selectedFile.name) }}</v-icon>
<div>
<div class="font-weight-medium">{{ selectedFile.name }}</div>
<div class="text-caption">{{ formatFileSize(selectedFile.size) }}</div>
<div v-if="selectedFiles.length > 0" class="mt-4">
<div class="d-flex align-center justify-space-between mb-2">
<span class="text-subtitle-2">已选择 {{ selectedFiles.length }} 个文件</span>
<v-btn variant="text" size="small" @click="selectedFiles = []">清空</v-btn>
</div>
<div class="files-list">
<div v-for="(file, index) in selectedFiles" :key="index"
class="file-item pa-3 mb-2 rounded bg-surface-variant">
<div class="d-flex align-center justify-space-between">
<div class="d-flex align-center gap-2">
<v-icon>{{ getFileIcon(file.name) }}</v-icon>
<div>
<div class="font-weight-medium">{{ file.name }}</div>
<div class="text-caption">{{ formatFileSize(file.size) }}</div>
</div>
</div>
<v-btn icon="mdi-close" variant="text" size="small" @click="removeFile(index)" />
</div>
</div>
<v-btn icon="mdi-close" variant="text" size="small" @click="selectedFile = null" />
</div>
</div>
<!-- 分块设置 -->
<v-expansion-panels class="mt-4">
<v-expansion-panel>
<v-expansion-panel-title>
<v-icon start>mdi-cog</v-icon>
{{ t('upload.chunkSettings') }}
</v-expansion-panel-title>
<v-expansion-panel-text>
<v-text-field
v-model.number="uploadSettings.chunk_size"
:label="t('upload.chunkSize')"
:hint="t('upload.chunkSizeHint')"
type="number"
variant="outlined"
density="compact"
class="mb-2"
/>
<v-text-field
v-model.number="uploadSettings.chunk_overlap"
:label="t('upload.chunkOverlap')"
:hint="t('upload.chunkOverlapHint')"
type="number"
variant="outlined"
density="compact"
/>
</v-expansion-panel-text>
</v-expansion-panel>
</v-expansion-panels>
<div class="mt-6">
<div class="d-flex align-center mb-4">
<h3 class="text-h6">{{ t('upload.chunkSettings') }}</h3>
</div>
<v-row>
<v-col cols="12" sm="6">
<v-text-field v-model.number="uploadSettings.chunk_size" :label="t('upload.chunkSize')"
:hint="t('upload.chunkSizeHint')" persistent-hint type="number" variant="outlined" density="compact"
:placeholder="props.kb?.chunk_size?.toString() || '512'" />
</v-col>
<v-col cols="12" sm="6">
<v-text-field v-model.number="uploadSettings.chunk_overlap" :label="t('upload.chunkOverlap')"
:hint="t('upload.chunkOverlapHint')" persistent-hint type="number" variant="outlined"
density="compact" :placeholder="props.kb?.chunk_overlap?.toString() || '50'" />
</v-col>
</v-row>
</div>
<div class="mt-2">
<h3 class="text-h6 mb-4">{{ t('upload.batchSettings') }}</h3>
<v-row>
<v-col cols="12" sm="4">
<v-text-field v-model.number="uploadSettings.batch_size" :label="t('upload.batchSize')" hint="每批处理的文本数量"
persistent-hint type="number" variant="outlined" density="compact" />
</v-col>
<v-col cols="12" sm="4">
<v-text-field v-model.number="uploadSettings.tasks_limit" :label="t('upload.tasksLimit')"
hint="并发任务数量限制" persistent-hint type="number" variant="outlined" density="compact" />
</v-col>
<v-col cols="12" sm="4">
<v-text-field v-model.number="uploadSettings.max_retries" :label="t('upload.maxRetries')"
hint="失败时的最大重试次数" persistent-hint type="number" variant="outlined" density="compact" />
</v-col>
</v-row>
</div>
</v-card-text>
<v-divider />
<v-card-actions class="pa-4">
<v-spacer />
<v-btn variant="text" @click="closeUploadDialog">
<v-btn variant="text" @click="closeUploadDialog" :disabled="uploading">
{{ t('upload.cancel') }}
</v-btn>
<v-btn
color="primary"
variant="elevated"
@click="uploadDocument"
:loading="uploading"
:disabled="!selectedFile"
>
<v-btn color="primary" variant="elevated" @click="uploadDocument" :loading="uploading"
:disabled="selectedFiles.length === 0">
{{ t('upload.submit') }}
</v-btn>
</v-card-actions>
@@ -201,7 +187,7 @@
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { ref, onMounted, onUnmounted } from 'vue'
import { useRouter } from 'vue-router'
import axios from 'axios'
import { useModuleI18n } from '@/i18n/composables'
@@ -224,9 +210,14 @@ const documents = ref<any[]>([])
const searchQuery = ref('')
const showUploadDialog = ref(false)
const showDeleteDialog = ref(false)
const selectedFile = ref<File | null>(null)
const selectedFiles = ref<File[]>([])
const deleteTarget = ref<any>(null)
const isDragging = ref(false)
const fileInput = ref<HTMLInputElement | null>(null)
// -
const uploadingTasks = ref<Map<string, any>>(new Map())
const progressPollingInterval = ref<number | null>(null)
const snackbar = ref({
show: false,
@@ -242,10 +233,24 @@ const showSnackbar = (text: string, color: string = 'success') => {
//
const uploadSettings = ref({
chunk_size: null,
chunk_overlap: null
chunk_size: null as number | null,
chunk_overlap: null as number | null,
batch_size: 32,
tasks_limit: 3,
max_retries: 3
})
//
const initUploadSettings = () => {
uploadSettings.value = {
chunk_size: props.kb?.chunk_size || null,
chunk_overlap: props.kb?.chunk_overlap || null,
batch_size: 32,
tasks_limit: 3,
max_retries: 3
}
}
//
const headers = [
{ title: t('documents.name'), key: 'doc_name', sortable: true },
@@ -253,7 +258,7 @@ const headers = [
{ title: t('documents.size'), key: 'file_size', sortable: true },
{ title: t('documents.chunks'), key: 'chunk_count', sortable: true },
{ title: t('documents.createdAt'), key: 'created_at', sortable: true },
{ title: t('documents.actions'), key: 'actions', sortable: false, align: 'end' }
{ title: t('documents.actions'), key: 'actions', sortable: false, align: 'end' as const }
]
//
@@ -277,30 +282,53 @@ const loadDocuments = async () => {
//
const handleFileSelect = (event: Event) => {
const target = event.target as HTMLInputElement
if (target.files && target.files[0]) {
selectedFile.value = target.files[0]
if (target.files && target.files.length > 0) {
const newFiles = Array.from(target.files)
addFiles(newFiles)
}
}
//
const addFiles = (files: File[]) => {
const totalFiles = selectedFiles.value.length + files.length
if (totalFiles > 10) {
showSnackbar('最多只能选择 10 个文件', 'warning')
return
}
selectedFiles.value.push(...files)
}
//
const removeFile = (index: number) => {
selectedFiles.value.splice(index, 1)
}
//
const handleDrop = (event: DragEvent) => {
isDragging.value = false
if (event.dataTransfer?.files && event.dataTransfer.files[0]) {
selectedFile.value = event.dataTransfer.files[0]
if (event.dataTransfer?.files && event.dataTransfer.files.length > 0) {
const newFiles = Array.from(event.dataTransfer.files)
addFiles(newFiles)
}
}
//
const uploadDocument = async () => {
if (!selectedFile.value) {
if (selectedFiles.value.length === 0) {
showSnackbar(t('upload.fileRequired'), 'warning')
return
}
uploading.value = true
try {
const formData = new FormData()
formData.append('file', selectedFile.value)
//
selectedFiles.value.forEach((file, index) => {
formData.append(`file${index}`, file)
})
formData.append('kb_id', props.kbId)
if (uploadSettings.value.chunk_size) {
formData.append('chunk_size', uploadSettings.value.chunk_size.toString())
@@ -308,16 +336,47 @@ const uploadDocument = async () => {
if (uploadSettings.value.chunk_overlap) {
formData.append('chunk_overlap', uploadSettings.value.chunk_overlap.toString())
}
formData.append('batch_size', uploadSettings.value.batch_size.toString())
formData.append('tasks_limit', uploadSettings.value.tasks_limit.toString())
formData.append('max_retries', uploadSettings.value.max_retries.toString())
const response = await axios.post('/api/kb/document/upload', formData, {
headers: { 'Content-Type': 'multipart/form-data' }
})
if (response.data.status === 'ok') {
showSnackbar(t('documents.uploadSuccess'))
const result = response.data.data
const taskId = result.task_id
showSnackbar(`正在后台上传 ${result.file_count} 个文件...`, 'info')
//
const uploadingDocs = selectedFiles.value.map((file, index) => ({
doc_id: `uploading_${taskId}_${index}`,
doc_name: file.name,
file_type: file.name.split('.').pop() || '',
file_size: file.size,
chunk_count: 0,
created_at: new Date().toISOString(),
uploading: true,
taskId: taskId,
uploadProgress: {
stage: 'waiting',
current: 0,
total: 100
}
}))
//
documents.value = [...uploadingDocs, ...documents.value]
//
closeUploadDialog()
await loadDocuments()
emit('refresh')
//
if (taskId) {
startProgressPolling(taskId)
}
} else {
showSnackbar(response.data.message || t('documents.uploadFailed'), 'error')
}
@@ -329,11 +388,118 @@ const uploadDocument = async () => {
}
}
//
const startProgressPolling = (taskId: string) => {
//
if (progressPollingInterval.value) {
stopProgressPolling()
}
progressPollingInterval.value = window.setInterval(async () => {
try {
const response = await axios.get('/api/kb/document/upload/progress', {
params: { task_id: taskId }
})
if (response.data.status === 'ok') {
const data = response.data.data
const status = data.status
if (status === 'processing' && data.progress) {
//
const progress = data.progress
const fileIndex = progress.file_index || 0
//
documents.value = documents.value.map(doc => {
if (doc.taskId === taskId) {
const docIndex = parseInt(doc.doc_id.split('_').pop() || '0')
if (docIndex === fileIndex) {
return {
...doc,
uploadProgress: {
stage: progress.stage || 'waiting',
current: progress.current || 0,
total: progress.total || 100
}
}
}
}
return doc
})
} else if (status === 'completed') {
//
stopProgressPolling()
const result = data.result
const successCount = result?.success_count || 0
const failedCount = result?.failed_count || 0
//
documents.value = documents.value.filter(doc => doc.taskId !== taskId)
//
await loadDocuments()
emit('refresh')
if (failedCount === 0) {
showSnackbar(`成功上传 ${successCount} 个文档`)
} else {
showSnackbar(`上传完成: ${successCount} 个成功, ${failedCount} 个失败`, 'warning')
}
} else if (status === 'failed') {
//
stopProgressPolling()
//
documents.value = documents.value.filter(doc => doc.taskId !== taskId)
showSnackbar(`上传失败: ${data.error || '未知错误'}`, 'error')
}
} else {
//
stopProgressPolling()
documents.value = documents.value.filter(doc => doc.taskId !== taskId)
}
} catch (error) {
console.error('Failed to fetch progress:', error)
//
}
}, 500) // 500ms
}
//
const stopProgressPolling = () => {
if (progressPollingInterval.value) {
clearInterval(progressPollingInterval.value)
progressPollingInterval.value = null
}
}
//
const getUploadPercentage = (item: any) => {
if (!item.uploadProgress) return 0
const { current, total } = item.uploadProgress
if (!total || total === 0) return 0
return (current / total) * 100
}
//
const getStageText = (stage: string) => {
const stageMap: Record<string, string> = {
'waiting': '等待中...',
'parsing': '解析文档...',
'chunking': '文本分块...',
'embedding': '生成向量...'
}
return stageMap[stage] || stage
}
//
const closeUploadDialog = () => {
showUploadDialog.value = false
selectedFile.value = null
uploadSettings.value = { chunk_size: null, chunk_overlap: null }
selectedFiles.value = []
initUploadSettings()
}
//
@@ -357,7 +523,8 @@ const deleteDocument = async () => {
deleting.value = true
try {
const response = await axios.post('/api/kb/document/delete', {
doc_id: deleteTarget.value.doc_id
doc_id: deleteTarget.value.doc_id,
kb_id: props.kbId
})
if (response.data.status === 'ok') {
@@ -419,6 +586,10 @@ const formatDate = (dateStr: string) => {
onMounted(() => {
loadDocuments()
})
onUnmounted(() => {
stopProgressPolling()
})
</script>
<style scoped>
@@ -427,8 +598,13 @@ onMounted(() => {
}
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.action-bar {
@@ -456,13 +632,26 @@ onMounted(() => {
transform: scale(1.02);
}
.files-list {
max-height: 300px;
overflow-y: auto;
}
.file-item {
transition: all 0.2s ease;
}
.file-item:hover {
background: rgba(var(--v-theme-surface-variant), 0.8) !important;
}
@media (max-width: 768px) {
.action-bar {
flex-direction: column;
align-items: stretch;
}
.action-bar > * {
.action-bar>* {
width: 100%;
}
}
@@ -1,66 +1,57 @@
<template>
<div class="retrieval-tab">
<v-card elevation="2">
<v-card-title class="pa-4">{{ t('retrieval.title') }}</v-card-title>
<v-card-subtitle class="px-4 pb-4">
<v-card-title class="pa-4 pb-0">{{ t('retrieval.title') }}</v-card-title>
<v-card-subtitle class="pb-4 pt-2">
{{ t('retrieval.subtitle') }}
</v-card-subtitle>
<v-divider />
<v-progress-linear v-if="loading" indeterminate color="primary" height="2" />
<v-card-text class="pa-6">
<!-- 查询输入区域 -->
<v-row>
<v-row class="mb-4">
<v-col cols="12" md="8">
<v-textarea
v-model="query"
:label="t('retrieval.query')"
:placeholder="t('retrieval.queryPlaceholder')"
variant="outlined"
rows="3"
auto-grow
clearable
/>
<v-textarea v-model="query" :label="t('retrieval.query')" :placeholder="t('retrieval.queryPlaceholder')"
variant="outlined" rows="3" auto-grow clearable />
<!-- debug -->
<div v-if="debugVisualize" class="mt-2">
<v-card variant="outlined">
<v-img :src="`data:image/png;base64,${debugVisualize}`" :alt="t('retrieval.tsneVisualization')" cover>
<template v-slot:placeholder>
<div class="d-flex align-center justify-center fill-height">
<v-progress-circular indeterminate color="primary" />
</div>
</template>
</v-img>
</v-card>
</div>
</v-col>
<v-col cols="12" md="4">
<v-card variant="outlined" class="pa-4">
<h4 class="text-subtitle-2 mb-3">{{ t('retrieval.settings') }}</h4>
<v-text-field
v-model.number="topK"
:label="t('retrieval.topK')"
:hint="t('retrieval.topKHint')"
type="number"
variant="outlined"
density="compact"
persistent-hint
class="mb-3"
/>
<v-text-field v-model.number="topK" :label="t('retrieval.topK')" :hint="t('retrieval.topKHint')"
type="number" variant="outlined" density="compact" persistent-hint class="mb-3" />
<v-checkbox
v-model="enableRerank"
:label="t('retrieval.enableRerank')"
:hint="t('retrieval.enableRerankHint')"
color="primary"
density="compact"
persistent-hint
/>
<v-alert v-if="enableRerank" type="info" variant="tonal" class="mt-2" density="compact">
如果没有配置重排序模型提供商将跳过重排序步骤
</v-alert>
<v-switch v-model="debugMode" :label="t('retrieval.debugMode')" color="primary" density="compact"
hide-details>
<template v-slot:label>
<span class="text-caption">
<v-icon size="small" class="mr-1">mdi-bug</v-icon>
Debug (t-SNE)
</span>
</template>
</v-switch>
</v-card>
</v-col>
</v-row>
<div class="d-flex justify-end mb-4">
<v-btn
prepend-icon="mdi-magnify"
color="primary"
variant="elevated"
@click="performRetrieval"
:loading="loading"
:disabled="!query || query.trim() === ''"
>
<v-btn prepend-icon="mdi-magnify" color="primary" variant="elevated" @click="performRetrieval"
:loading="loading" :disabled="!query || query.trim() === ''">
{{ loading ? t('retrieval.searching') : t('retrieval.search') }}
</v-btn>
</div>
@@ -71,28 +62,33 @@
<div class="d-flex align-center mb-4">
<h3 class="text-h6">{{ t('retrieval.results') }}</h3>
<v-chip class="ml-3" color="primary" variant="tonal">
<v-chip class="ml-3" color="primary" variant="tonal" size="small">
{{ results.length }} {{ t('retrieval.results') }}
</v-chip>
</div>
<!-- 结果列表 -->
<div v-if="results.length > 0" class="results-list">
<v-card
v-for="(result, index) in results"
:key="result.chunk_id"
variant="outlined"
class="mb-4"
>
<v-card-title class="d-flex align-center pa-4">
<v-chip size="small" color="primary" class="mr-2">
<v-card v-for="(result, index) in results" :key="result.chunk_id" variant="outlined" class="mb-4">
<v-card-title class="d-flex align-center pa-2">
<v-chip size="x-small" color="primary" class="mr-2">
#{{ index + 1 }}
</v-chip>
<span class="text-subtitle-1">
{{ t('retrieval.chunk', { index: result.chunk_index }) }}
</span>
<div class="ml-4">
<v-chip size="x-small" variant="tonal" class="mr-2">
<v-icon start size="small">mdi-file-document</v-icon>
{{ result.doc_name }}
</v-chip>
<v-chip size="x-small" variant="tonal">
<v-icon start size="small">mdi-text</v-icon>
{{ t('retrieval.charCount', { count: result.char_count }) }}
</v-chip>
</div>
<v-spacer />
<v-chip size="small" :color="getScoreColor(result.score)">
<v-chip size="x-small" :color="getScoreColor(result.score)">
{{ t('retrieval.score') }}: {{ result.score.toFixed(4) }}
</v-chip>
</v-card-title>
@@ -100,17 +96,6 @@
<v-divider />
<v-card-text class="pa-4">
<div class="mb-3">
<v-chip size="small" variant="tonal" class="mr-2">
<v-icon start size="small">mdi-file-document</v-icon>
{{ result.doc_name }}
</v-chip>
<v-chip size="small" variant="tonal">
<v-icon start size="small">mdi-text</v-icon>
{{ t('retrieval.charCount', { count: result.char_count }) }}
</v-chip>
</div>
<div class="content-box">
{{ result.content }}
</div>
@@ -143,16 +128,18 @@ import { useModuleI18n } from '@/i18n/composables'
const { tm: t } = useModuleI18n('features/knowledge-base/detail')
const props = defineProps<{
kbId: string
kbId: string,
kbName: string,
}>()
//
const loading = ref(false)
const query = ref('')
const topK = ref(5)
const enableRerank = ref(false)
const debugMode = ref(false)
const results = ref<any[]>([])
const hasSearched = ref(false)
const debugVisualize = ref<string | null>(null)
const snackbar = ref({
show: false,
@@ -175,18 +162,24 @@ const performRetrieval = async () => {
loading.value = true
hasSearched.value = false
debugVisualize.value = null
try {
const response = await axios.post('/api/kb/retrieve', {
query: query.value,
kb_ids: [props.kbId],
kb_names: [props.kbName],
top_k: topK.value,
enable_rerank: enableRerank.value
debug: debugMode.value
})
if (response.data.status === 'ok') {
results.value = response.data.data.results || []
hasSearched.value = true
if (debugMode.value && response.data.data.visualization) {
debugVisualize.value = response.data.data.visualization
}
showSnackbar(t('retrieval.searchSuccess', { count: results.value.length }))
} else {
showSnackbar(response.data.message || t('retrieval.searchFailed'), 'error')
@@ -214,8 +207,13 @@ const getScoreColor = (score: number) => {
}
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.results-section {
@@ -227,6 +225,7 @@ const getScoreColor = (score: number) => {
opacity: 0;
transform: translateY(20px);
}
to {
opacity: 1;
transform: translateY(0);
@@ -234,7 +233,7 @@ const getScoreColor = (score: number) => {
}
.content-box {
background: rgba(var(--v-theme-surface-variant), 0.3);
background: rgba(var(--v-theme-surface-variant), 0.1);
border-radius: 8px;
padding: 16px;
white-space: pre-wrap;
@@ -242,5 +241,8 @@ const getScoreColor = (score: number) => {
font-family: 'Consolas', 'Monaco', 'Courier New', monospace;
font-size: 0.9rem;
line-height: 1.6;
height: 120px;
overflow-y: auto;
font-size: 13px;
}
</style>
@@ -1,172 +0,0 @@
<template>
<div class="sessions-tab">
<v-card elevation="2">
<v-card-title class="d-flex align-center pa-4">
<span>{{ t('sessions.title') }}</span>
<v-spacer />
<v-btn
prepend-icon="mdi-refresh"
variant="tonal"
size="small"
@click="loadSessions"
:loading="loading"
>
{{ t('sessions.refresh') }}
</v-btn>
</v-card-title>
<v-card-subtitle class="px-4 pb-4">
{{ t('sessions.subtitle') }}
</v-card-subtitle>
<v-divider />
<v-card-text class="pa-0">
<v-data-table
:headers="headers"
:items="sessions"
:loading="loading"
>
<template #item.scope="{ item }">
<v-chip :color="item.scope === 'session' ? 'primary' : 'secondary'" size="small" variant="tonal">
{{ item.scope === 'session' ? t('sessions.scopeSession') : t('sessions.scopePlatform') }}
</v-chip>
</template>
<template #item.enable_rerank="{ item }">
<v-icon :color="item.enable_rerank ? 'success' : 'grey'">
{{ item.enable_rerank ? 'mdi-check-circle' : 'mdi-close-circle' }}
</v-icon>
</template>
<template #item.actions="{ item }">
<v-btn
icon="mdi-open-in-new"
variant="text"
size="small"
color="primary"
@click="goToSessionManagement(item)"
:title="t('sessions.viewInSessionManagement')"
/>
</template>
<template #no-data>
<div class="text-center py-8">
<v-icon size="64" color="grey-lighten-2">mdi-account-multiple-outline</v-icon>
<p class="mt-4 text-medium-emphasis">{{ t('sessions.empty') }}</p>
<v-btn
class="mt-4"
prepend-icon="mdi-cog"
variant="tonal"
color="primary"
@click="goToSessionManagement()"
>
{{ t('sessions.goToSessionManagement') }}
</v-btn>
</div>
</template>
</v-data-table>
</v-card-text>
</v-card>
<!-- 消息提示 -->
<v-snackbar v-model="snackbar.show" :color="snackbar.color">
{{ snackbar.text }}
</v-snackbar>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted } from 'vue'
import { useRouter } from 'vue-router'
import axios from 'axios'
import { useModuleI18n } from '@/i18n/composables'
const { tm: t } = useModuleI18n('features/knowledge-base/detail')
const router = useRouter()
const props = defineProps<{
kbId: string
}>()
//
const loading = ref(false)
const sessions = ref<any[]>([])
const snackbar = ref({
show: false,
text: '',
color: 'success'
})
const showSnackbar = (text: string, color: string = 'success') => {
snackbar.value.text = text
snackbar.value.color = color
snackbar.value.show = true
}
//
const headers = [
{ title: t('sessions.scope'), key: 'scope' },
{ title: t('sessions.scopeId'), key: 'scope_id' },
{ title: t('sessions.topK'), key: 'top_k' },
{ title: t('sessions.enableRerank'), key: 'enable_rerank' },
{ title: t('sessions.actions'), key: 'actions', sortable: false, align: 'end' }
]
// 使
const loadSessions = async () => {
loading.value = true
try {
const url = '/api/kb/session/config/list_by_kb'
const params = { kb_id: props.kbId }
const response = await axios.get(url, { params })
if (response.data.status === 'ok') {
sessions.value = response.data.data.sessions
} else {
console.error('[SessionsTab] API返回错误:', response.data.message)
showSnackbar(response.data.message || t('sessions.loadFailed'), 'error')
}
} catch (error) {
console.error('[SessionsTab] 请求失败:', error)
if (error.response) {
console.error('[SessionsTab] 错误响应状态:', error.response.status)
console.error('[SessionsTab] 错误响应数据:', error.response.data)
} else if (error.request) {
console.error('[SessionsTab] 请求未收到响应:', error.request)
} else {
console.error('[SessionsTab] 请求配置错误:', error.message)
}
showSnackbar(t('sessions.loadFailed'), 'error')
} finally {
loading.value = false
}
}
//
const goToSessionManagement = (session?: any) => {
router.push({ name: 'SessionManagement' })
}
onMounted(() => {
loadSessions()
})
</script>
<style scoped>
.sessions-tab {
animation: fadeIn 0.3s ease;
}
@keyframes fadeIn {
from { opacity: 0; }
to { opacity: 1; }
}
</style>
@@ -34,7 +34,7 @@
<h3 class="text-h6 mb-4 mt-6">{{ t('settings.retrieval') }}</h3>
<v-row>
<v-col cols="12" md="4">
<v-col cols="12" md="6">
<v-text-field
v-model.number="formData.top_k_dense"
:label="t('settings.topKDense')"
@@ -43,7 +43,7 @@
density="comfortable"
/>
</v-col>
<v-col cols="12" md="4">
<v-col cols="12" md="6">
<v-text-field
v-model.number="formData.top_k_sparse"
:label="t('settings.topKSparse')"
@@ -52,7 +52,7 @@
density="comfortable"
/>
</v-col>
<v-col cols="12" md="4">
<!-- <v-col cols="12" md="4">
<v-text-field
v-model.number="formData.top_m_final"
:label="t('settings.topMFinal')"
@@ -60,23 +60,7 @@
variant="outlined"
density="comfortable"
/>
</v-col>
</v-row>
<v-row>
<v-col cols="12">
<v-checkbox
v-model="formData.enable_rerank"
:label="t('settings.enableRerank')"
:hint="rerankProviders.length === 0 ? '未检测到可用的重排序模型提供商' : '使用重排序模型提高检索质量'"
:disabled="rerankProviders.length === 0"
color="primary"
persistent-hint
/>
<v-alert v-if="formData.enable_rerank && rerankProviders.length === 0" type="warning" variant="tonal" class="mt-2" density="compact">
当前没有可用的重排序模型提供商,请先在提供商管理中添加支持 rerank 的模型
</v-alert>
</v-col>
</v-col> -->
</v-row>
<!-- 模型设置 -->
@@ -93,6 +77,7 @@
variant="outlined"
density="comfortable"
@update:model-value="handleEmbeddingProviderChange"
:disabled="true"
/>
</v-col>
<v-col cols="12" md="6">
@@ -216,8 +201,6 @@ const formData = ref({
chunk_overlap: 50,
top_k_dense: 50,
top_k_sparse: 50,
top_m_final: 5,
enable_rerank: false,
embedding_provider_id: '',
rerank_provider_id: ''
})
@@ -230,8 +213,7 @@ watch(() => props.kb, (kb) => {
chunk_overlap: kb.chunk_overlap || 50,
top_k_dense: kb.top_k_dense || 50,
top_k_sparse: kb.top_k_sparse || 50,
top_m_final: kb.top_m_final || 5,
enable_rerank: kb.enable_rerank === true,
// top_m_final: kb.top_m_final || 5,
embedding_provider_id: kb.embedding_provider_id || '',
rerank_provider_id: kb.rerank_provider_id || ''
}
@@ -299,8 +281,7 @@ const saveSettings = async () => {
chunk_overlap: formData.value.chunk_overlap,
top_k_dense: formData.value.top_k_dense,
top_k_sparse: formData.value.top_k_sparse,
top_m_final: formData.value.top_m_final,
enable_rerank: formData.value.enable_rerank,
// top_m_final: formData.value.top_m_final,
rerank_provider_id: formData.value.rerank_provider_id
})
+1 -1
View File
@@ -129,7 +129,7 @@ class ProviderCommands:
)
return
i = 1
ret = "下面列出了此服务提供商可用模型:"
ret = "下面列出了此模型提供商可用模型:"
for model in models:
ret += f"\n{i}. {model}"
i += 1
+14 -7
View File
@@ -11,19 +11,26 @@ class SIDCommand:
self.context = context
async def sid(self, event: AstrMessageEvent):
"""获取会话 ID 和 管理员 ID"""
"""获取消息来源信息"""
sid = event.unified_msg_origin
user_id = str(event.get_sender_id())
ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。
/wl <SID> 添加白名单, /dwl <SID> 删除白名单
UID: {user_id} ID 可用于设置管理员
/op <UID> 授权管理员, /deop <UID> 取消管理员"""
umo_platform = event.session.platform_id
umo_msg_type = event.session.message_type.value
umo_session_id = event.session.session_id
ret = (
f"UMO: 「{sid}」 此值可用于设置白名单。\n"
f"UID: 「{user_id}」 此值可用于设置管理员。\n"
f"消息会话来源信息:\n"
f" 机器人 ID: 「{umo_platform}\n"
f" 消息类型: 「{umo_msg_type}\n"
f" 会话 ID: 「{umo_session_id}\n"
f"消息来源可用于配置机器人的配置文件路由。"
)
if (
self.context.get_config()["platform_settings"]["unique_session"]
and event.get_group_id()
):
ret += f"\n\n当前处于独立会话模式, 此群 ID: {event.get_group_id()}, 也可将此 ID 加入白名单来放行整个群聊。"
ret += f"\n\n当前处于独立会话模式, 此群 ID: {event.get_group_id()}, 也可将此 ID 加入白名单来放行整个群聊。"
event.set_result(MessageEventResult().message(ret).use_t2i(False))
+8 -2
View File
@@ -4,7 +4,7 @@ import random
import astrbot.api.star as star
from astrbot.api.event import AstrMessageEvent
from astrbot.api.platform import MessageType
from astrbot.api.provider import ProviderRequest
from astrbot.api.provider import ProviderRequest, Provider
from astrbot.api.message_components import Plain, Image
from astrbot import logger
from collections import defaultdict
@@ -32,6 +32,7 @@ class LongTermMemory:
image_caption = (
True
if cfg["provider_settings"]["default_image_caption_provider_id"]
and cfg["provider_ltm_settings"]["image_caption"]
else False
)
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
@@ -73,6 +74,8 @@ class LongTermMemory:
provider = self.context.get_provider_by_id(image_caption_provider_id)
if not provider:
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
if not isinstance(provider, Provider):
raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述")
response = await provider.text_chat(
prompt=image_caption_prompt,
session_id=uuid.uuid4().hex,
@@ -122,8 +125,11 @@ class LongTermMemory:
elif isinstance(comp, Image):
if cfg["image_caption"]:
try:
url = comp.url if comp.url else comp.file
if not url:
raise Exception("图片 URL 为空")
caption = await self.get_image_caption(
comp.url if comp.url else comp.file,
url,
cfg["image_caption_provider_id"],
cfg["image_caption_prompt"],
)
+2 -1
View File
@@ -1,3 +1,4 @@
import copy
import astrbot.api.star as star
import builtins
import datetime
@@ -41,7 +42,7 @@ class ProcessLLMRequest:
if persona:
if prompt := persona["prompt"]:
req.system_prompt += prompt
if begin_dialogs := persona["_begin_dialogs_processed"]:
if begin_dialogs := copy.deepcopy(persona["_begin_dialogs_processed"]):
req.contexts[:0] = begin_dialogs
# tools select
+1
View File
@@ -54,6 +54,7 @@ dependencies = [
"pypdf>=6.1.1",
"aiofiles>=25.1.0",
"rank-bm25>=0.2.2",
"jieba>=0.42.1",
]
[project.scripts]