Merge pull request #114 from lxfight/lwl-dev/knowledge-base
refactor: 知识库优化
This commit is contained in:
@@ -13,11 +13,18 @@ jobs:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v5
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 'latest'
|
||||
|
||||
- name: npm install, build
|
||||
run: |
|
||||
cd dashboard
|
||||
npm install
|
||||
npm run build
|
||||
npm install pnpm -g
|
||||
pnpm install
|
||||
pnpm i --save-dev @types/markdown-it
|
||||
pnpm run build
|
||||
|
||||
- name: Inject Commit SHA
|
||||
id: get_sha
|
||||
|
||||
+17
-5
@@ -6,8 +6,20 @@ ci:
|
||||
autoupdate_schedule: weekly
|
||||
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.2
|
||||
hooks:
|
||||
- id: ruff
|
||||
- id: ruff-format
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.14.1
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff-check
|
||||
types_or: [ python, pyi ]
|
||||
args: [ --fix ]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
types_or: [ python, pyi ]
|
||||
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.21.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args: [--py310-plus]
|
||||
|
||||
@@ -4,14 +4,25 @@
|
||||
|
||||
<div align="center">
|
||||
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<br>
|
||||
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
<div>
|
||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp" alt="Featured|HelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<div>
|
||||
<img src="https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/Soulter/AstrBot/releases/latest">
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||

|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
|
||||
</div>
|
||||
|
||||
<br>
|
||||
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
@@ -118,7 +129,7 @@ uv run main.py
|
||||
| QQ(OneBot) | ✔ |
|
||||
| Telegram | ✔ |
|
||||
| 企微应用 | ✔ |
|
||||
| 企智机器人 | ✔ |
|
||||
| 企微智能机器人 | ✔ |
|
||||
| 微信客服 | ✔ |
|
||||
| 微信公众号 | ✔ |
|
||||
| 飞书 | ✔ |
|
||||
|
||||
@@ -209,9 +209,38 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
continue
|
||||
|
||||
valid_params = {} # 参数过滤:只传递函数实际需要的参数
|
||||
|
||||
# 获取实际的 handler 函数
|
||||
if func_tool.handler:
|
||||
logger.debug(
|
||||
f"工具 {func_tool_name} 期望的参数: {func_tool.parameters}"
|
||||
)
|
||||
if func_tool.parameters and func_tool.parameters.get("properties"):
|
||||
expected_params = set(func_tool.parameters["properties"].keys())
|
||||
|
||||
valid_params = {
|
||||
k: v
|
||||
for k, v in func_tool_args.items()
|
||||
if k in expected_params
|
||||
}
|
||||
|
||||
# 记录被忽略的参数
|
||||
ignored_params = set(func_tool_args.keys()) - set(
|
||||
valid_params.keys()
|
||||
)
|
||||
if ignored_params:
|
||||
logger.warning(
|
||||
f"工具 {func_tool_name} 忽略非期望参数: {ignored_params}"
|
||||
)
|
||||
else:
|
||||
# 如果没有 handler(如 MCP 工具),使用所有参数
|
||||
valid_params = func_tool_args
|
||||
logger.warning(f"工具 {func_tool_name} 没有 handler,使用所有参数")
|
||||
|
||||
try:
|
||||
await self.agent_hooks.on_tool_start(
|
||||
self.run_context, func_tool, func_tool_args
|
||||
self.run_context, func_tool, valid_params
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_tool_start hook: {e}", exc_info=True)
|
||||
@@ -219,7 +248,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
executor = self.tool_executor.execute(
|
||||
tool=func_tool,
|
||||
run_context=self.run_context,
|
||||
**func_tool_args,
|
||||
**valid_params, # 只传递有效的参数
|
||||
)
|
||||
|
||||
_final_resp: CallToolResult | None = None
|
||||
|
||||
@@ -5,6 +5,7 @@ from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.config.astrbot_config import ASTRBOT_CONFIG_PATH
|
||||
from astrbot.core.config.default import DEFAULT_CONFIG
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_config_path
|
||||
from typing import TypeVar, TypedDict
|
||||
|
||||
@@ -15,14 +16,12 @@ class ConfInfo(TypedDict):
|
||||
"""Configuration information for a specific session or platform."""
|
||||
|
||||
id: str # UUID of the configuration or "default"
|
||||
umop: list[str] # Unified Message Origin Pattern
|
||||
name: str
|
||||
path: str # File name to the configuration file
|
||||
|
||||
|
||||
DEFAULT_CONFIG_CONF_INFO = ConfInfo(
|
||||
id="default",
|
||||
umop=["::"],
|
||||
name="default",
|
||||
path=ASTRBOT_CONFIG_PATH,
|
||||
)
|
||||
@@ -31,8 +30,14 @@ DEFAULT_CONFIG_CONF_INFO = ConfInfo(
|
||||
class AstrBotConfigManager:
|
||||
"""A class to manage the system configuration of AstrBot, aka ACM"""
|
||||
|
||||
def __init__(self, default_config: AstrBotConfig, sp: SharedPreferences):
|
||||
def __init__(
|
||||
self,
|
||||
default_config: AstrBotConfig,
|
||||
ucr: UmopConfigRouter,
|
||||
sp: SharedPreferences,
|
||||
):
|
||||
self.sp = sp
|
||||
self.ucr = ucr
|
||||
self.confs: dict[str, AstrBotConfig] = {}
|
||||
"""uuid / "default" -> AstrBotConfig"""
|
||||
self.confs["default"] = default_config
|
||||
@@ -63,24 +68,15 @@ class AstrBotConfigManager:
|
||||
)
|
||||
continue
|
||||
|
||||
def _is_umo_match(self, p1: str, p2: str) -> bool:
|
||||
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
|
||||
p1_ls = p1.split(":")
|
||||
p2_ls = p2.split(":")
|
||||
|
||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||
return False # 非法格式
|
||||
|
||||
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
||||
|
||||
def _load_conf_mapping(self, umo: str | MessageSession) -> ConfInfo:
|
||||
"""获取指定 umo 的配置文件 uuid, 如果不存在则返回默认配置(返回 "default")
|
||||
|
||||
Returns:
|
||||
ConfInfo: 包含配置文件的 uuid, 路径和名称等信息, 是一个 dict 类型
|
||||
"""
|
||||
# uuid -> { "umop": list, "path": str, "name": str }
|
||||
# uuid -> { "path": str, "name": str }
|
||||
abconf_data = self._get_abconf_data()
|
||||
|
||||
if isinstance(umo, MessageSession):
|
||||
umo = str(umo)
|
||||
else:
|
||||
@@ -89,10 +85,13 @@ class AstrBotConfigManager:
|
||||
except Exception:
|
||||
return DEFAULT_CONFIG_CONF_INFO
|
||||
|
||||
for uuid_, meta in abconf_data.items():
|
||||
for pattern in meta["umop"]:
|
||||
if self._is_umo_match(pattern, umo):
|
||||
return ConfInfo(**meta, id=uuid_)
|
||||
conf_id = self.ucr.get_conf_id_for_umop(umo)
|
||||
if conf_id:
|
||||
meta = abconf_data.get(conf_id)
|
||||
if meta and isinstance(meta, dict):
|
||||
# the bind relation between umo and conf is defined in ucr now, so we remove "umop" here
|
||||
meta.pop("umop", None)
|
||||
return ConfInfo(**meta, id=conf_id)
|
||||
|
||||
return DEFAULT_CONFIG_CONF_INFO
|
||||
|
||||
@@ -100,23 +99,14 @@ class AstrBotConfigManager:
|
||||
self,
|
||||
abconf_path: str,
|
||||
abconf_id: str,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
abconf_name: str | None = None,
|
||||
) -> None:
|
||||
"""保存配置文件的映射关系"""
|
||||
for part in umo_parts:
|
||||
if isinstance(part, MessageSession):
|
||||
part = str(part)
|
||||
elif not isinstance(part, str):
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data = self.sp.get(
|
||||
"abconf_mapping", {}, scope="global", scope_id="global"
|
||||
)
|
||||
random_word = abconf_name or uuid.uuid4().hex[:8]
|
||||
abconf_data[abconf_id] = {
|
||||
"umop": umo_parts,
|
||||
"path": abconf_path,
|
||||
"name": random_word,
|
||||
}
|
||||
@@ -153,29 +143,26 @@ class AstrBotConfigManager:
|
||||
def get_conf_list(self) -> list[ConfInfo]:
|
||||
"""获取所有配置文件的元数据列表"""
|
||||
conf_list = []
|
||||
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
||||
abconf_mapping = self._get_abconf_data()
|
||||
for uuid_, meta in abconf_mapping.items():
|
||||
if not isinstance(meta, dict):
|
||||
continue
|
||||
meta.pop("umop", None)
|
||||
conf_list.append(ConfInfo(**meta, id=uuid_))
|
||||
conf_list.append(DEFAULT_CONFIG_CONF_INFO)
|
||||
return conf_list
|
||||
|
||||
def create_conf(
|
||||
self,
|
||||
umo_parts: list[str] | list[MessageSession],
|
||||
config: dict = DEFAULT_CONFIG,
|
||||
name: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
|
||||
|
||||
umo_parts 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
|
||||
"""
|
||||
conf_uuid = str(uuid.uuid4())
|
||||
conf_file_name = f"abconf_{conf_uuid}.json"
|
||||
conf_path = os.path.join(get_astrbot_config_path(), conf_file_name)
|
||||
conf = AstrBotConfig(config_path=conf_path, default_config=config)
|
||||
conf.save_config()
|
||||
self._save_conf_mapping(conf_file_name, conf_uuid, umo_parts, abconf_name=name)
|
||||
self._save_conf_mapping(conf_file_name, conf_uuid, abconf_name=name)
|
||||
self.confs[conf_uuid] = conf
|
||||
return conf_uuid
|
||||
|
||||
@@ -228,15 +215,12 @@ class AstrBotConfigManager:
|
||||
logger.info(f"成功删除配置文件 {conf_id}")
|
||||
return True
|
||||
|
||||
def update_conf_info(
|
||||
self, conf_id: str, name: str | None = None, umo_parts: list[str] | None = None
|
||||
) -> bool:
|
||||
def update_conf_info(self, conf_id: str, name: str | None = None) -> bool:
|
||||
"""更新配置文件信息
|
||||
|
||||
Args:
|
||||
conf_id: 配置文件的 UUID
|
||||
name: 新的配置文件名称 (可选)
|
||||
umo_parts: 新的 UMO 部分列表 (可选)
|
||||
|
||||
Returns:
|
||||
bool: 更新是否成功
|
||||
@@ -255,18 +239,6 @@ class AstrBotConfigManager:
|
||||
if name is not None:
|
||||
abconf_data[conf_id]["name"] = name
|
||||
|
||||
# 更新 UMO 部分
|
||||
if umo_parts is not None:
|
||||
# 验证 UMO 部分格式
|
||||
for part in umo_parts:
|
||||
if isinstance(part, MessageSession):
|
||||
part = str(part)
|
||||
elif not isinstance(part, str):
|
||||
raise ValueError(
|
||||
"umo_parts must be a list of strings or MessageSession instances"
|
||||
)
|
||||
abconf_data[conf_id]["umop"] = umo_parts
|
||||
|
||||
# 保存更新
|
||||
self.sp.put("abconf_mapping", abconf_data, scope="global", scope_id="global")
|
||||
self.abconf_data = abconf_data
|
||||
|
||||
@@ -134,27 +134,11 @@ DEFAULT_CONFIG = {
|
||||
"persona": [], # deprecated
|
||||
"timezone": "Asia/Shanghai",
|
||||
"callback_api_base": "",
|
||||
"default_kb_collection": "", # 默认知识库名称
|
||||
"default_kb_collection": "", # 默认知识库名称, 已经过时
|
||||
"plugin_set": ["*"], # "*" 表示使用所有可用的插件, 空列表表示不使用任何插件
|
||||
"knowledge_base": {
|
||||
"enabled": False, # 默认禁用,用户需要主动启用
|
||||
"embedding_provider_id": "", # 嵌入模型提供商 ID (为空时自动选择第一个)
|
||||
"rerank_provider_id": "", # 重排序模型提供商 ID (为空时自动选择第一个)
|
||||
"storage": {
|
||||
"files_path": "data/knowledge_base", # 文件存储路径
|
||||
"vector_db_path": "data/knowledge_base/vectors", # 向量数据库路径
|
||||
},
|
||||
"chunking": {
|
||||
"chunk_size": 512, # 文档块大小(字符数)
|
||||
"chunk_overlap": 50, # 文档块重叠大小(字符数)
|
||||
},
|
||||
"retrieval": {
|
||||
"top_k_dense": 50, # 密集检索返回结果数
|
||||
"top_k_sparse": 50, # 稀疏检索返回结果数
|
||||
"top_m_final": 5, # 最终融合后返回的结果数
|
||||
"enable_rerank": True, # 是否启用重排序
|
||||
},
|
||||
},
|
||||
"kb_names": [], # 默认知识库名称列表
|
||||
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
|
||||
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
|
||||
}
|
||||
|
||||
|
||||
@@ -181,10 +165,11 @@ CONFIG_METADATA_2 = {
|
||||
"enable": False,
|
||||
"appid": "",
|
||||
"secret": "",
|
||||
"is_sandbox": False,
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6196,
|
||||
},
|
||||
"QQ 个人号(aiocqhttp)": {
|
||||
"QQ 个人号(OneBot v11)": {
|
||||
"id": "default",
|
||||
"type": "aiocqhttp",
|
||||
"enable": False,
|
||||
@@ -192,7 +177,7 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_port": 6199,
|
||||
"ws_reverse_token": "",
|
||||
},
|
||||
"微信个人号(WeChatPadPro)": {
|
||||
"WeChatPadPro": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
"enable": False,
|
||||
@@ -287,6 +272,14 @@ CONFIG_METADATA_2 = {
|
||||
"misskey_default_visibility": "public",
|
||||
"misskey_local_only": False,
|
||||
"misskey_enable_chat": True,
|
||||
# download / security options
|
||||
"misskey_allow_insecure_downloads": False,
|
||||
"misskey_download_timeout": 15,
|
||||
"misskey_download_chunk_size": 65536,
|
||||
"misskey_max_download_bytes": None,
|
||||
"misskey_enable_file_upload": True,
|
||||
"misskey_upload_concurrency": 3,
|
||||
"misskey_upload_folder": "",
|
||||
},
|
||||
"Slack": {
|
||||
"id": "slack",
|
||||
@@ -311,8 +304,26 @@ CONFIG_METADATA_2 = {
|
||||
"satori_heartbeat_interval": 10,
|
||||
"satori_reconnect_delay": 5,
|
||||
},
|
||||
# "WebChat": {
|
||||
# "id": "webchat",
|
||||
# "type": "webchat",
|
||||
# "enable": False,
|
||||
# "webchat_link_path": "",
|
||||
# "webchat_present_type": "fullscreen",
|
||||
# },
|
||||
},
|
||||
"items": {
|
||||
# "webchat_link_path": {
|
||||
# "description": "链接路径",
|
||||
# "_special": "webchat_link_path",
|
||||
# "type": "string",
|
||||
# },
|
||||
# "webchat_present_type": {
|
||||
# "_special": "webchat_present_type",
|
||||
# "description": "展现形式",
|
||||
# "type": "string",
|
||||
# "options": ["fullscreen", "embedded"],
|
||||
# },
|
||||
"satori_api_base_url": {
|
||||
"description": "Satori API 终结点",
|
||||
"type": "string",
|
||||
@@ -415,6 +426,41 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人将会监听和响应私信聊天消息",
|
||||
},
|
||||
"misskey_enable_file_upload": {
|
||||
"description": "启用文件上传到 Misskey",
|
||||
"type": "bool",
|
||||
"hint": "启用后,适配器会尝试将消息链中的文件上传到 Misskey。URL 文件会先尝试服务器端上传,异步上传失败时会回退到下载后本地上传。",
|
||||
},
|
||||
"misskey_allow_insecure_downloads": {
|
||||
"description": "允许不安全下载(禁用 SSL 验证)",
|
||||
"type": "bool",
|
||||
"hint": "当远端服务器存在证书问题导致无法正常下载时,自动禁用 SSL 验证作为回退方案。适用于某些图床的证书配置问题。启用有安全风险,仅在必要时使用。",
|
||||
},
|
||||
"misskey_download_timeout": {
|
||||
"description": "远端下载超时时间(秒)",
|
||||
"type": "int",
|
||||
"hint": "下载远程文件时的超时时间(秒),用于异步上传回退到本地上传的场景。",
|
||||
},
|
||||
"misskey_download_chunk_size": {
|
||||
"description": "流式下载分块大小(字节)",
|
||||
"type": "int",
|
||||
"hint": "流式下载和计算 MD5 时使用的每次读取字节数,过小会增加开销,过大会占用内存。",
|
||||
},
|
||||
"misskey_max_download_bytes": {
|
||||
"description": "最大允许下载字节数(超出则中止)",
|
||||
"type": "int",
|
||||
"hint": "如果希望限制下载文件的最大大小以防止 OOM,请填写最大字节数;留空或 null 表示不限制。",
|
||||
},
|
||||
"misskey_upload_concurrency": {
|
||||
"description": "并发上传限制",
|
||||
"type": "int",
|
||||
"hint": "同时进行的文件上传任务上限(整数,默认 3)。",
|
||||
},
|
||||
"misskey_upload_folder": {
|
||||
"description": "上传到网盘的目标文件夹 ID",
|
||||
"type": "string",
|
||||
"hint": "可选:填写 Misskey 网盘中目标文件夹的 ID,上传的文件将放置到该文件夹内。留空则使用账号网盘根目录。",
|
||||
},
|
||||
"telegram_command_register": {
|
||||
"description": "Telegram 命令注册",
|
||||
"type": "bool",
|
||||
@@ -466,19 +512,18 @@ CONFIG_METADATA_2 = {
|
||||
"hint": "启用后,机器人可以接收到频道的私聊消息。",
|
||||
},
|
||||
"ws_reverse_host": {
|
||||
"description": "反向 Websocket 主机地址(AstrBot 为服务器端)",
|
||||
"description": "反向 Websocket 主机",
|
||||
"type": "string",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号。",
|
||||
"hint": "AstrBot 将作为服务器端。",
|
||||
},
|
||||
"ws_reverse_port": {
|
||||
"description": "反向 Websocket 端口",
|
||||
"type": "int",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
|
||||
},
|
||||
"ws_reverse_token": {
|
||||
"description": "反向 Websocket Token",
|
||||
"type": "string",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
|
||||
"hint": "反向 Websocket Token。未设置则不启用 Token 验证。",
|
||||
},
|
||||
"wecom_ai_bot_name": {
|
||||
"description": "企业微信智能机器人的名字",
|
||||
@@ -2019,6 +2064,9 @@ CONFIG_METADATA_2 = {
|
||||
"default_kb_collection": {
|
||||
"type": "string",
|
||||
},
|
||||
"kb_names": {"type": "list", "items": {"type": "string"}},
|
||||
"kb_fusion_top_k": {"type": "int", "default": 20},
|
||||
"kb_final_top_k": {"type": "int", "default": 5},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -2097,10 +2145,22 @@ CONFIG_METADATA_3 = {
|
||||
"description": "知识库",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"default_kb_collection": {
|
||||
"description": "默认使用的知识库",
|
||||
"type": "string",
|
||||
"kb_names": {
|
||||
"description": "知识库列表",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"_special": "select_knowledgebase",
|
||||
"hint": "支持多选",
|
||||
},
|
||||
"kb_fusion_top_k": {
|
||||
"description": "融合检索结果数",
|
||||
"type": "int",
|
||||
"hint": "多个知识库检索结果融合后的返回结果数量",
|
||||
},
|
||||
"kb_final_top_k": {
|
||||
"description": "最终返回结果数",
|
||||
"type": "int",
|
||||
"hint": "从知识库中检索到的结果数量,越大可能获得越多相关信息,但也可能引入噪音。建议根据实际需求调整",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -2194,7 +2254,7 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
"hint": "例子: 如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
|
||||
"hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
|
||||
},
|
||||
"provider_settings.prompt_prefix": {
|
||||
"description": "用户提示词",
|
||||
|
||||
@@ -17,7 +17,6 @@ import os
|
||||
from .event_bus import EventBus
|
||||
from . import astrbot_config, html_renderer
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
@@ -26,15 +25,17 @@ from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger, sp
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.conversation_mgr import ConversationManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star_handler import star_map
|
||||
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
|
||||
class AstrBotCoreLifecycle:
|
||||
@@ -85,11 +86,21 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
await html_renderer.initialize()
|
||||
|
||||
# 初始化 UMOP 配置路由器
|
||||
self.umop_config_router = UmopConfigRouter(sp=sp)
|
||||
|
||||
# 初始化 AstrBot 配置管理器
|
||||
self.astrbot_config_mgr = AstrBotConfigManager(
|
||||
default_config=self.astrbot_config, sp=sp
|
||||
default_config=self.astrbot_config, ucr=self.umop_config_router, sp=sp
|
||||
)
|
||||
|
||||
# 4.5 to 4.6 migration for umop_config_router
|
||||
try:
|
||||
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 初始化事件队列
|
||||
self.event_queue = Queue()
|
||||
|
||||
@@ -112,9 +123,7 @@ class AstrBotCoreLifecycle:
|
||||
self.platform_message_history_manager = PlatformMessageHistoryManager(self.db)
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.kb_manager = KnowledgeBaseManager(
|
||||
self.astrbot_config, self.db, self.provider_manager
|
||||
)
|
||||
self.kb_manager = KnowledgeBaseManager(self.provider_manager)
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
@@ -141,10 +150,6 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
await self.kb_manager.initialize()
|
||||
|
||||
# 注册知识库会话生命周期钩子(零侵入级联清理)
|
||||
if self.kb_manager.is_initialized:
|
||||
self.kb_manager.register_session_lifecycle_hooks(self.conversation_manager)
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
self.pipeline_scheduler_mapping = await self.load_pipeline_scheduler()
|
||||
|
||||
@@ -160,7 +165,7 @@ class AstrBotCoreLifecycle:
|
||||
self.start_time = int(time.time())
|
||||
|
||||
# 初始化当前任务列表
|
||||
self.curr_tasks: List[asyncio.Task] = []
|
||||
self.curr_tasks: list[asyncio.Task] = []
|
||||
|
||||
# 根据配置实例化各个平台适配器
|
||||
await self.platform_manager.initialize()
|
||||
@@ -267,7 +272,7 @@ class AstrBotCoreLifecycle:
|
||||
target=self.astrbot_updator._reboot, name="restart", daemon=True
|
||||
).start()
|
||||
|
||||
def load_platform(self) -> List[asyncio.Task]:
|
||||
def load_platform(self) -> list[asyncio.Task]:
|
||||
"""加载平台实例并返回所有平台实例的异步任务列表"""
|
||||
tasks = []
|
||||
platform_insts = self.platform_manager.get_insts()
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
|
||||
|
||||
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
|
||||
abconf_data = acm.abconf_data
|
||||
|
||||
if not isinstance(abconf_data, dict):
|
||||
# should be unreachable
|
||||
logger.warning(
|
||||
f"migrate_45_to_46: abconf_data is not a dict (type={type(abconf_data)}). Value: {abconf_data!r}"
|
||||
)
|
||||
return
|
||||
|
||||
# 如果任何一项带有 umop,则说明需要迁移
|
||||
need_migration = False
|
||||
for conf_id, conf_info in abconf_data.items():
|
||||
if isinstance(conf_info, dict) and "umop" in conf_info:
|
||||
need_migration = True
|
||||
break
|
||||
|
||||
if not need_migration:
|
||||
return
|
||||
|
||||
logger.info("Starting migration from version 4.5 to 4.6")
|
||||
|
||||
# extract umo->conf_id mapping
|
||||
umo_to_conf_id = {}
|
||||
for conf_id, conf_info in abconf_data.items():
|
||||
if isinstance(conf_info, dict) and "umop" in conf_info:
|
||||
umop_ls = conf_info.pop("umop")
|
||||
if not isinstance(umop_ls, list):
|
||||
continue
|
||||
for umo in umop_ls:
|
||||
if isinstance(umo, str) and umo not in umo_to_conf_id:
|
||||
umo_to_conf_id[umo] = conf_id
|
||||
|
||||
# update the abconf data
|
||||
await sp.global_put("abconf_mapping", abconf_data)
|
||||
# update the umop config router
|
||||
await ucr.update_routing_data(umo_to_conf_id)
|
||||
|
||||
logger.info("Migration from version 45 to 46 completed successfully")
|
||||
@@ -16,14 +16,42 @@ class BaseVecDB:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert(self, content: str, metadata: dict = None, id: str = None) -> int:
|
||||
async def insert(
|
||||
self, content: str, metadata: dict | None = None, id: str | None = None
|
||||
) -> int:
|
||||
"""
|
||||
插入一条文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def retrieve(self, query: str, top_k: int = 5) -> list[Result]:
|
||||
async def insert_batch(
|
||||
self,
|
||||
contents: list[str],
|
||||
metadatas: list[dict] | None = None,
|
||||
ids: list[str] | None = None,
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> int:
|
||||
"""
|
||||
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
fetch_k: int = 20,
|
||||
rerank: bool = False,
|
||||
metadata_filters: dict | None = None,
|
||||
) -> list[Result]:
|
||||
"""
|
||||
搜索最相似的文档。
|
||||
Args:
|
||||
@@ -44,3 +72,6 @@ class BaseVecDB:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def close(self): ...
|
||||
|
||||
@@ -1,59 +1,219 @@
|
||||
import aiosqlite
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from sqlalchemy import Text, Column
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import Field, SQLModel, select, col, func, text, MetaData
|
||||
|
||||
|
||||
class BaseDocModel(SQLModel, table=False):
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
class Document(BaseDocModel, table=True):
|
||||
"""SQLModel for documents table."""
|
||||
|
||||
__tablename__ = "documents" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
doc_id: str = Field(nullable=False)
|
||||
text: str = Field(nullable=False)
|
||||
metadata_: str | None = Field(default=None, sa_column=Column("metadata", Text))
|
||||
created_at: datetime | None = Field(default=None)
|
||||
updated_at: datetime | None = Field(default=None)
|
||||
|
||||
|
||||
class DocumentStorage:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.connection = None
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.engine: AsyncEngine | None = None
|
||||
self.async_session_maker: sessionmaker | None = None
|
||||
self.sqlite_init_path = os.path.join(
|
||||
os.path.dirname(__file__), "sqlite_init.sql"
|
||||
)
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
|
||||
if not os.path.exists(self.db_path):
|
||||
await self.connect()
|
||||
async with self.connection.cursor() as cursor:
|
||||
with open(self.sqlite_init_path, "r", encoding="utf-8") as f:
|
||||
sql_script = f.read()
|
||||
await cursor.executescript(sql_script)
|
||||
await self.connection.commit()
|
||||
else:
|
||||
await self.connect()
|
||||
await self.connect()
|
||||
async with self.engine.begin() as conn: # type: ignore
|
||||
# Create tables using SQLModel
|
||||
await conn.run_sync(BaseDocModel.metadata.create_all)
|
||||
|
||||
try:
|
||||
await conn.execute(
|
||||
text(
|
||||
"ALTER TABLE documents ADD COLUMN kb_doc_id TEXT "
|
||||
"GENERATED ALWAYS AS (json_extract(metadata, '$.kb_doc_id')) STORED"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"ALTER TABLE documents ADD COLUMN user_id TEXT "
|
||||
"GENERATED ALWAYS AS (json_extract(metadata, '$.user_id')) STORED"
|
||||
)
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_kb_doc_id ON documents(kb_doc_id)"
|
||||
)
|
||||
)
|
||||
await conn.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_documents_user_id ON documents(user_id)"
|
||||
)
|
||||
)
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
await conn.commit()
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the SQLite database."""
|
||||
self.connection = await aiosqlite.connect(self.db_path)
|
||||
if self.engine is None:
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
future=True,
|
||||
)
|
||||
self.async_session_maker = sessionmaker(
|
||||
self.engine, # type: ignore
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
) # type: ignore
|
||||
|
||||
async def get_documents(self, metadata_filters: dict, ids: list = None):
|
||||
@asynccontextmanager
|
||||
async def get_session(self):
|
||||
"""Context manager for database sessions."""
|
||||
async with self.async_session_maker() as session: # type: ignore
|
||||
yield session
|
||||
|
||||
async def get_documents(
|
||||
self,
|
||||
metadata_filters: dict,
|
||||
ids: list | None = None,
|
||||
offset: int | None = 0,
|
||||
limit: int | None = 100,
|
||||
) -> list[dict]:
|
||||
"""Retrieve documents by metadata filters and ids.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
ids (list | None): Optional list of document IDs to filter.
|
||||
offset (int | None): Offset for pagination.
|
||||
limit (int | None): Limit for pagination.
|
||||
|
||||
Returns:
|
||||
list: The list of document IDs(primary key, not doc_id) that match the filters.
|
||||
list: The list of documents that match the filters.
|
||||
"""
|
||||
# metadata filter -> SQL WHERE clause
|
||||
where_clauses = []
|
||||
values = []
|
||||
for key, val in metadata_filters.items():
|
||||
where_clauses.append(f"json_extract(metadata, '$.{key}') = ?")
|
||||
values.append(val)
|
||||
if ids is not None and len(ids) > 0:
|
||||
ids = [str(i) for i in ids if i != -1]
|
||||
where_clauses.append("id IN ({})".format(",".join("?" * len(ids))))
|
||||
values.extend(ids)
|
||||
where_sql = " AND ".join(where_clauses) or "1=1"
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
result = []
|
||||
async with self.connection.cursor() as cursor:
|
||||
sql = "SELECT * FROM documents WHERE " + where_sql
|
||||
await cursor.execute(sql, values)
|
||||
for row in await cursor.fetchall():
|
||||
result.append(await self.tuple_to_dict(row))
|
||||
return result
|
||||
async with self.get_session() as session:
|
||||
query = select(Document)
|
||||
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
if ids is not None and len(ids) > 0:
|
||||
valid_ids = [int(i) for i in ids if i != -1]
|
||||
if valid_ids:
|
||||
query = query.where(col(Document.id).in_(valid_ids))
|
||||
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
if offset is not None:
|
||||
query = query.offset(offset)
|
||||
|
||||
result = await session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
return [self._document_to_dict(doc) for doc in documents]
|
||||
|
||||
async def insert_document(self, doc_id: str, text: str, metadata: dict) -> int:
|
||||
"""Insert a single document and return its integer ID.
|
||||
|
||||
Args:
|
||||
doc_id (str): The document ID (UUID string).
|
||||
text (str): The document text.
|
||||
metadata (dict): The document metadata.
|
||||
|
||||
Returns:
|
||||
int: The integer ID of the inserted document.
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
text=text,
|
||||
metadata_=json.dumps(metadata),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
session.add(document)
|
||||
await session.flush() # Flush to get the ID
|
||||
return document.id # type: ignore
|
||||
|
||||
async def insert_documents_batch(
|
||||
self, doc_ids: list[str], texts: list[str], metadatas: list[dict]
|
||||
) -> list[int]:
|
||||
"""Batch insert documents and return their integer IDs.
|
||||
|
||||
Args:
|
||||
doc_ids (list[str]): List of document IDs (UUID strings).
|
||||
texts (list[str]): List of document texts.
|
||||
metadatas (list[dict]): List of document metadata.
|
||||
|
||||
Returns:
|
||||
list[int]: List of integer IDs of the inserted documents.
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
import json
|
||||
|
||||
documents = []
|
||||
for doc_id, text, metadata in zip(doc_ids, texts, metadatas):
|
||||
document = Document(
|
||||
doc_id=doc_id,
|
||||
text=text,
|
||||
metadata_=json.dumps(metadata),
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
documents.append(document)
|
||||
session.add(document)
|
||||
|
||||
await session.flush() # Flush to get all IDs
|
||||
return [doc.id for doc in documents] # type: ignore
|
||||
|
||||
async def delete_document_by_doc_id(self, doc_id: str):
|
||||
"""Delete a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id of the document to delete.
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if document:
|
||||
await session.delete(document)
|
||||
|
||||
async def get_document_by_doc_id(self, doc_id: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
@@ -62,28 +222,85 @@ class DocumentStorage:
|
||||
doc_id (str): The doc_id of the document to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: The document data.
|
||||
dict: The document data or None if not found.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT * FROM documents WHERE doc_id = ?", (doc_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return await self.tuple_to_dict(row)
|
||||
else:
|
||||
return None
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if document:
|
||||
return self._document_to_dict(document)
|
||||
return None
|
||||
|
||||
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
|
||||
"""Retrieve a document by its doc_id.
|
||||
"""Update a document by its doc_id.
|
||||
|
||||
Args:
|
||||
doc_id (str): The doc_id.
|
||||
new_text (str): The new text to update the document with.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"UPDATE documents SET text = ? WHERE doc_id = ?", (new_text, doc_id)
|
||||
)
|
||||
await self.connection.commit()
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document).where(col(Document.doc_id) == doc_id)
|
||||
result = await session.execute(query)
|
||||
document = result.scalar_one_or_none()
|
||||
|
||||
if document:
|
||||
document.text = new_text
|
||||
document.updated_at = datetime.now()
|
||||
session.add(document)
|
||||
|
||||
async def delete_documents(self, metadata_filters: dict):
|
||||
"""Delete documents by their metadata filters.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict): The metadata filters to apply.
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
async with session.begin():
|
||||
query = select(Document)
|
||||
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
result = await session.execute(query)
|
||||
documents = result.scalars().all()
|
||||
|
||||
for doc in documents:
|
||||
await session.delete(doc)
|
||||
|
||||
async def count_documents(self, metadata_filters: dict | None = None) -> int:
|
||||
"""Count documents in the database.
|
||||
|
||||
Args:
|
||||
metadata_filters (dict | None): Metadata filters to apply.
|
||||
|
||||
Returns:
|
||||
int: The count of documents.
|
||||
"""
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
query = select(func.count(col(Document.id)))
|
||||
|
||||
if metadata_filters:
|
||||
for key, val in metadata_filters.items():
|
||||
query = query.where(
|
||||
text(f"json_extract(metadata, '$.{key}') = :filter_{key}")
|
||||
).params(**{f"filter_{key}": val})
|
||||
|
||||
result = await session.execute(query)
|
||||
count = result.scalar_one_or_none()
|
||||
return count if count is not None else 0
|
||||
|
||||
async def get_user_ids(self) -> list[str]:
|
||||
"""Retrieve all user IDs from the documents table.
|
||||
@@ -91,11 +308,38 @@ class DocumentStorage:
|
||||
Returns:
|
||||
list: A list of user IDs.
|
||||
"""
|
||||
async with self.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT DISTINCT user_id FROM documents")
|
||||
rows = await cursor.fetchall()
|
||||
assert self.engine is not None, "Database connection is not initialized."
|
||||
|
||||
async with self.get_session() as session:
|
||||
query = text(
|
||||
"SELECT DISTINCT user_id FROM documents WHERE user_id IS NOT NULL"
|
||||
)
|
||||
result = await session.execute(query)
|
||||
rows = result.fetchall()
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def _document_to_dict(self, document: Document) -> dict:
|
||||
"""Convert a Document model to a dictionary.
|
||||
|
||||
Args:
|
||||
document (Document): The document to convert.
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
"""
|
||||
return {
|
||||
"id": document.id,
|
||||
"doc_id": document.doc_id,
|
||||
"text": document.text,
|
||||
"metadata": document.metadata_,
|
||||
"created_at": document.created_at.isoformat()
|
||||
if isinstance(document.created_at, datetime)
|
||||
else document.created_at,
|
||||
"updated_at": document.updated_at.isoformat()
|
||||
if isinstance(document.updated_at, datetime)
|
||||
else document.updated_at,
|
||||
}
|
||||
|
||||
async def tuple_to_dict(self, row):
|
||||
"""Convert a tuple to a dictionary.
|
||||
|
||||
@@ -104,6 +348,8 @@ class DocumentStorage:
|
||||
|
||||
Returns:
|
||||
dict: The converted dictionary.
|
||||
|
||||
Note: This method is kept for backward compatibility but is no longer used internally.
|
||||
"""
|
||||
return {
|
||||
"id": row[0],
|
||||
@@ -116,6 +362,7 @@ class DocumentStorage:
|
||||
|
||||
async def close(self):
|
||||
"""Close the connection to the SQLite database."""
|
||||
if self.connection:
|
||||
await self.connection.close()
|
||||
self.connection = None
|
||||
if self.engine:
|
||||
await self.engine.dispose()
|
||||
self.engine = None
|
||||
self.async_session_maker = None
|
||||
|
||||
@@ -9,7 +9,7 @@ import numpy as np
|
||||
|
||||
|
||||
class EmbeddingStorage:
|
||||
def __init__(self, dimension: int, path: str = None):
|
||||
def __init__(self, dimension: int, path: str | None = None):
|
||||
self.dimension = dimension
|
||||
self.path = path
|
||||
self.index = None
|
||||
@@ -18,7 +18,6 @@ class EmbeddingStorage:
|
||||
else:
|
||||
base_index = faiss.IndexFlatL2(dimension)
|
||||
self.index = faiss.IndexIDMap(base_index)
|
||||
self.storage = {}
|
||||
|
||||
async def insert(self, vector: np.ndarray, id: int):
|
||||
"""插入向量
|
||||
@@ -29,12 +28,29 @@ class EmbeddingStorage:
|
||||
Raises:
|
||||
ValueError: 如果向量的维度与存储的维度不匹配
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
if vector.shape[0] != self.dimension:
|
||||
raise ValueError(
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}"
|
||||
)
|
||||
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
|
||||
self.storage[id] = vector
|
||||
await self.save_index()
|
||||
|
||||
async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
|
||||
"""批量插入向量
|
||||
|
||||
Args:
|
||||
vectors (np.ndarray): 要插入的向量数组
|
||||
ids (list[int]): 向量的ID列表
|
||||
Raises:
|
||||
ValueError: 如果向量的维度与存储的维度不匹配
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
if vectors.shape[1] != self.dimension:
|
||||
raise ValueError(
|
||||
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}"
|
||||
)
|
||||
self.index.add_with_ids(vectors, np.array(ids))
|
||||
await self.save_index()
|
||||
|
||||
async def search(self, vector: np.ndarray, k: int) -> tuple:
|
||||
@@ -46,10 +62,22 @@ class EmbeddingStorage:
|
||||
Returns:
|
||||
tuple: (距离, 索引)
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
faiss.normalize_L2(vector)
|
||||
distances, indices = self.index.search(vector, k)
|
||||
return distances, indices
|
||||
|
||||
async def delete(self, ids: list[int]):
|
||||
"""删除向量
|
||||
|
||||
Args:
|
||||
ids (list[int]): 要删除的向量ID列表
|
||||
"""
|
||||
assert self.index is not None, "FAISS index is not initialized."
|
||||
id_array = np.array(ids, dtype=np.int64)
|
||||
self.index.remove_ids(id_array)
|
||||
await self.save_index()
|
||||
|
||||
async def save_index(self):
|
||||
"""保存索引
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
from .document_storage import DocumentStorage
|
||||
from .embedding_storage import EmbeddingStorage
|
||||
from ..base import Result, BaseVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
class FaissVecDB(BaseVecDB):
|
||||
@@ -44,18 +45,56 @@ class FaissVecDB(BaseVecDB):
|
||||
|
||||
vector = await self.embedding_provider.get_embedding(content)
|
||||
vector = np.array(vector, dtype=np.float32)
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"INSERT INTO documents (doc_id, text, metadata) VALUES (?, ?, ?)",
|
||||
(str_id, content, json.dumps(metadata)),
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
result = await self.document_storage.get_document_by_doc_id(str_id)
|
||||
int_id = result["id"]
|
||||
|
||||
# 插入向量到 FAISS
|
||||
await self.embedding_storage.insert(vector, int_id)
|
||||
return int_id
|
||||
# 使用 DocumentStorage 的方法插入文档
|
||||
int_id = await self.document_storage.insert_document(str_id, content, metadata)
|
||||
|
||||
# 插入向量到 FAISS
|
||||
await self.embedding_storage.insert(vector, int_id)
|
||||
return int_id
|
||||
|
||||
async def insert_batch(
|
||||
self,
|
||||
contents: list[str],
|
||||
metadatas: list[dict] | None = None,
|
||||
ids: list[str] | None = None,
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
批量插入文本和其对应向量,自动生成 ID 并保持一致性。
|
||||
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
"""
|
||||
metadatas = metadatas or [{} for _ in contents]
|
||||
ids = ids or [str(uuid.uuid4()) for _ in contents]
|
||||
|
||||
start = time.time()
|
||||
logger.debug(f"Generating embeddings for {len(contents)} contents...")
|
||||
vectors = await self.embedding_provider.get_embeddings_batch(
|
||||
contents,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
end = time.time()
|
||||
logger.debug(
|
||||
f"Generated embeddings for {len(contents)} contents in {end - start:.2f} seconds."
|
||||
)
|
||||
|
||||
# 使用 DocumentStorage 的批量插入方法
|
||||
int_ids = await self.document_storage.insert_documents_batch(
|
||||
ids, contents, metadatas
|
||||
)
|
||||
|
||||
# 批量插入向量到 FAISS
|
||||
vectors_array = np.array(vectors).astype("float32")
|
||||
await self.embedding_storage.insert_batch(vectors_array, int_ids)
|
||||
return int_ids
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
@@ -119,23 +158,42 @@ class FaissVecDB(BaseVecDB):
|
||||
|
||||
return top_k_results
|
||||
|
||||
async def delete(self, doc_id: int):
|
||||
async def delete(self, doc_id: str):
|
||||
"""
|
||||
删除一条文档
|
||||
删除一条文档块(chunk)
|
||||
"""
|
||||
await self.document_storage.connection.execute(
|
||||
"DELETE FROM documents WHERE doc_id = ?", (doc_id,)
|
||||
)
|
||||
await self.document_storage.connection.commit()
|
||||
# 获得对应的 int id
|
||||
result = await self.document_storage.get_document_by_doc_id(doc_id)
|
||||
int_id = result["id"] if result else None
|
||||
if int_id is None:
|
||||
return
|
||||
|
||||
# 使用 DocumentStorage 的删除方法
|
||||
await self.document_storage.delete_document_by_doc_id(doc_id)
|
||||
await self.embedding_storage.delete([int_id])
|
||||
|
||||
async def close(self):
|
||||
await self.document_storage.close()
|
||||
|
||||
async def count_documents(self) -> int:
|
||||
async def count_documents(self, metadata_filter: dict | None = None) -> int:
|
||||
"""
|
||||
计算文档数量
|
||||
|
||||
Args:
|
||||
metadata_filter (dict | None): 元数据过滤器
|
||||
"""
|
||||
async with self.document_storage.connection.cursor() as cursor:
|
||||
await cursor.execute("SELECT COUNT(*) FROM documents")
|
||||
count = await cursor.fetchone()
|
||||
return count[0] if count else 0
|
||||
count = await self.document_storage.count_documents(
|
||||
metadata_filters=metadata_filter or {}
|
||||
)
|
||||
return count
|
||||
|
||||
async def delete_documents(self, metadata_filters: dict):
|
||||
"""
|
||||
根据元数据过滤器删除文档
|
||||
"""
|
||||
docs = await self.document_storage.get_documents(
|
||||
metadata_filters=metadata_filters, offset=None, limit=None
|
||||
)
|
||||
doc_ids: list[int] = [doc["id"] for doc in docs]
|
||||
await self.embedding_storage.delete(doc_ids)
|
||||
await self.document_storage.delete_documents(metadata_filters=metadata_filters)
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
"""
|
||||
知识库管理模块
|
||||
|
||||
提供文档上传、解析、分块、向量化、检索等功能
|
||||
"""
|
||||
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KBSessionConfig,
|
||||
KnowledgeBase,
|
||||
)
|
||||
|
||||
# 注意: 以下导入在对应模块实现后取消注释
|
||||
from .database import KBDatabase
|
||||
from .manager import KBManager
|
||||
from .manager_ops import KBManagerOps
|
||||
from .session_config_db import SessionConfigDB
|
||||
|
||||
# from .injector import KnowledgeBaseInjector
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeBase",
|
||||
"KBDocument",
|
||||
"KBChunk",
|
||||
"KBMedia",
|
||||
"KBSessionConfig",
|
||||
"KBDatabase",
|
||||
"SessionConfigDB",
|
||||
"KBManager",
|
||||
"KBManagerOps",
|
||||
# "KnowledgeBaseInjector",
|
||||
]
|
||||
@@ -13,7 +13,7 @@ class BaseChunker(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def chunk(self, text: str) -> list[str]:
|
||||
async def chunk(self, text: str, **kwargs) -> list[str]:
|
||||
"""将文本分块
|
||||
|
||||
Args:
|
||||
|
||||
@@ -22,7 +22,7 @@ class FixedSizeChunker(BaseChunker):
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
|
||||
async def chunk(self, text: str) -> list[str]:
|
||||
async def chunk(self, text: str, **kwargs) -> list[str]:
|
||||
"""固定大小分块
|
||||
|
||||
Args:
|
||||
@@ -31,22 +31,25 @@ class FixedSizeChunker(BaseChunker):
|
||||
Returns:
|
||||
list[str]: 分块后的文本列表
|
||||
"""
|
||||
chunk_size = kwargs.get("chunk_size", self.chunk_size)
|
||||
chunk_overlap = kwargs.get("chunk_overlap", self.chunk_overlap)
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
text_len = len(text)
|
||||
|
||||
while start < text_len:
|
||||
end = start + self.chunk_size
|
||||
end = start + chunk_size
|
||||
chunk = text[start:end]
|
||||
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
# 移动窗口,保留重叠部分
|
||||
start = end - self.chunk_overlap
|
||||
start = end - chunk_overlap
|
||||
|
||||
# 防止无限循环: 如果重叠过大,直接移到end
|
||||
if start >= end or self.chunk_overlap >= self.chunk_size:
|
||||
if start >= end or chunk_overlap >= chunk_size:
|
||||
start = end
|
||||
|
||||
return chunks
|
||||
|
||||
@@ -1,347 +0,0 @@
|
||||
"""知识库数据库操作类
|
||||
|
||||
该模块封装知识库、文档、块、多媒体和会话配置相关的数据库查询操作。
|
||||
|
||||
注意:
|
||||
- 该模块操作的是独立的知识库数据库 (data/knowledge_base/kb.db)
|
||||
- 会话配置也存储在此数据库中,会话ID来源于主数据库
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KBSessionConfig,
|
||||
KnowledgeBase,
|
||||
)
|
||||
|
||||
|
||||
class KBDatabase:
|
||||
"""知识库数据库操作类
|
||||
|
||||
职责:
|
||||
- 封装知识库、文档、块、多媒体和会话配置的数据库查询操作
|
||||
- 统一异常处理
|
||||
|
||||
注意:
|
||||
- 该类操作独立的知识库数据库 (kb.db)
|
||||
- 会话配置存储会话ID与知识库的绑定关系,会话ID来源于主数据库
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBSQLiteDatabase):
|
||||
"""初始化知识库数据库操作类
|
||||
|
||||
Args:
|
||||
kb_db: 知识库独立数据库实例,而非主数据库
|
||||
"""
|
||||
self.db = kb_db
|
||||
|
||||
# ===== 知识库查询 =====
|
||||
|
||||
async def get_kb_by_id(self, kb_id: str) -> Optional[KnowledgeBase]:
|
||||
"""根据 ID 获取知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_kb_by_name(self, kb_name: str) -> Optional[KnowledgeBase]:
|
||||
"""根据名称获取知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_name == kb_name)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KnowledgeBase)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KnowledgeBase.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_kbs(self) -> int:
|
||||
"""统计知识库数量"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(func.count(KnowledgeBase.id))
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
# ===== 文档查询 =====
|
||||
|
||||
async def get_document_by_id(self, doc_id: str) -> Optional[KBDocument]:
|
||||
"""根据 ID 获取文档"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_documents_by_kb(
|
||||
self, kb_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument)
|
||||
.where(KBDocument.kb_id == kb_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBDocument.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_documents_by_kb(self, kb_id: str) -> int:
|
||||
"""统计知识库的文档数量"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(func.count(KBDocument.id)).where(KBDocument.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
# ===== 块查询 =====
|
||||
|
||||
async def get_chunk_by_id(self, chunk_id: str) -> Optional[KBChunk]:
|
||||
"""根据 ID 获取块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_chunks_by_kb_ids(self, kb_ids: list[str]) -> list[KBChunk]:
|
||||
"""根据知识库 ID 列表获取所有块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.kb_id.in_(kb_ids))
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_chunk_by_vec_doc_id(self, vec_doc_id: str) -> Optional[KBChunk]:
|
||||
"""根据向量文档 ID 获取块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.vec_doc_id == vec_doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_chunk_with_metadata(self, chunk_id: str) -> Optional[dict]:
|
||||
"""获取块及其关联的文档和知识库元数据"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk, KBDocument, KnowledgeBase)
|
||||
.join(KBDocument, KBChunk.doc_id == KBDocument.doc_id)
|
||||
.join(KnowledgeBase, KBChunk.kb_id == KnowledgeBase.kb_id)
|
||||
.where(KBChunk.chunk_id == chunk_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
chunk, doc, kb = row
|
||||
return {
|
||||
"chunk": chunk,
|
||||
"document": doc,
|
||||
"knowledge_base": kb,
|
||||
}
|
||||
|
||||
async def list_chunks_by_doc(
|
||||
self, doc_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBChunk]:
|
||||
"""列出文档的所有块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk)
|
||||
.where(KBChunk.doc_id == doc_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBChunk.chunk_index)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
# ===== 多媒体查询 =====
|
||||
|
||||
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
|
||||
"""列出文档的所有多媒体资源"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_media_by_id(self, media_id: str) -> Optional[KBMedia]:
|
||||
"""根据 ID 获取多媒体资源"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
# ===== 会话配置查询 =====
|
||||
|
||||
async def get_session_kb_ids(self, session_id: str) -> list[str]:
|
||||
"""获取会话关联的知识库 ID 列表
|
||||
|
||||
查找顺序:
|
||||
1. 会话级别配置 (优先)
|
||||
2. 平台级别配置
|
||||
3. 返回空列表
|
||||
|
||||
Args:
|
||||
session_id: 会话ID(来自主数据库)
|
||||
|
||||
Returns:
|
||||
知识库ID列表
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 1. 查找会话级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "session",
|
||||
KBSessionConfig.scope_id == session_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 2. 提取平台 ID (格式: platform:xxx:session_id)
|
||||
parts = session_id.split(":")
|
||||
if len(parts) >= 2:
|
||||
platform_id = parts[0]
|
||||
|
||||
# 查找平台级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "platform",
|
||||
KBSessionConfig.scope_id == platform_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 3. 无配置
|
||||
return []
|
||||
|
||||
async def set_session_kb_ids(
|
||||
self,
|
||||
scope: str,
|
||||
scope_id: str,
|
||||
kb_ids: list[str],
|
||||
top_k: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> KBSessionConfig:
|
||||
"""设置会话知识库配置
|
||||
|
||||
Args:
|
||||
scope: 配置范围 (session/platform)
|
||||
scope_id: 范围标识 (会话 ID 或平台 ID,来自主数据库)
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k: 返回结果数量 (可选)
|
||||
enable_rerank: 是否启用 Rerank (可选)
|
||||
|
||||
Returns:
|
||||
配置对象
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 查找现有配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
# 更新现有配置
|
||||
config.kb_ids = json.dumps(kb_ids)
|
||||
if top_k is not None:
|
||||
config.top_k = top_k
|
||||
if enable_rerank is not None:
|
||||
config.enable_rerank = enable_rerank
|
||||
else:
|
||||
# 创建新配置
|
||||
config = KBSessionConfig(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
kb_ids=json.dumps(kb_ids),
|
||||
top_k=top_k,
|
||||
enable_rerank=enable_rerank,
|
||||
)
|
||||
session.add(config)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
|
||||
"""删除会话知识库配置
|
||||
|
||||
Args:
|
||||
scope: 配置范围 (session/platform)
|
||||
scope_id: 范围标识 (会话 ID 或平台 ID)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
return False
|
||||
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def delete_session_kb_config_by_session_id(self, session_id: str) -> bool:
|
||||
"""根据会话ID删除会话配置(用于主数据库会话删除时的级联清理)
|
||||
|
||||
Args:
|
||||
session_id: 会话ID(来自主数据库)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
return await self.delete_session_kb_config("session", session_id)
|
||||
|
||||
async def list_all_session_configs(
|
||||
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
|
||||
) -> list[KBSessionConfig]:
|
||||
"""列出所有会话配置
|
||||
|
||||
Args:
|
||||
offset: 偏移量
|
||||
limit: 限制数量
|
||||
scope: 可选的范围过滤 (session/platform)
|
||||
|
||||
Returns:
|
||||
会话配置列表
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig)
|
||||
|
||||
if scope:
|
||||
stmt = stmt.where(KBSessionConfig.scope == scope)
|
||||
|
||||
stmt = (
|
||||
stmt.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBSessionConfig.created_at.desc())
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
@@ -1,115 +0,0 @@
|
||||
"""知识库上下文注入器
|
||||
|
||||
负责检索相关知识并格式化为 LLM 可用的上下文文本
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.manager import (
|
||||
RetrievalManager,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBaseInjector:
|
||||
"""知识库上下文注入器
|
||||
|
||||
职责:
|
||||
- 检索相关知识
|
||||
- 格式化为上下文文本
|
||||
- 注入到 LLM Prompt
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kb_db: KBDatabase,
|
||||
retrieval_manager: RetrievalManager,
|
||||
):
|
||||
"""初始化知识库上下文注入器
|
||||
|
||||
Args:
|
||||
kb_db: 知识库数据库实例
|
||||
retrieval_manager: 检索管理器实例
|
||||
"""
|
||||
self.kb_db = kb_db
|
||||
self.retrieval_manager = retrieval_manager
|
||||
|
||||
async def retrieve_and_inject(
|
||||
self,
|
||||
unified_msg_origin: str,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> Optional[dict]:
|
||||
"""检索并注入知识库上下文
|
||||
|
||||
Args:
|
||||
unified_msg_origin: 统一消息来源 ID (会话 ID)
|
||||
query: 用户查询
|
||||
top_k: 返回结果数量
|
||||
|
||||
Returns:
|
||||
Optional[dict]: 包含检索结果和格式化上下文的字典,如果无结果则返回 None
|
||||
{
|
||||
"context_text": str, # 格式化的上下文文本
|
||||
"results": List[dict], # 原始检索结果列表
|
||||
}
|
||||
"""
|
||||
# 1. 获取会话关联的知识库
|
||||
kb_ids = await self.kb_db.get_session_kb_ids(unified_msg_origin)
|
||||
|
||||
if not kb_ids:
|
||||
return None
|
||||
|
||||
# 2. 检索知识
|
||||
results = await self.retrieval_manager.retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
top_m_final=top_k,
|
||||
)
|
||||
|
||||
if not results:
|
||||
return None
|
||||
|
||||
# 3. 格式化上下文
|
||||
context_text = self._format_context(results)
|
||||
|
||||
# 4. 转换结果为字典格式
|
||||
results_dict = [
|
||||
{
|
||||
"chunk_id": r.chunk_id,
|
||||
"doc_id": r.doc_id,
|
||||
"kb_id": r.kb_id,
|
||||
"kb_name": r.kb_name,
|
||||
"doc_name": r.doc_name,
|
||||
"chunk_index": r.metadata.get("chunk_index", 0),
|
||||
"content": r.content,
|
||||
"score": r.score,
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
return {
|
||||
"context_text": context_text,
|
||||
"results": results_dict,
|
||||
}
|
||||
|
||||
def _format_context(self, results: List[RetrievalResult]) -> str:
|
||||
"""格式化知识上下文
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
|
||||
Returns:
|
||||
str: 格式化的上下文文本
|
||||
"""
|
||||
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
lines.append(f"【知识 {i}】")
|
||||
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
|
||||
lines.append(f"内容: {result.content}")
|
||||
lines.append(f"相关度: {result.score:.2f}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -0,0 +1,299 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from sqlmodel import col, desc
|
||||
from sqlalchemy import text, func, select, update, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
BaseKBModel,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
|
||||
class KBSQLiteDatabase:
|
||||
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
|
||||
"""初始化知识库数据库
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径, 默认为 data/knowledge_base/kb.db
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.inited = False
|
||||
|
||||
# 确保目录存在
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建异步引擎
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
self.async_session = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self):
|
||||
"""获取数据库会话
|
||||
|
||||
用法:
|
||||
async with kb_db.get_db() as session:
|
||||
# 执行数据库操作
|
||||
result = await session.execute(stmt)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
yield session
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化数据库,创建表并配置 SQLite 参数"""
|
||||
async with self.engine.begin() as conn:
|
||||
# 创建所有知识库相关表
|
||||
await conn.run_sync(BaseKBModel.metadata.create_all)
|
||||
|
||||
# 配置 SQLite 性能优化参数
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
self.inited = True
|
||||
|
||||
async def migrate_to_v1(self) -> None:
|
||||
"""执行知识库数据库 v1 迁移
|
||||
|
||||
创建所有必要的索引以优化查询性能
|
||||
"""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
# 创建知识库表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
|
||||
"ON knowledge_bases(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_name "
|
||||
"ON knowledge_bases(kb_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
|
||||
"ON knowledge_bases(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建文档表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
|
||||
"ON kb_documents(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
|
||||
"ON kb_documents(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_name "
|
||||
"ON kb_documents(doc_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_type "
|
||||
"ON kb_documents(file_type)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
|
||||
"ON kb_documents(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建多媒体表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
|
||||
"ON kb_media(media_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
|
||||
"ON kb_media(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_type "
|
||||
"ON kb_media(media_type)"
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭数据库连接"""
|
||||
await self.engine.dispose()
|
||||
logger.info(f"知识库数据库已关闭: {self.db_path}")
|
||||
|
||||
async def get_kb_by_id(self, kb_id: str) -> KnowledgeBase | None:
|
||||
"""根据 ID 获取知识库"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_id) == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_kb_by_name(self, kb_name: str) -> KnowledgeBase | None:
|
||||
"""根据名称获取知识库"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(col(KnowledgeBase.kb_name) == kb_name)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库"""
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KnowledgeBase)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(desc(KnowledgeBase.created_at))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_kbs(self) -> int:
|
||||
"""统计知识库数量"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(func.count(col(KnowledgeBase.id)))
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
# ===== 文档查询 =====
|
||||
|
||||
async def get_document_by_id(self, doc_id: str) -> KBDocument | None:
|
||||
"""根据 ID 获取文档"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_documents_by_kb(
|
||||
self, kb_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument)
|
||||
.where(col(KBDocument.kb_id) == kb_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(desc(KBDocument.created_at))
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def count_documents_by_kb(self, kb_id: str) -> int:
|
||||
"""统计知识库的文档数量"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(func.count(col(KBDocument.id))).where(
|
||||
col(KBDocument.kb_id) == kb_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar() or 0
|
||||
|
||||
async def get_document_with_metadata(self, doc_id: str) -> dict | None:
|
||||
async with self.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument, KnowledgeBase)
|
||||
.join(KnowledgeBase, col(KBDocument.kb_id) == col(KnowledgeBase.kb_id))
|
||||
.where(col(KBDocument.doc_id) == doc_id)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
return {
|
||||
"document": row[0],
|
||||
"knowledge_base": row[1],
|
||||
}
|
||||
|
||||
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
|
||||
"""删除单个文档及其相关数据"""
|
||||
# 在知识库表中删除
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
# 删除文档记录
|
||||
delete_stmt = delete(KBDocument).where(col(KBDocument.doc_id) == doc_id)
|
||||
await session.execute(delete_stmt)
|
||||
await session.commit()
|
||||
|
||||
# 在 vec db 中删除相关向量
|
||||
await vec_db.delete_documents(metadata_filters={"kb_doc_id": doc_id})
|
||||
|
||||
# ===== 多媒体查询 =====
|
||||
|
||||
async def list_media_by_doc(self, doc_id: str) -> list[KBMedia]:
|
||||
"""列出文档的所有多媒体资源"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBMedia).where(col(KBMedia.doc_id) == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_media_by_id(self, media_id: str) -> KBMedia | None:
|
||||
"""根据 ID 获取多媒体资源"""
|
||||
async with self.get_db() as session:
|
||||
stmt = select(KBMedia).where(col(KBMedia.media_id) == media_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_kb_stats(self, kb_id: str, vec_db: FaissVecDB) -> None:
|
||||
"""更新知识库统计信息"""
|
||||
chunk_cnt = await vec_db.count_documents()
|
||||
|
||||
async with self.get_db() as session:
|
||||
async with session.begin():
|
||||
update_stmt = (
|
||||
update(KnowledgeBase)
|
||||
.where(col(KnowledgeBase.kb_id) == kb_id)
|
||||
.values(
|
||||
doc_count=select(func.count(col(KBDocument.id)))
|
||||
.where(col(KBDocument.kb_id) == kb_id)
|
||||
.scalar_subquery(),
|
||||
chunk_count=chunk_cnt,
|
||||
)
|
||||
)
|
||||
|
||||
await session.execute(update_stmt)
|
||||
await session.commit()
|
||||
@@ -0,0 +1,352 @@
|
||||
import uuid
|
||||
import aiofiles
|
||||
import json
|
||||
from pathlib import Path
|
||||
from .models import KnowledgeBase, KBDocument, KBMedia
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from .parsers.base import BaseParser
|
||||
from .chunking.base import BaseChunker
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class KBHelper:
|
||||
vec_db: BaseVecDB
|
||||
kb: KnowledgeBase
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kb_db: KBSQLiteDatabase,
|
||||
kb: KnowledgeBase,
|
||||
provider_manager: ProviderManager,
|
||||
kb_root_dir: str,
|
||||
chunker: BaseChunker,
|
||||
parsers: dict[str, BaseParser],
|
||||
):
|
||||
self.kb_db = kb_db
|
||||
self.kb = kb
|
||||
self.prov_mgr = provider_manager
|
||||
self.kb_root_dir = kb_root_dir
|
||||
self.parsers = parsers
|
||||
self.chunker = chunker
|
||||
|
||||
self.kb_dir = Path(self.kb_root_dir) / self.kb.kb_id
|
||||
self.kb_medias_dir = Path(self.kb_dir) / "medias" / self.kb.kb_id
|
||||
self.kb_files_dir = Path(self.kb_dir) / "files" / self.kb.kb_id
|
||||
|
||||
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def initialize(self):
|
||||
await self._ensure_vec_db()
|
||||
|
||||
async def get_ep(self) -> EmbeddingProvider:
|
||||
if not self.kb.embedding_provider_id:
|
||||
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
|
||||
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
|
||||
self.kb.embedding_provider_id
|
||||
) # type: ignore
|
||||
if not ep:
|
||||
raise ValueError(
|
||||
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider"
|
||||
)
|
||||
return ep
|
||||
|
||||
async def get_rp(self) -> RerankProvider | None:
|
||||
if not self.kb.rerank_provider_id:
|
||||
return None
|
||||
rp: RerankProvider = await self.prov_mgr.get_provider_by_id(
|
||||
self.kb.rerank_provider_id
|
||||
) # type: ignore
|
||||
if not rp:
|
||||
raise ValueError(
|
||||
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider"
|
||||
)
|
||||
return rp
|
||||
|
||||
async def _ensure_vec_db(self) -> FaissVecDB:
|
||||
if not self.kb.embedding_provider_id:
|
||||
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
|
||||
|
||||
ep = await self.get_ep()
|
||||
rp = await self.get_rp()
|
||||
|
||||
vec_db = FaissVecDB(
|
||||
doc_store_path=str(self.kb_dir / "doc.db"),
|
||||
index_store_path=str(self.kb_dir / "index.faiss"),
|
||||
embedding_provider=ep,
|
||||
rerank_provider=rp,
|
||||
)
|
||||
await vec_db.initialize()
|
||||
self.vec_db = vec_db
|
||||
return vec_db
|
||||
|
||||
async def delete_vec_db(self):
|
||||
await self.terminate()
|
||||
if self.kb_dir.exists():
|
||||
for item in self.kb_dir.iterdir():
|
||||
if item.is_file():
|
||||
item.unlink()
|
||||
self.kb_dir.rmdir()
|
||||
|
||||
async def terminate(self):
|
||||
if self.vec_db:
|
||||
await self.vec_db.close()
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
file_name: str,
|
||||
file_content: bytes,
|
||||
file_type: str,
|
||||
chunk_size: int = 512,
|
||||
chunk_overlap: int = 50,
|
||||
batch_size: int = 32,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> KBDocument:
|
||||
"""上传并处理文档(带原子性保证和失败清理)
|
||||
|
||||
流程:
|
||||
1. 保存原始文件
|
||||
2. 解析文档内容
|
||||
3. 提取多媒体资源
|
||||
4. 分块处理
|
||||
5. 生成向量并存储
|
||||
6. 保存元数据(事务)
|
||||
7. 更新统计
|
||||
|
||||
Args:
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total)
|
||||
- stage: 当前阶段 ('parsing', 'chunking', 'embedding')
|
||||
- current: 当前进度
|
||||
- total: 总数
|
||||
"""
|
||||
await self._ensure_vec_db()
|
||||
doc_id = str(uuid.uuid4())
|
||||
media_paths: list[Path] = []
|
||||
|
||||
# file_path = self.kb_files_dir / f"{doc_id}.{file_type}"
|
||||
# async with aiofiles.open(file_path, "wb") as f:
|
||||
# await f.write(file_content)
|
||||
|
||||
try:
|
||||
# 阶段1: 解析文档
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 0, 100)
|
||||
|
||||
parser = self.parsers.get(file_type)
|
||||
if not parser:
|
||||
raise ValueError(f"不支持的文件类型: {file_type}")
|
||||
parse_result = await parser.parse(file_content, file_name)
|
||||
text_content = parse_result.text
|
||||
media_items = parse_result.media
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("parsing", 100, 100)
|
||||
|
||||
# 保存媒体文件
|
||||
saved_media = []
|
||||
for media_item in media_items:
|
||||
media = await self._save_media(
|
||||
doc_id=doc_id,
|
||||
media_type=media_item.media_type,
|
||||
file_name=media_item.file_name,
|
||||
content=media_item.content,
|
||||
mime_type=media_item.mime_type,
|
||||
)
|
||||
saved_media.append(media)
|
||||
media_paths.append(Path(media.file_path))
|
||||
|
||||
# 阶段2: 分块
|
||||
if progress_callback:
|
||||
await progress_callback("chunking", 0, 100)
|
||||
|
||||
chunks_text = await self.chunker.chunk(
|
||||
text_content, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
contents = []
|
||||
metadatas = []
|
||||
for idx, chunk_text in enumerate(chunks_text):
|
||||
contents.append(chunk_text)
|
||||
metadatas.append(
|
||||
{
|
||||
"kb_id": self.kb.kb_id,
|
||||
"kb_doc_id": doc_id,
|
||||
"chunk_index": idx,
|
||||
}
|
||||
)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("chunking", 100, 100)
|
||||
|
||||
# 阶段3: 生成向量(带进度回调)
|
||||
async def embedding_progress_callback(current, total):
|
||||
if progress_callback:
|
||||
await progress_callback("embedding", current, total)
|
||||
|
||||
await self.vec_db.insert_batch(
|
||||
contents=contents,
|
||||
metadatas=metadatas,
|
||||
batch_size=batch_size,
|
||||
tasks_limit=tasks_limit,
|
||||
max_retries=max_retries,
|
||||
progress_callback=embedding_progress_callback,
|
||||
)
|
||||
|
||||
# 保存文档的元数据
|
||||
doc = KBDocument(
|
||||
doc_id=doc_id,
|
||||
kb_id=self.kb.kb_id,
|
||||
doc_name=file_name,
|
||||
file_type=file_type,
|
||||
file_size=len(file_content),
|
||||
# file_path=str(file_path),
|
||||
file_path="",
|
||||
chunk_count=len(chunks_text),
|
||||
media_count=0,
|
||||
)
|
||||
async with self.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
session.add(doc)
|
||||
for media in saved_media:
|
||||
session.add(media)
|
||||
await session.commit()
|
||||
|
||||
await session.refresh(doc)
|
||||
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
await self.kb_db.update_kb_stats(kb_id=self.kb.kb_id, vec_db=vec_db)
|
||||
await self.refresh_kb()
|
||||
await self.refresh_document(doc_id)
|
||||
return doc
|
||||
except Exception as e:
|
||||
logger.error(f"上传文档失败: {e}")
|
||||
# if file_path.exists():
|
||||
# file_path.unlink()
|
||||
|
||||
for media_path in media_paths:
|
||||
try:
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as me:
|
||||
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
|
||||
|
||||
raise e
|
||||
|
||||
async def list_documents(
|
||||
self, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
docs = await self.kb_db.list_documents_by_kb(self.kb.kb_id, offset, limit)
|
||||
return docs
|
||||
|
||||
async def get_document(self, doc_id: str) -> KBDocument | None:
|
||||
"""获取单个文档"""
|
||||
doc = await self.kb_db.get_document_by_id(doc_id)
|
||||
return doc
|
||||
|
||||
async def delete_document(self, doc_id: str):
|
||||
"""删除单个文档及其相关数据"""
|
||||
await self.kb_db.delete_document_by_id(
|
||||
doc_id=doc_id,
|
||||
vec_db=self.vec_db, # type: ignore
|
||||
)
|
||||
await self.kb_db.update_kb_stats(
|
||||
kb_id=self.kb.kb_id,
|
||||
vec_db=self.vec_db, # type: ignore
|
||||
)
|
||||
await self.refresh_kb()
|
||||
|
||||
async def delete_chunk(self, chunk_id: str, doc_id: str):
|
||||
"""删除单个文本块及其相关数据"""
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
await vec_db.delete(chunk_id)
|
||||
await self.kb_db.update_kb_stats(
|
||||
kb_id=self.kb.kb_id,
|
||||
vec_db=self.vec_db, # type: ignore
|
||||
)
|
||||
await self.refresh_kb()
|
||||
await self.refresh_document(doc_id)
|
||||
|
||||
async def refresh_kb(self):
|
||||
if self.kb:
|
||||
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
|
||||
if kb:
|
||||
self.kb = kb
|
||||
|
||||
async def refresh_document(self, doc_id: str) -> None:
|
||||
"""更新文档的元数据"""
|
||||
doc = await self.get_document(doc_id)
|
||||
if not doc:
|
||||
raise ValueError(f"无法找到 ID 为 {doc_id} 的文档")
|
||||
chunk_count = await self.get_chunk_count_by_doc_id(doc_id)
|
||||
doc.chunk_count = chunk_count
|
||||
async with self.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
session.add(doc)
|
||||
await session.commit()
|
||||
await session.refresh(doc)
|
||||
|
||||
async def get_chunks_by_doc_id(
|
||||
self, doc_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[dict]:
|
||||
"""获取文档的所有块及其元数据"""
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
chunks = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={"kb_doc_id": doc_id}, offset=offset, limit=limit
|
||||
)
|
||||
result = []
|
||||
for chunk in chunks:
|
||||
chunk_md = json.loads(chunk["metadata"])
|
||||
result.append(
|
||||
{
|
||||
"chunk_id": chunk["doc_id"],
|
||||
"doc_id": chunk_md["kb_doc_id"],
|
||||
"kb_id": chunk_md["kb_id"],
|
||||
"chunk_index": chunk_md["chunk_index"],
|
||||
"content": chunk["text"],
|
||||
"char_count": len(chunk["text"]),
|
||||
}
|
||||
)
|
||||
return result
|
||||
|
||||
async def get_chunk_count_by_doc_id(self, doc_id: str) -> int:
|
||||
"""获取文档的块数量"""
|
||||
vec_db: FaissVecDB = self.vec_db # type: ignore
|
||||
count = await vec_db.count_documents(metadata_filter={"kb_doc_id": doc_id})
|
||||
return count
|
||||
|
||||
async def _save_media(
|
||||
self,
|
||||
doc_id: str,
|
||||
media_type: str,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
mime_type: str,
|
||||
) -> KBMedia:
|
||||
"""保存多媒体资源"""
|
||||
media_id = str(uuid.uuid4())
|
||||
ext = Path(file_name).suffix
|
||||
|
||||
# 保存文件
|
||||
file_path = self.kb_medias_dir / doc_id / f"{media_id}{ext}"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
media = KBMedia(
|
||||
media_id=media_id,
|
||||
doc_id=doc_id,
|
||||
kb_id=self.kb.kb_id,
|
||||
media_type=media_type,
|
||||
file_name=file_name,
|
||||
file_path=str(file_path),
|
||||
file_size=len(content),
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
return media
|
||||
@@ -1,364 +0,0 @@
|
||||
"""
|
||||
知识库管理器
|
||||
负责知识库模块的初始化、配置和资源管理
|
||||
|
||||
架构说明:
|
||||
- 知识库数据存储在独立的数据库 (data/knowledge_base/kb.db)
|
||||
- 会话配置存储在主数据库 (data/astrbot.db) 以便于会话关联
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
|
||||
|
||||
class KnowledgeBaseManager:
|
||||
"""知识库管理器
|
||||
|
||||
职责:
|
||||
- 知识库模块的初始化
|
||||
- Embedding Provider 和 Rerank Provider 的选择
|
||||
- 各个子组件的协调管理
|
||||
- 注册会话删除回调,实现级联清理
|
||||
|
||||
架构说明:
|
||||
- 知识库数据存储在独立数据库 (kb.db)
|
||||
- 会话配置存储在独立数据库 (kb.db),会话ID来自主数据库
|
||||
- 通过回调机制实现与主数据库的生命周期同步
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: dict,
|
||||
main_db: BaseDatabase,
|
||||
provider_manager: ProviderManager,
|
||||
):
|
||||
"""初始化知识库管理器
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
main_db: 主数据库实例 (不直接使用,仅用于类型引用)
|
||||
provider_manager: Provider 管理器
|
||||
"""
|
||||
self.config = config.get("knowledge_base", {})
|
||||
self.provider_manager = provider_manager
|
||||
|
||||
# 知识库独立数据库
|
||||
self.kb_db = None
|
||||
|
||||
# 组件实例
|
||||
self.kb_database = None
|
||||
self.kb_manager = None
|
||||
self.kb_vec_db = None
|
||||
self.retrieval_manager = None
|
||||
self.kb_injector = None
|
||||
|
||||
self._initialized = False
|
||||
self._session_deleted_callback_registered = False
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化知识库模块"""
|
||||
if not self.config.get("enabled", False):
|
||||
logger.info("知识库功能未启用")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("正在初始化知识库模块...")
|
||||
|
||||
# 1. 检查并选择 Embedding Provider
|
||||
embedding_provider = self._select_embedding_provider()
|
||||
if not embedding_provider:
|
||||
logger.warning("未配置 Embedding Provider,知识库功能无法使用")
|
||||
return
|
||||
|
||||
# 2. 初始化数据库
|
||||
await self._init_kb_database()
|
||||
await self._init_database()
|
||||
|
||||
# 3. 初始化向量数据库
|
||||
await self._init_vector_db(embedding_provider)
|
||||
|
||||
# 4. 初始化解析器和分块器
|
||||
parsers = self._init_parsers()
|
||||
chunker = self._init_chunker()
|
||||
|
||||
# 5. 初始化知识库管理器
|
||||
await self._init_kb_manager(parsers, chunker)
|
||||
|
||||
# 6. 初始化检索管理器
|
||||
await self._init_retrieval_manager()
|
||||
|
||||
# 7. 初始化上下文注入器
|
||||
await self._init_injector()
|
||||
|
||||
self._initialized = True
|
||||
logger.info("知识库模块初始化完成")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"知识库模块导入失败: {e}")
|
||||
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
|
||||
except Exception as e:
|
||||
logger.error(f"知识库模块初始化失败: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _init_kb_database(self):
|
||||
"""初始化知识库独立数据库"""
|
||||
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
|
||||
|
||||
db_path = self.config.get("storage", {}).get(
|
||||
"kb_db_path", "data/knowledge_base/kb.db"
|
||||
)
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.kb_db = KBSQLiteDatabase(db_path)
|
||||
await self.kb_db.initialize()
|
||||
await self.kb_db.migrate_to_v1()
|
||||
|
||||
logger.info(f"知识库独立数据库已初始化: {db_path}")
|
||||
|
||||
async def _init_database(self):
|
||||
"""初始化知识库数据库操作类"""
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
|
||||
self.kb_database = KBDatabase(self.kb_db)
|
||||
|
||||
async def _init_vector_db(self, embedding_provider):
|
||||
"""初始化向量数据库"""
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
storage_path = self.config.get("storage", {}).get(
|
||||
"vector_db_path", "data/knowledge_base/vectors"
|
||||
)
|
||||
Path(storage_path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.kb_vec_db = FaissVecDB(
|
||||
doc_store_path=f"{storage_path}/documents.db",
|
||||
index_store_path=f"{storage_path}/index.faiss",
|
||||
embedding_provider=embedding_provider,
|
||||
)
|
||||
await self.kb_vec_db.initialize()
|
||||
|
||||
def _init_parsers(self) -> dict:
|
||||
"""初始化文档解析器"""
|
||||
from astrbot.core.knowledge_base.parsers.text_parser import TextParser
|
||||
from astrbot.core.knowledge_base.parsers.pdf_parser import PDFParser
|
||||
|
||||
return {
|
||||
"txt": TextParser(),
|
||||
"md": TextParser(),
|
||||
"markdown": TextParser(),
|
||||
"pdf": PDFParser(),
|
||||
}
|
||||
|
||||
def _init_chunker(self):
|
||||
"""初始化分块器"""
|
||||
from astrbot.core.knowledge_base.chunking.fixed_size import FixedSizeChunker
|
||||
|
||||
chunking_config = self.config.get("chunking", {})
|
||||
return FixedSizeChunker(
|
||||
chunk_size=chunking_config.get("chunk_size", 512),
|
||||
chunk_overlap=chunking_config.get("chunk_overlap", 50),
|
||||
)
|
||||
|
||||
async def _init_kb_manager(self, parsers: dict, chunker):
|
||||
"""初始化知识库管理器"""
|
||||
from astrbot.core.knowledge_base.manager import KBManager
|
||||
|
||||
files_path = self.config.get("storage", {}).get(
|
||||
"files_path", "data/knowledge_base"
|
||||
)
|
||||
|
||||
self.kb_manager = KBManager(
|
||||
db=self.kb_db, # 使用独立的知识库数据库
|
||||
vec_db=self.kb_vec_db,
|
||||
storage_path=files_path,
|
||||
parsers=parsers,
|
||||
chunker=chunker,
|
||||
provider_manager=self.provider_manager,
|
||||
)
|
||||
|
||||
async def _init_retrieval_manager(self):
|
||||
"""初始化检索管理器"""
|
||||
from astrbot.core.knowledge_base.retrieval.manager import RetrievalManager
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import (
|
||||
SparseRetriever,
|
||||
)
|
||||
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
||||
|
||||
sparse_retriever = SparseRetriever(self.kb_database)
|
||||
rank_fusion = RankFusion(self.kb_database)
|
||||
|
||||
# 选择 Rerank Provider (可选)
|
||||
rerank_provider = self._select_rerank_provider()
|
||||
|
||||
self.retrieval_manager = RetrievalManager(
|
||||
vec_db=self.kb_vec_db,
|
||||
sparse_retriever=sparse_retriever,
|
||||
rank_fusion=rank_fusion,
|
||||
kb_db=self.kb_database,
|
||||
rerank_provider=rerank_provider,
|
||||
)
|
||||
|
||||
async def _init_injector(self):
|
||||
"""初始化上下文注入器"""
|
||||
from astrbot.core.knowledge_base.injector import KnowledgeBaseInjector
|
||||
|
||||
self.kb_injector = KnowledgeBaseInjector(
|
||||
kb_db=self.kb_database,
|
||||
retrieval_manager=self.retrieval_manager,
|
||||
)
|
||||
|
||||
def _select_embedding_provider(self):
|
||||
"""选择 Embedding Provider
|
||||
|
||||
逻辑:
|
||||
- 如果配置了 embedding_provider_id,则使用指定的 provider
|
||||
- 如果没有配置,但有 embedding provider,则使用第一个
|
||||
- 如果有多个 embedding provider 但没有指定,则警告并使用第一个
|
||||
"""
|
||||
embedding_providers = self.provider_manager.embedding_provider_insts
|
||||
|
||||
if not embedding_providers:
|
||||
return None
|
||||
|
||||
configured_provider_id = self.config.get("embedding_provider_id")
|
||||
|
||||
if configured_provider_id:
|
||||
# 按 ID 查找
|
||||
for provider in embedding_providers:
|
||||
provider_id = provider.meta().id
|
||||
if provider_id == configured_provider_id:
|
||||
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
|
||||
return provider
|
||||
logger.warning(
|
||||
f"未找到配置的 Embedding Provider ID: {configured_provider_id},"
|
||||
f"将使用第一个可用的"
|
||||
)
|
||||
|
||||
if len(embedding_providers) > 1 and not configured_provider_id:
|
||||
provider = embedding_providers[0]
|
||||
provider_id = provider.meta().id
|
||||
logger.info(
|
||||
f"检测到 {len(embedding_providers)} 个 Embedding Provider,"
|
||||
f"未在配置文件中指定 embedding_provider_id,将使用第一个: {provider_id}"
|
||||
)
|
||||
return provider
|
||||
|
||||
provider = embedding_providers[0]
|
||||
provider_id = provider.meta().id
|
||||
logger.info(f"知识库使用 Embedding Provider: {provider_id}")
|
||||
return provider
|
||||
|
||||
def _select_rerank_provider(self):
|
||||
"""选择 Rerank Provider (可选)"""
|
||||
if not self.config.get("retrieval", {}).get("enable_rerank", True):
|
||||
return None
|
||||
|
||||
rerank_providers = self.provider_manager.rerank_provider_insts
|
||||
if not rerank_providers:
|
||||
return None
|
||||
|
||||
configured_provider_id = self.config.get("rerank_provider_id")
|
||||
|
||||
if configured_provider_id:
|
||||
for provider in rerank_providers:
|
||||
provider_id = provider.meta().id
|
||||
if provider_id == configured_provider_id:
|
||||
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
|
||||
return provider
|
||||
logger.warning(f"未找到配置的 Rerank Provider ID: {configured_provider_id}")
|
||||
|
||||
if len(rerank_providers) > 0:
|
||||
provider = rerank_providers[0]
|
||||
provider_id = provider.meta().id
|
||||
logger.info(f"知识库使用 Rerank Provider: {provider_id}")
|
||||
return provider
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""检查是否已初始化"""
|
||||
return self._initialized
|
||||
|
||||
def get_kb_manager(self):
|
||||
"""获取知识库管理器"""
|
||||
return self.kb_manager if self._initialized else None
|
||||
|
||||
def get_kb_injector(self):
|
||||
"""获取知识库上下文注入器"""
|
||||
return self.kb_injector if self._initialized else None
|
||||
|
||||
def register_session_lifecycle_hooks(self, conversation_manager):
|
||||
"""注册会话生命周期钩子
|
||||
|
||||
在会话删除时自动清理知识库配置,实现零侵入的级联清理。
|
||||
|
||||
Args:
|
||||
conversation_manager: 会话管理器实例
|
||||
"""
|
||||
if self._session_deleted_callback_registered or not self._initialized:
|
||||
return
|
||||
|
||||
async def on_session_deleted(session_id: str):
|
||||
"""会话删除回调:清理知识库配置"""
|
||||
try:
|
||||
await self.kb_database.delete_session_kb_config_by_session_id(
|
||||
session_id
|
||||
)
|
||||
logger.info(f"已清理会话知识库配置: {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"清理会话知识库配置失败 ({session_id}): {e}")
|
||||
|
||||
conversation_manager.register_on_session_deleted(on_session_deleted)
|
||||
self._session_deleted_callback_registered = True
|
||||
logger.info("已注册知识库会话删除回调")
|
||||
|
||||
async def reinitialize(self):
|
||||
"""重新初始化知识库模块
|
||||
|
||||
用于在运行时动态初始化知识库模块(例如用户添加了 embedding provider 后)
|
||||
"""
|
||||
if self._initialized:
|
||||
logger.info("知识库模块已初始化,将重新初始化")
|
||||
await self.terminate()
|
||||
|
||||
await self.initialize()
|
||||
return self._initialized
|
||||
|
||||
async def terminate(self):
|
||||
"""终止知识库模块,清理资源"""
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
logger.info("正在终止知识库模块...")
|
||||
|
||||
# 关闭向量数据库连接
|
||||
if self.kb_vec_db:
|
||||
try:
|
||||
await self.kb_vec_db.close()
|
||||
logger.debug("向量数据库已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭向量数据库时出错: {e}")
|
||||
|
||||
# 关闭知识库独立数据库连接
|
||||
if self.kb_db:
|
||||
try:
|
||||
await self.kb_db.close()
|
||||
logger.debug("知识库数据库已关闭")
|
||||
except Exception as e:
|
||||
logger.warning(f"关闭知识库数据库时出错: {e}")
|
||||
|
||||
# 清理资源
|
||||
self._initialized = False
|
||||
self.kb_db = None
|
||||
self.kb_database = None
|
||||
self.kb_manager = None
|
||||
self.kb_vec_db = None
|
||||
self.retrieval_manager = None
|
||||
self.kb_injector = None
|
||||
|
||||
logger.info("知识库模块已终止")
|
||||
@@ -0,0 +1,279 @@
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
|
||||
from .retrieval.manager import RetrievalManager, RetrievalResult
|
||||
from .retrieval.sparse_retriever import SparseRetriever
|
||||
from .retrieval.rank_fusion import RankFusion
|
||||
from .kb_db_sqlite import KBSQLiteDatabase
|
||||
|
||||
from .parsers.text_parser import TextParser
|
||||
from .parsers.pdf_parser import PDFParser
|
||||
from .chunking.fixed_size import FixedSizeChunker
|
||||
from .kb_helper import KBHelper
|
||||
|
||||
from .models import KnowledgeBase
|
||||
|
||||
|
||||
FILES_PATH = "data/knowledge_base"
|
||||
DB_PATH = Path(FILES_PATH) / "kb.db"
|
||||
"""Knowledge Base storage root directory"""
|
||||
PARSERS = {
|
||||
"txt": TextParser(),
|
||||
"md": TextParser(),
|
||||
"markdown": TextParser(),
|
||||
"pdf": PDFParser(),
|
||||
}
|
||||
CHUNKER = FixedSizeChunker()
|
||||
|
||||
|
||||
class KnowledgeBaseManager:
|
||||
kb_db: KBSQLiteDatabase
|
||||
retrieval_manager: RetrievalManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_manager: ProviderManager,
|
||||
):
|
||||
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
|
||||
self.provider_manager = provider_manager
|
||||
self._session_deleted_callback_registered = False
|
||||
|
||||
self.kb_insts: dict[str, KBHelper] = {}
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化知识库模块"""
|
||||
try:
|
||||
logger.info("正在初始化知识库模块...")
|
||||
|
||||
# 初始化数据库
|
||||
await self._init_kb_database()
|
||||
|
||||
# 初始化检索管理器
|
||||
sparse_retriever = SparseRetriever(self.kb_db)
|
||||
rank_fusion = RankFusion(self.kb_db)
|
||||
self.retrieval_manager = RetrievalManager(
|
||||
sparse_retriever=sparse_retriever,
|
||||
rank_fusion=rank_fusion,
|
||||
kb_db=self.kb_db,
|
||||
)
|
||||
await self.load_kbs()
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"知识库模块导入失败: {e}")
|
||||
logger.warning("请确保已安装所需依赖: pypdf, aiofiles, Pillow, rank-bm25")
|
||||
except Exception as e:
|
||||
logger.error(f"知识库模块初始化失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _init_kb_database(self):
|
||||
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
|
||||
await self.kb_db.initialize()
|
||||
await self.kb_db.migrate_to_v1()
|
||||
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
|
||||
|
||||
async def load_kbs(self):
|
||||
"""加载所有知识库实例"""
|
||||
kb_records = await self.kb_db.list_kbs()
|
||||
for record in kb_records:
|
||||
kb_helper = KBHelper(
|
||||
kb_db=self.kb_db,
|
||||
kb=record,
|
||||
provider_manager=self.provider_manager,
|
||||
kb_root_dir=FILES_PATH,
|
||||
chunker=CHUNKER,
|
||||
parsers=PARSERS,
|
||||
)
|
||||
await kb_helper.initialize()
|
||||
self.kb_insts[record.kb_id] = kb_helper
|
||||
|
||||
async def create_kb(
|
||||
self,
|
||||
kb_name: str,
|
||||
description: str | None = None,
|
||||
emoji: str | None = None,
|
||||
embedding_provider_id: str | None = None,
|
||||
rerank_provider_id: str | None = None,
|
||||
chunk_size: int | None = None,
|
||||
chunk_overlap: int | None = None,
|
||||
top_k_dense: int | None = None,
|
||||
top_k_sparse: int | None = None,
|
||||
top_m_final: int | None = None,
|
||||
) -> KBHelper:
|
||||
"""创建新的知识库实例"""
|
||||
kb = KnowledgeBase(
|
||||
kb_name=kb_name,
|
||||
description=description,
|
||||
emoji=emoji or "📚",
|
||||
embedding_provider_id=embedding_provider_id,
|
||||
rerank_provider_id=rerank_provider_id,
|
||||
chunk_size=chunk_size if chunk_size is not None else 512,
|
||||
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
|
||||
top_k_dense=top_k_dense if top_k_dense is not None else 50,
|
||||
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
|
||||
top_m_final=top_m_final if top_m_final is not None else 5,
|
||||
)
|
||||
async with self.kb_db.get_db() as session:
|
||||
session.add(kb)
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
|
||||
kb_helper = KBHelper(
|
||||
kb_db=self.kb_db,
|
||||
kb=kb,
|
||||
provider_manager=self.provider_manager,
|
||||
kb_root_dir=FILES_PATH,
|
||||
chunker=CHUNKER,
|
||||
parsers=PARSERS,
|
||||
)
|
||||
await kb_helper.initialize()
|
||||
self.kb_insts[kb.kb_id] = kb_helper
|
||||
return kb_helper
|
||||
|
||||
async def get_kb(self, kb_id: str) -> KBHelper | None:
|
||||
"""获取知识库实例"""
|
||||
if kb_id in self.kb_insts:
|
||||
return self.kb_insts[kb_id]
|
||||
|
||||
async def get_kb_by_name(self, kb_name: str) -> KBHelper | None:
|
||||
"""通过名称获取知识库实例"""
|
||||
for kb_helper in self.kb_insts.values():
|
||||
if kb_helper.kb.kb_name == kb_name:
|
||||
return kb_helper
|
||||
return None
|
||||
|
||||
async def delete_kb(self, kb_id: str) -> bool:
|
||||
"""删除知识库实例"""
|
||||
kb_helper = await self.get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
return False
|
||||
|
||||
await kb_helper.delete_vec_db()
|
||||
async with self.kb_db.get_db() as session:
|
||||
await session.delete(kb_helper.kb)
|
||||
await session.commit()
|
||||
|
||||
self.kb_insts.pop(kb_id, None)
|
||||
return True
|
||||
|
||||
async def list_kbs(self) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库实例"""
|
||||
kbs = [kb_helper.kb for kb_helper in self.kb_insts.values()]
|
||||
return kbs
|
||||
|
||||
async def update_kb(
|
||||
self,
|
||||
kb_id: str,
|
||||
kb_name: str,
|
||||
description: str | None = None,
|
||||
emoji: str | None = None,
|
||||
embedding_provider_id: str | None = None,
|
||||
rerank_provider_id: str | None = None,
|
||||
chunk_size: int | None = None,
|
||||
chunk_overlap: int | None = None,
|
||||
top_k_dense: int | None = None,
|
||||
top_k_sparse: int | None = None,
|
||||
top_m_final: int | None = None,
|
||||
) -> KBHelper | None:
|
||||
"""更新知识库实例"""
|
||||
kb_helper = await self.get_kb(kb_id)
|
||||
if not kb_helper:
|
||||
return None
|
||||
|
||||
kb = kb_helper.kb
|
||||
if kb_name is not None:
|
||||
kb.kb_name = kb_name
|
||||
if description is not None:
|
||||
kb.description = description
|
||||
if emoji is not None:
|
||||
kb.emoji = emoji
|
||||
if embedding_provider_id is not None:
|
||||
kb.embedding_provider_id = embedding_provider_id
|
||||
kb.rerank_provider_id = rerank_provider_id # 允许设置为 None
|
||||
if chunk_size is not None:
|
||||
kb.chunk_size = chunk_size
|
||||
if chunk_overlap is not None:
|
||||
kb.chunk_overlap = chunk_overlap
|
||||
if top_k_dense is not None:
|
||||
kb.top_k_dense = top_k_dense
|
||||
if top_k_sparse is not None:
|
||||
kb.top_k_sparse = top_k_sparse
|
||||
if top_m_final is not None:
|
||||
kb.top_m_final = top_m_final
|
||||
async with self.kb_db.get_db() as session:
|
||||
session.add(kb)
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
|
||||
return kb_helper
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_names: list[str],
|
||||
top_k_fusion: int = 20,
|
||||
top_m_final: int = 5,
|
||||
) -> dict | None:
|
||||
"""从指定知识库中检索相关内容"""
|
||||
kb_ids = []
|
||||
kb_id_helper_map = {}
|
||||
for kb_name in kb_names:
|
||||
if kb_helper := await self.get_kb_by_name(kb_name):
|
||||
kb_ids.append(kb_helper.kb.kb_id)
|
||||
kb_id_helper_map[kb_helper.kb.kb_id] = kb_helper
|
||||
|
||||
if not kb_ids:
|
||||
return {}
|
||||
|
||||
results = await self.retrieval_manager.retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
kb_id_helper_map=kb_id_helper_map,
|
||||
top_k_fusion=top_k_fusion,
|
||||
top_m_final=top_m_final,
|
||||
)
|
||||
if not results:
|
||||
return None
|
||||
|
||||
context_text = self._format_context(results)
|
||||
|
||||
results_dict = [
|
||||
{
|
||||
"chunk_id": r.chunk_id,
|
||||
"doc_id": r.doc_id,
|
||||
"kb_id": r.kb_id,
|
||||
"kb_name": r.kb_name,
|
||||
"doc_name": r.doc_name,
|
||||
"chunk_index": r.metadata.get("chunk_index", 0),
|
||||
"content": r.content,
|
||||
"score": r.score,
|
||||
"char_count": r.metadata.get("char_count", 0),
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
return {
|
||||
"context_text": context_text,
|
||||
"results": results_dict,
|
||||
}
|
||||
|
||||
def _format_context(self, results: list[RetrievalResult]) -> str:
|
||||
"""格式化知识上下文
|
||||
|
||||
Args:
|
||||
results: 检索结果列表
|
||||
|
||||
Returns:
|
||||
str: 格式化的上下文文本
|
||||
"""
|
||||
lines = ["以下是相关的知识库内容,请参考这些信息回答用户的问题:\n"]
|
||||
|
||||
for i, result in enumerate(results, 1):
|
||||
lines.append(f"【知识 {i}】")
|
||||
lines.append(f"来源: {result.kb_name} / {result.doc_name}")
|
||||
lines.append(f"内容: {result.content}")
|
||||
lines.append(f"相关度: {result.score:.2f}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -1,230 +0,0 @@
|
||||
"""
|
||||
知识库独立 SQLite 数据库
|
||||
|
||||
该模块提供知识库专用的独立 SQLite 数据库,与主数据库 (astrbot.db) 完全隔离。
|
||||
职责:
|
||||
- 管理知识库相关表 (knowledge_bases, kb_documents, kb_chunks, kb_media)
|
||||
- 提供数据库连接和会话管理
|
||||
- 执行数据库迁移和初始化
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class KBSQLiteDatabase:
|
||||
"""知识库独立 SQLite 数据库
|
||||
|
||||
与主数据库 (astrbot.db) 完全隔离的独立数据库,专门用于存储知识库数据。
|
||||
|
||||
特点:
|
||||
- 数据隔离: 知识库数据不会影响主数据库格式
|
||||
- 独立备份: 可以单独备份和恢复知识库数据
|
||||
- 性能隔离: 大量知识库查询不会影响主业务性能
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = "data/knowledge_base/kb.db") -> None:
|
||||
"""初始化知识库数据库
|
||||
|
||||
Args:
|
||||
db_path: 数据库文件路径,默认为 data/knowledge_base/kb.db
|
||||
"""
|
||||
self.db_path = db_path
|
||||
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
|
||||
self.inited = False
|
||||
|
||||
# 确保目录存在
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 创建异步引擎
|
||||
self.engine = create_async_engine(
|
||||
self.DATABASE_URL,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=3600,
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
self.async_session = async_sessionmaker(
|
||||
self.engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db(self):
|
||||
"""获取数据库会话
|
||||
|
||||
用法:
|
||||
async with kb_db.get_db() as session:
|
||||
# 执行数据库操作
|
||||
result = await session.execute(stmt)
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
yield session
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""初始化数据库,创建表并配置 SQLite 参数"""
|
||||
# noqa: F401 - 这些导入是必需的,用于触发 SQLModel 创建对应的数据库表
|
||||
from astrbot.core.knowledge_base.models import ( # noqa: F401
|
||||
KBChunk,
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KBSessionConfig,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
async with self.engine.begin() as conn:
|
||||
# 创建所有知识库相关表
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
# 配置 SQLite 性能优化参数
|
||||
await conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
await conn.execute(text("PRAGMA synchronous=NORMAL"))
|
||||
await conn.execute(text("PRAGMA cache_size=20000"))
|
||||
await conn.execute(text("PRAGMA temp_store=MEMORY"))
|
||||
await conn.execute(text("PRAGMA mmap_size=134217728"))
|
||||
await conn.execute(text("PRAGMA optimize"))
|
||||
await conn.commit()
|
||||
|
||||
self.inited = True
|
||||
logger.info(f"知识库数据库已初始化: {self.db_path}")
|
||||
|
||||
async def migrate_to_v1(self) -> None:
|
||||
"""执行知识库数据库 v1 迁移
|
||||
|
||||
创建所有必要的索引以优化查询性能
|
||||
"""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
# 创建知识库表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_kb_id "
|
||||
"ON knowledge_bases(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_name "
|
||||
"ON knowledge_bases(kb_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_kb_created_at "
|
||||
"ON knowledge_bases(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建文档表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_doc_id "
|
||||
"ON kb_documents(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_kb_id "
|
||||
"ON kb_documents(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_name "
|
||||
"ON kb_documents(doc_name)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_type "
|
||||
"ON kb_documents(file_type)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_doc_created_at "
|
||||
"ON kb_documents(created_at)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建块表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_chunk_id "
|
||||
"ON kb_chunks(chunk_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_doc_id "
|
||||
"ON kb_chunks(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_kb_id ON kb_chunks(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunk_vec_doc_id "
|
||||
"ON kb_chunks(vec_doc_id)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建多媒体表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_media_id "
|
||||
"ON kb_media(media_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_doc_id "
|
||||
"ON kb_media(doc_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_kb_id ON kb_media(kb_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_media_type "
|
||||
"ON kb_media(media_type)"
|
||||
)
|
||||
)
|
||||
|
||||
# 创建会话配置表索引
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_config_scope_id "
|
||||
"ON kb_session_config(scope_id)"
|
||||
)
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS idx_session_config_scope "
|
||||
"ON kb_session_config(scope)"
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info("知识库数据库迁移 v1 完成")
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭数据库连接"""
|
||||
await self.engine.dispose()
|
||||
logger.info(f"知识库数据库已关闭: {self.db_path}")
|
||||
@@ -1,375 +0,0 @@
|
||||
"""知识库管理器
|
||||
|
||||
该模块提供知识库的CRUD操作和文档上传处理流程。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
from sqlalchemy import func, select, update
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.knowledge_base.chunking.base import BaseChunker
|
||||
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KnowledgeBase
|
||||
from astrbot.core.knowledge_base.parsers.base import BaseParser
|
||||
|
||||
|
||||
class KBManager:
|
||||
"""知识库管理器
|
||||
|
||||
职责:
|
||||
- 知识库的 CRUD 操作
|
||||
- 文档上传与解析
|
||||
- 文档块生成与存储
|
||||
- 多媒体资源管理
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db: BaseDatabase,
|
||||
vec_db: BaseVecDB,
|
||||
storage_path: str,
|
||||
parsers: dict[str, BaseParser],
|
||||
chunker: BaseChunker,
|
||||
provider_manager=None,
|
||||
):
|
||||
self.db = db
|
||||
self.vec_db = vec_db
|
||||
self.storage_path = Path(storage_path)
|
||||
self.media_path = self.storage_path / "media"
|
||||
self.files_path = self.storage_path / "files"
|
||||
self.parsers = parsers
|
||||
self.chunker = chunker
|
||||
self.provider_manager = provider_manager
|
||||
|
||||
# 确保目录存在
|
||||
self.media_path.mkdir(parents=True, exist_ok=True)
|
||||
self.files_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ===== 知识库操作 =====
|
||||
|
||||
async def create_kb(
|
||||
self,
|
||||
kb_name: str,
|
||||
description: Optional[str] = None,
|
||||
emoji: Optional[str] = None,
|
||||
embedding_provider_id: Optional[str] = None,
|
||||
rerank_provider_id: Optional[str] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
top_k_dense: Optional[int] = None,
|
||||
top_k_sparse: Optional[int] = None,
|
||||
top_m_final: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> KnowledgeBase:
|
||||
"""创建知识库
|
||||
|
||||
Args:
|
||||
enable_rerank: 是否启用重排序。
|
||||
- 如果明确传入 True/False,则使用该值
|
||||
- 如果为 None,则根据是否有可用的 rerank provider 自动决定
|
||||
"""
|
||||
# 智能决定 enable_rerank 的默认值
|
||||
if enable_rerank is None:
|
||||
# 检查是否有可用的 rerank provider
|
||||
has_rerank_provider = (
|
||||
self.provider_manager
|
||||
and hasattr(self.provider_manager, 'rerank_provider_insts')
|
||||
and len(self.provider_manager.rerank_provider_insts) > 0
|
||||
)
|
||||
enable_rerank = has_rerank_provider
|
||||
|
||||
kb = KnowledgeBase(
|
||||
kb_name=kb_name,
|
||||
description=description,
|
||||
emoji=emoji or "📚",
|
||||
embedding_provider_id=embedding_provider_id,
|
||||
rerank_provider_id=rerank_provider_id,
|
||||
chunk_size=chunk_size if chunk_size is not None else 512,
|
||||
chunk_overlap=chunk_overlap if chunk_overlap is not None else 50,
|
||||
top_k_dense=top_k_dense if top_k_dense is not None else 50,
|
||||
top_k_sparse=top_k_sparse if top_k_sparse is not None else 50,
|
||||
top_m_final=top_m_final if top_m_final is not None else 5,
|
||||
enable_rerank=enable_rerank,
|
||||
)
|
||||
async with self.db.get_db() as session:
|
||||
session.add(kb)
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
return kb
|
||||
|
||||
async def get_kb(self, kb_id: str) -> Optional[KnowledgeBase]:
|
||||
"""获取知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_kbs(self, offset: int = 0, limit: int = 100) -> list[KnowledgeBase]:
|
||||
"""列出所有知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KnowledgeBase)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KnowledgeBase.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def update_kb(
|
||||
self,
|
||||
kb_id: str,
|
||||
kb_name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
emoji: Optional[str] = None,
|
||||
embedding_provider_id: Optional[str] = None,
|
||||
rerank_provider_id: Optional[str] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
top_k_dense: Optional[int] = None,
|
||||
top_k_sparse: Optional[int] = None,
|
||||
top_m_final: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> Optional[KnowledgeBase]:
|
||||
"""更新知识库"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
kb = result.scalar_one_or_none()
|
||||
if not kb:
|
||||
return None
|
||||
|
||||
if kb_name is not None:
|
||||
kb.kb_name = kb_name
|
||||
if description is not None:
|
||||
kb.description = description
|
||||
if emoji is not None:
|
||||
kb.emoji = emoji
|
||||
if embedding_provider_id is not None:
|
||||
kb.embedding_provider_id = embedding_provider_id
|
||||
if rerank_provider_id is not None:
|
||||
kb.rerank_provider_id = rerank_provider_id
|
||||
if chunk_size is not None:
|
||||
kb.chunk_size = chunk_size
|
||||
if chunk_overlap is not None:
|
||||
kb.chunk_overlap = chunk_overlap
|
||||
if top_k_dense is not None:
|
||||
kb.top_k_dense = top_k_dense
|
||||
if top_k_sparse is not None:
|
||||
kb.top_k_sparse = top_k_sparse
|
||||
if top_m_final is not None:
|
||||
kb.top_m_final = top_m_final
|
||||
if enable_rerank is not None:
|
||||
kb.enable_rerank = enable_rerank
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(kb)
|
||||
return kb
|
||||
|
||||
async def delete_kb(self, kb_id: str) -> bool:
|
||||
"""删除知识库(级联删除所有文档和资源)"""
|
||||
# 1. 获取所有文档
|
||||
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
|
||||
|
||||
ops = KBManagerOps(self)
|
||||
docs = await ops.list_documents(kb_id)
|
||||
|
||||
# 2. 删除所有文档(包括文件和向量)
|
||||
for doc in docs:
|
||||
await ops.delete_document(doc.doc_id)
|
||||
|
||||
# 3. 删除知识库记录
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KnowledgeBase).where(KnowledgeBase.kb_id == kb_id)
|
||||
result = await session.execute(stmt)
|
||||
kb = result.scalar_one_or_none()
|
||||
if not kb:
|
||||
return False
|
||||
|
||||
await session.delete(kb)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
# ===== 文档上传 =====
|
||||
|
||||
async def upload_document(
|
||||
self,
|
||||
kb_id: str,
|
||||
file_name: str,
|
||||
file_content: bytes,
|
||||
file_type: str,
|
||||
) -> KBDocument:
|
||||
"""上传并处理文档(带原子性保证和失败清理)
|
||||
|
||||
流程:
|
||||
1. 保存原始文件
|
||||
2. 解析文档内容
|
||||
3. 提取多媒体资源
|
||||
4. 分块处理
|
||||
5. 生成向量并存储
|
||||
6. 保存元数据(事务)
|
||||
7. 更新统计
|
||||
"""
|
||||
doc_id = str(uuid.uuid4())
|
||||
file_path = None
|
||||
media_paths = []
|
||||
vec_doc_ids = []
|
||||
|
||||
try:
|
||||
# 1. 保存原始文件
|
||||
file_path = self.files_path / kb_id / f"{doc_id}.{file_type}"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(file_content)
|
||||
|
||||
# 2. 解析文档
|
||||
parser = self.parsers.get(file_type)
|
||||
if not parser:
|
||||
raise ValueError(f"不支持的文件类型: {file_type}")
|
||||
|
||||
parse_result = await parser.parse(file_content, file_name)
|
||||
text_content = parse_result.text
|
||||
media_items = parse_result.media
|
||||
|
||||
# 3. 保存多媒体资源
|
||||
from astrbot.core.knowledge_base.manager_ops import KBManagerOps
|
||||
|
||||
ops = KBManagerOps(self)
|
||||
saved_media = []
|
||||
for media_item in media_items:
|
||||
media = await ops._save_media(
|
||||
kb_id=kb_id,
|
||||
doc_id=doc_id,
|
||||
media_type=media_item.media_type,
|
||||
file_name=media_item.file_name,
|
||||
content=media_item.content,
|
||||
mime_type=media_item.mime_type,
|
||||
)
|
||||
saved_media.append(media)
|
||||
media_paths.append(Path(media.file_path))
|
||||
|
||||
# 4. 文档分块
|
||||
chunks_text = await self.chunker.chunk(text_content)
|
||||
|
||||
# 5. 生成向量并存储
|
||||
saved_chunks = []
|
||||
for idx, chunk_text in enumerate(chunks_text):
|
||||
# 存储到向量数据库
|
||||
vec_doc_id = await self.vec_db.insert(
|
||||
content=chunk_text,
|
||||
metadata={
|
||||
"kb_id": kb_id,
|
||||
"doc_id": doc_id,
|
||||
"chunk_index": idx,
|
||||
},
|
||||
)
|
||||
vec_doc_ids.append(str(vec_doc_id))
|
||||
|
||||
# 保存块元数据
|
||||
chunk = KBChunk(
|
||||
doc_id=doc_id,
|
||||
kb_id=kb_id,
|
||||
chunk_index=idx,
|
||||
content=chunk_text,
|
||||
char_count=len(chunk_text),
|
||||
vec_doc_id=str(vec_doc_id),
|
||||
)
|
||||
saved_chunks.append(chunk)
|
||||
|
||||
# 6. 保存文档元数据(事务)
|
||||
doc = KBDocument(
|
||||
doc_id=doc_id,
|
||||
kb_id=kb_id,
|
||||
doc_name=file_name,
|
||||
file_type=file_type,
|
||||
file_size=len(file_content),
|
||||
file_path=str(file_path),
|
||||
chunk_count=len(saved_chunks),
|
||||
media_count=len(saved_media),
|
||||
)
|
||||
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
session.add(doc)
|
||||
for chunk in saved_chunks:
|
||||
session.add(chunk)
|
||||
for media in saved_media:
|
||||
session.add(media)
|
||||
await session.commit()
|
||||
|
||||
await session.refresh(doc)
|
||||
|
||||
# 7. 更新知识库统计
|
||||
await self._update_kb_stats(kb_id)
|
||||
|
||||
return doc
|
||||
|
||||
except Exception as e:
|
||||
# 失败清理:删除已创建的资源
|
||||
from astrbot.core import logger
|
||||
|
||||
logger.error(f"文档上传失败,开始清理资源: {e}")
|
||||
|
||||
# 清理向量数据库
|
||||
for vec_id in vec_doc_ids:
|
||||
try:
|
||||
await self.vec_db.delete(vec_id)
|
||||
except Exception as ve:
|
||||
logger.warning(f"清理向量失败 {vec_id}: {ve}")
|
||||
|
||||
# 清理多媒体文件
|
||||
for media_path in media_paths:
|
||||
try:
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as me:
|
||||
logger.warning(f"清理多媒体文件失败 {media_path}: {me}")
|
||||
|
||||
# 清理文档文件
|
||||
if file_path and file_path.exists():
|
||||
try:
|
||||
file_path.unlink()
|
||||
except Exception as fe:
|
||||
logger.warning(f"清理文档文件失败 {file_path}: {fe}")
|
||||
|
||||
# 重新抛出原始异常
|
||||
raise
|
||||
|
||||
# ===== 统计更新 =====
|
||||
|
||||
async def _update_kb_stats(self, kb_id: str):
|
||||
"""更新知识库统计信息(事务中执行)"""
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
# 统计文档数(在事务中查询)
|
||||
doc_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBDocument.id)).where(
|
||||
KBDocument.kb_id == kb_id
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# 统计块数(在事务中查询)
|
||||
chunk_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBChunk.id)).where(KBChunk.kb_id == kb_id)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
# 更新知识库(在同一事务中)
|
||||
await session.execute(
|
||||
update(KnowledgeBase)
|
||||
.where(KnowledgeBase.kb_id == kb_id)
|
||||
.values(doc_count=doc_count, chunk_count=chunk_count)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
@@ -1,312 +0,0 @@
|
||||
"""知识库管理器辅助操作
|
||||
|
||||
该模块提供文档、块和多媒体的管理操作。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiofiles
|
||||
from sqlalchemy import delete, func, select
|
||||
|
||||
from astrbot.core.knowledge_base.models import KBChunk, KBDocument, KBMedia
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.manager import KBManager
|
||||
|
||||
|
||||
class KBManagerOps:
|
||||
"""知识库管理器辅助操作类
|
||||
|
||||
职责:
|
||||
- 文档管理操作
|
||||
- 块管理操作
|
||||
- 多媒体管理操作
|
||||
"""
|
||||
|
||||
def __init__(self, manager: "KBManager"):
|
||||
self.manager = manager
|
||||
self.db = manager.db
|
||||
self.vec_db = manager.vec_db
|
||||
self.media_path = manager.media_path
|
||||
self.files_path = manager.files_path
|
||||
|
||||
# ===== 文档操作 =====
|
||||
|
||||
async def list_documents(
|
||||
self, kb_id: str, offset: int = 0, limit: int = 100
|
||||
) -> list[KBDocument]:
|
||||
"""列出知识库的所有文档"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBDocument)
|
||||
.where(KBDocument.kb_id == kb_id)
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBDocument.created_at.desc())
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_document(self, doc_id: str) -> KBDocument | None:
|
||||
"""获取文档详情"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def delete_document(self, doc_id: str) -> bool:
|
||||
"""删除文档(级联删除块、多媒体、向量)
|
||||
|
||||
采用三阶段删除策略:
|
||||
1. 删除向量数据库中的向量(允许部分失败)
|
||||
2. 删除SQL数据库中的记录(事务保证原子性)
|
||||
3. 删除文件系统中的文件(失败不影响数据一致性)
|
||||
"""
|
||||
from astrbot.core import logger
|
||||
|
||||
# 0. 获取文档信息
|
||||
doc = await self.get_document(doc_id)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# 收集所有需要删除的资源
|
||||
chunks = await self.list_chunks(doc_id)
|
||||
media_list = await self.list_media(doc_id)
|
||||
|
||||
# ===== 第一阶段: 删除向量(可重试) =====
|
||||
vec_ids_to_delete = [chunk.vec_doc_id for chunk in chunks]
|
||||
deleted_vec_ids = []
|
||||
failed_vec_ids = []
|
||||
|
||||
for vec_id in vec_ids_to_delete:
|
||||
try:
|
||||
await self.vec_db.delete(vec_id)
|
||||
deleted_vec_ids.append(vec_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除向量失败: {vec_id}, {e}")
|
||||
failed_vec_ids.append(vec_id)
|
||||
|
||||
# 如果向量删除失败过多(超过50%),中止操作
|
||||
if len(failed_vec_ids) > len(vec_ids_to_delete) * 0.5:
|
||||
logger.error(
|
||||
f"向量删除失败过多 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 中止文档删除"
|
||||
)
|
||||
return False
|
||||
|
||||
# 记录部分失败但继续执行
|
||||
if failed_vec_ids:
|
||||
logger.warning(
|
||||
f"部分向量删除失败 ({len(failed_vec_ids)}/{len(vec_ids_to_delete)}), 但继续执行删除操作"
|
||||
)
|
||||
|
||||
# ===== 第二阶段: 删除数据库记录(事务) =====
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
# 删除块记录
|
||||
await session.execute(delete(KBChunk).where(KBChunk.doc_id == doc_id))
|
||||
|
||||
# 删除多媒体记录
|
||||
await session.execute(delete(KBMedia).where(KBMedia.doc_id == doc_id))
|
||||
|
||||
# 删除文档记录
|
||||
await session.execute(
|
||||
delete(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# ===== 第三阶段: 删除文件(失败不影响) =====
|
||||
# 删除多媒体文件
|
||||
for media in media_list:
|
||||
try:
|
||||
media_path = Path(media.file_path)
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"删除多媒体文件失败: {media.file_path}, {e}")
|
||||
|
||||
# 删除文档文件
|
||||
try:
|
||||
file_path = Path(doc.file_path)
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"删除文档文件失败: {doc.file_path}, {e}")
|
||||
|
||||
# ===== 更新统计 =====
|
||||
await self.manager._update_kb_stats(doc.kb_id)
|
||||
|
||||
return True
|
||||
|
||||
# ===== 块操作 =====
|
||||
|
||||
async def list_chunks(self, doc_id: str) -> list[KBChunk]:
|
||||
"""列出文档的所有块"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = (
|
||||
select(KBChunk)
|
||||
.where(KBChunk.doc_id == doc_id)
|
||||
.order_by(KBChunk.chunk_index)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_chunk(self, chunk_id: str) -> bool:
|
||||
"""删除单个块
|
||||
|
||||
流程:
|
||||
1. 查询块信息
|
||||
2. 删除向量
|
||||
3. 删除数据库记录
|
||||
4. 更新文档统计
|
||||
"""
|
||||
from astrbot.core import logger
|
||||
|
||||
# 1. 查询块信息
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
||||
result = await session.execute(stmt)
|
||||
chunk = result.scalar_one_or_none()
|
||||
if not chunk:
|
||||
return False
|
||||
|
||||
doc_id = chunk.doc_id
|
||||
vec_doc_id = chunk.vec_doc_id
|
||||
|
||||
# 2. 删除向量
|
||||
try:
|
||||
await self.vec_db.delete(vec_doc_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除向量失败: {vec_doc_id}, {e}")
|
||||
return False
|
||||
|
||||
# 3. 删除数据库记录
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(KBChunk).where(KBChunk.chunk_id == chunk_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# 4. 更新文档统计
|
||||
await self._update_doc_stats(doc_id)
|
||||
|
||||
return True
|
||||
|
||||
# ===== 多媒体操作 =====
|
||||
|
||||
async def list_media(self, doc_id: str) -> list[KBMedia]:
|
||||
"""列出文档的所有多媒体资源"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.doc_id == doc_id)
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def delete_media(self, media_id: str) -> bool:
|
||||
"""删除多媒体资源
|
||||
|
||||
流程:
|
||||
1. 查询媒体信息
|
||||
2. 删除数据库记录
|
||||
3. 删除文件(失败不影响)
|
||||
4. 更新文档统计
|
||||
"""
|
||||
from astrbot.core import logger
|
||||
|
||||
# 1. 查询媒体信息
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBMedia).where(KBMedia.media_id == media_id)
|
||||
result = await session.execute(stmt)
|
||||
media = result.scalar_one_or_none()
|
||||
if not media:
|
||||
return False
|
||||
|
||||
doc_id = media.doc_id
|
||||
file_path_str = media.file_path
|
||||
|
||||
# 2. 删除数据库记录
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(KBMedia).where(KBMedia.media_id == media_id)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# 3. 删除文件(失败不影响)
|
||||
try:
|
||||
media_path = Path(file_path_str)
|
||||
if media_path.exists():
|
||||
media_path.unlink()
|
||||
except Exception as e:
|
||||
logger.warning(f"删除多媒体文件失败: {file_path_str}, {e}")
|
||||
|
||||
# 4. 更新文档统计
|
||||
await self._update_doc_stats(doc_id)
|
||||
|
||||
return True
|
||||
|
||||
# ===== 内部辅助方法 =====
|
||||
|
||||
async def _save_media(
|
||||
self,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
media_type: str,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
mime_type: str,
|
||||
) -> KBMedia:
|
||||
"""保存多媒体资源"""
|
||||
media_id = str(uuid.uuid4())
|
||||
ext = Path(file_name).suffix
|
||||
|
||||
# 保存文件
|
||||
file_path = self.media_path / kb_id / doc_id / f"{media_id}{ext}"
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(content)
|
||||
|
||||
# 创建记录
|
||||
media = KBMedia(
|
||||
media_id=media_id,
|
||||
doc_id=doc_id,
|
||||
kb_id=kb_id,
|
||||
media_type=media_type,
|
||||
file_name=file_name,
|
||||
file_path=str(file_path),
|
||||
file_size=len(content),
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
return media
|
||||
|
||||
async def _update_doc_stats(self, doc_id: str):
|
||||
"""更新文档统计信息(事务中执行)"""
|
||||
async with self.db.get_db() as session:
|
||||
async with session.begin():
|
||||
# 统计块数
|
||||
chunk_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBChunk.id)).where(KBChunk.doc_id == doc_id)
|
||||
)
|
||||
) or 0
|
||||
|
||||
# 统计多媒体数
|
||||
media_count = (
|
||||
await session.scalar(
|
||||
select(func.count(KBMedia.id)).where(KBMedia.doc_id == doc_id)
|
||||
)
|
||||
) or 0
|
||||
|
||||
# 更新文档
|
||||
doc = await session.scalar(
|
||||
select(KBDocument).where(KBDocument.doc_id == doc_id)
|
||||
)
|
||||
if doc:
|
||||
doc.chunk_count = chunk_count
|
||||
doc.media_count = media_count
|
||||
|
||||
await session.commit()
|
||||
@@ -1,31 +1,20 @@
|
||||
"""知识库管理功能的数据模型定义
|
||||
|
||||
该模块定义了知识库系统所需的数据模型,包括:
|
||||
- KnowledgeBase: 知识库表 (存储在独立的 kb.db)
|
||||
- KBDocument: 文档表 (存储在独立的 kb.db)
|
||||
- KBChunk: 文档块表 (存储在独立的 kb.db)
|
||||
- KBMedia: 多媒体资源表 (存储在独立的 kb.db)
|
||||
- KBSessionConfig: 会话配置表 (存储在独立的 kb.db)
|
||||
|
||||
注意:
|
||||
- 所有模型存储在独立的知识库数据库 (data/knowledge_base/kb.db)
|
||||
- 与主数据库 (astrbot.db) 完全解耦
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import Field, SQLModel, Text, UniqueConstraint
|
||||
from sqlmodel import Field, SQLModel, Text, UniqueConstraint, MetaData
|
||||
|
||||
|
||||
class KnowledgeBase(SQLModel, table=True):
|
||||
class BaseKBModel(SQLModel, table=False):
|
||||
metadata = MetaData()
|
||||
|
||||
|
||||
class KnowledgeBase(BaseKBModel, table=True):
|
||||
"""知识库表
|
||||
|
||||
存储知识库的基本信息和统计数据。
|
||||
"""
|
||||
|
||||
__tablename__ = "knowledge_bases"
|
||||
__tablename__ = "knowledge_bases" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
@@ -38,18 +27,17 @@ class KnowledgeBase(SQLModel, table=True):
|
||||
index=True,
|
||||
)
|
||||
kb_name: str = Field(max_length=100, nullable=False)
|
||||
description: Optional[str] = Field(default=None, sa_type=Text)
|
||||
emoji: Optional[str] = Field(default="📚", max_length=10)
|
||||
embedding_provider_id: Optional[str] = Field(default=None, max_length=100)
|
||||
rerank_provider_id: Optional[str] = Field(default=None, max_length=100)
|
||||
description: str | None = Field(default=None, sa_type=Text)
|
||||
emoji: str | None = Field(default="📚", max_length=10)
|
||||
embedding_provider_id: str | None = Field(default=None, max_length=100)
|
||||
rerank_provider_id: str | None = Field(default=None, max_length=100)
|
||||
# 分块配置参数
|
||||
chunk_size: Optional[int] = Field(default=512, nullable=True)
|
||||
chunk_overlap: Optional[int] = Field(default=50, nullable=True)
|
||||
chunk_size: int | None = Field(default=512, nullable=True)
|
||||
chunk_overlap: int | None = Field(default=50, nullable=True)
|
||||
# 检索配置参数
|
||||
top_k_dense: Optional[int] = Field(default=50, nullable=True)
|
||||
top_k_sparse: Optional[int] = Field(default=50, nullable=True)
|
||||
top_m_final: Optional[int] = Field(default=5, nullable=True)
|
||||
enable_rerank: Optional[bool] = Field(default=False, nullable=True)
|
||||
top_k_dense: int | None = Field(default=50, nullable=True)
|
||||
top_k_sparse: int | None = Field(default=50, nullable=True)
|
||||
top_m_final: int | None = Field(default=5, nullable=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
@@ -58,14 +46,21 @@ class KnowledgeBase(SQLModel, table=True):
|
||||
doc_count: int = Field(default=0, nullable=False)
|
||||
chunk_count: int = Field(default=0, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"kb_name",
|
||||
name="uix_kb_name",
|
||||
),
|
||||
)
|
||||
|
||||
class KBDocument(SQLModel, table=True):
|
||||
|
||||
class KBDocument(BaseKBModel, table=True):
|
||||
"""文档表
|
||||
|
||||
存储上传到知识库的文档元数据。
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_documents"
|
||||
__tablename__ = "kb_documents" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
@@ -91,40 +86,13 @@ class KBDocument(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class KBChunk(SQLModel, table=True):
|
||||
"""文档块表
|
||||
|
||||
存储文档分块后的文本内容和向量索引关联信息。
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_chunks"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
chunk_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
index=True,
|
||||
)
|
||||
doc_id: str = Field(max_length=36, nullable=False, index=True)
|
||||
kb_id: str = Field(max_length=36, nullable=False, index=True)
|
||||
chunk_index: int = Field(nullable=False)
|
||||
content: str = Field(sa_type=Text, nullable=False)
|
||||
char_count: int = Field(nullable=False)
|
||||
vec_doc_id: str = Field(max_length=100, nullable=False, index=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class KBMedia(SQLModel, table=True):
|
||||
class KBMedia(BaseKBModel, table=True):
|
||||
"""多媒体资源表
|
||||
|
||||
存储从文档中提取的图片、视频等多媒体资源。
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_media"
|
||||
__tablename__ = "kb_media" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
@@ -144,39 +112,3 @@ class KBMedia(SQLModel, table=True):
|
||||
file_size: int = Field(nullable=False)
|
||||
mime_type: str = Field(max_length=100, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class KBSessionConfig(SQLModel, table=True):
|
||||
"""会话知识库配置表
|
||||
|
||||
存储会话或平台级别的知识库关联配置。
|
||||
该表存储在知识库独立数据库中,保持完全解耦。
|
||||
|
||||
支持两种配置范围:
|
||||
- platform: 平台级别配置 (如 'qq', 'telegram')
|
||||
- session: 会话级别配置 (如 'qq:group:12345')
|
||||
"""
|
||||
|
||||
__tablename__ = "kb_session_config"
|
||||
|
||||
id: int | None = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}, default=None
|
||||
)
|
||||
config_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
scope: str = Field(max_length=20, nullable=False)
|
||||
scope_id: str = Field(max_length=255, nullable=False, index=True)
|
||||
kb_ids: str = Field(sa_type=Text, nullable=False)
|
||||
top_k: Optional[int] = Field(default=None, nullable=True)
|
||||
enable_rerank: Optional[bool] = Field(default=None, nullable=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (UniqueConstraint("scope", "scope_id", name="uix_scope_scope_id"),)
|
||||
|
||||
@@ -51,10 +51,10 @@ class PDFParser(BaseParser):
|
||||
continue
|
||||
|
||||
resources = page["/Resources"]
|
||||
if not resources or "/XObject" not in resources:
|
||||
if not resources or "/XObject" not in resources: # type: ignore
|
||||
continue
|
||||
|
||||
xobjects = resources["/XObject"].get_object()
|
||||
xobjects = resources["/XObject"].get_object() # type: ignore
|
||||
if not xobjects:
|
||||
continue
|
||||
|
||||
|
||||
@@ -0,0 +1,767 @@
|
||||
———
|
||||
》),
|
||||
)÷(1-
|
||||
”,
|
||||
)、
|
||||
=(
|
||||
:
|
||||
→
|
||||
℃
|
||||
&
|
||||
*
|
||||
一一
|
||||
~~~~
|
||||
’
|
||||
.
|
||||
『
|
||||
.一
|
||||
./
|
||||
--
|
||||
』
|
||||
=″
|
||||
【
|
||||
[*]
|
||||
}>
|
||||
[⑤]]
|
||||
[①D]
|
||||
c]
|
||||
ng昉
|
||||
*
|
||||
//
|
||||
[
|
||||
]
|
||||
[②e]
|
||||
[②g]
|
||||
={
|
||||
}
|
||||
,也
|
||||
‘
|
||||
A
|
||||
[①⑥]
|
||||
[②B]
|
||||
[①a]
|
||||
[④a]
|
||||
[①③]
|
||||
[③h]
|
||||
③]
|
||||
1.
|
||||
--
|
||||
[②b]
|
||||
’‘
|
||||
×××
|
||||
[①⑧]
|
||||
0:2
|
||||
=[
|
||||
[⑤b]
|
||||
[②c]
|
||||
[④b]
|
||||
[②③]
|
||||
[③a]
|
||||
[④c]
|
||||
[①⑤]
|
||||
[①⑦]
|
||||
[①g]
|
||||
∈[
|
||||
[①⑨]
|
||||
[①④]
|
||||
[①c]
|
||||
[②f]
|
||||
[②⑧]
|
||||
[②①]
|
||||
[①C]
|
||||
[③c]
|
||||
[③g]
|
||||
[②⑤]
|
||||
[②②]
|
||||
一.
|
||||
[①h]
|
||||
.数
|
||||
[]
|
||||
[①B]
|
||||
数/
|
||||
[①i]
|
||||
[③e]
|
||||
[①①]
|
||||
[④d]
|
||||
[④e]
|
||||
[③b]
|
||||
[⑤a]
|
||||
[①A]
|
||||
[②⑧]
|
||||
[②⑦]
|
||||
[①d]
|
||||
[②j]
|
||||
〕〔
|
||||
][
|
||||
://
|
||||
′∈
|
||||
[②④
|
||||
[⑤e]
|
||||
12%
|
||||
b]
|
||||
...
|
||||
...................
|
||||
…………………………………………………③
|
||||
ZXFITL
|
||||
[③F]
|
||||
」
|
||||
[①o]
|
||||
]∧′=[
|
||||
∪φ∈
|
||||
′|
|
||||
{-
|
||||
②c
|
||||
}
|
||||
[③①]
|
||||
R.L.
|
||||
[①E]
|
||||
Ψ
|
||||
-[*]-
|
||||
↑
|
||||
.日
|
||||
[②d]
|
||||
[②
|
||||
[②⑦]
|
||||
[②②]
|
||||
[③e]
|
||||
[①i]
|
||||
[①B]
|
||||
[①h]
|
||||
[①d]
|
||||
[①g]
|
||||
[①②]
|
||||
[②a]
|
||||
f]
|
||||
[⑩]
|
||||
a]
|
||||
[①e]
|
||||
[②h]
|
||||
[②⑥]
|
||||
[③d]
|
||||
[②⑩]
|
||||
e]
|
||||
〉
|
||||
】
|
||||
元/吨
|
||||
[②⑩]
|
||||
2.3%
|
||||
5:0
|
||||
[①]
|
||||
::
|
||||
[②]
|
||||
[③]
|
||||
[④]
|
||||
[⑤]
|
||||
[⑥]
|
||||
[⑦]
|
||||
[⑧]
|
||||
[⑨]
|
||||
……
|
||||
——
|
||||
?
|
||||
、
|
||||
。
|
||||
“
|
||||
”
|
||||
《
|
||||
》
|
||||
!
|
||||
,
|
||||
:
|
||||
;
|
||||
?
|
||||
.
|
||||
,
|
||||
.
|
||||
'
|
||||
?
|
||||
·
|
||||
———
|
||||
──
|
||||
?
|
||||
—
|
||||
<
|
||||
>
|
||||
(
|
||||
)
|
||||
〔
|
||||
〕
|
||||
[
|
||||
]
|
||||
(
|
||||
)
|
||||
-
|
||||
+
|
||||
~
|
||||
×
|
||||
/
|
||||
/
|
||||
①
|
||||
②
|
||||
③
|
||||
④
|
||||
⑤
|
||||
⑥
|
||||
⑦
|
||||
⑧
|
||||
⑨
|
||||
⑩
|
||||
Ⅲ
|
||||
В
|
||||
"
|
||||
;
|
||||
#
|
||||
@
|
||||
γ
|
||||
μ
|
||||
φ
|
||||
φ.
|
||||
×
|
||||
Δ
|
||||
■
|
||||
▲
|
||||
sub
|
||||
exp
|
||||
sup
|
||||
sub
|
||||
Lex
|
||||
#
|
||||
%
|
||||
&
|
||||
'
|
||||
+
|
||||
+ξ
|
||||
++
|
||||
-
|
||||
-β
|
||||
<
|
||||
<±
|
||||
<Δ
|
||||
<λ
|
||||
<φ
|
||||
<<
|
||||
=
|
||||
=
|
||||
=☆
|
||||
=-
|
||||
>
|
||||
>λ
|
||||
_
|
||||
~±
|
||||
~+
|
||||
[⑤f]
|
||||
[⑤d]
|
||||
[②i]
|
||||
≈
|
||||
[②G]
|
||||
[①f]
|
||||
LI
|
||||
㈧
|
||||
[-
|
||||
......
|
||||
〉
|
||||
[③⑩]
|
||||
第二
|
||||
一番
|
||||
一直
|
||||
一个
|
||||
一些
|
||||
许多
|
||||
种
|
||||
有的是
|
||||
也就是说
|
||||
末##末
|
||||
啊
|
||||
阿
|
||||
哎
|
||||
哎呀
|
||||
哎哟
|
||||
唉
|
||||
俺
|
||||
俺们
|
||||
按
|
||||
按照
|
||||
吧
|
||||
吧哒
|
||||
把
|
||||
罢了
|
||||
被
|
||||
本
|
||||
本着
|
||||
比
|
||||
比方
|
||||
比如
|
||||
鄙人
|
||||
彼
|
||||
彼此
|
||||
边
|
||||
别
|
||||
别的
|
||||
别说
|
||||
并
|
||||
并且
|
||||
不比
|
||||
不成
|
||||
不单
|
||||
不但
|
||||
不独
|
||||
不管
|
||||
不光
|
||||
不过
|
||||
不仅
|
||||
不拘
|
||||
不论
|
||||
不怕
|
||||
不然
|
||||
不如
|
||||
不特
|
||||
不惟
|
||||
不问
|
||||
不只
|
||||
朝
|
||||
朝着
|
||||
趁
|
||||
趁着
|
||||
乘
|
||||
冲
|
||||
除
|
||||
除此之外
|
||||
除非
|
||||
除了
|
||||
此
|
||||
此间
|
||||
此外
|
||||
从
|
||||
从而
|
||||
打
|
||||
待
|
||||
但
|
||||
但是
|
||||
当
|
||||
当着
|
||||
到
|
||||
得
|
||||
的
|
||||
的话
|
||||
等
|
||||
等等
|
||||
地
|
||||
第
|
||||
叮咚
|
||||
对
|
||||
对于
|
||||
多
|
||||
多少
|
||||
而
|
||||
而况
|
||||
而且
|
||||
而是
|
||||
而外
|
||||
而言
|
||||
而已
|
||||
尔后
|
||||
反过来
|
||||
反过来说
|
||||
反之
|
||||
非但
|
||||
非徒
|
||||
否则
|
||||
嘎
|
||||
嘎登
|
||||
该
|
||||
赶
|
||||
个
|
||||
各
|
||||
各个
|
||||
各位
|
||||
各种
|
||||
各自
|
||||
给
|
||||
根据
|
||||
跟
|
||||
故
|
||||
故此
|
||||
固然
|
||||
关于
|
||||
管
|
||||
归
|
||||
果然
|
||||
果真
|
||||
过
|
||||
哈
|
||||
哈哈
|
||||
呵
|
||||
和
|
||||
何
|
||||
何处
|
||||
何况
|
||||
何时
|
||||
嘿
|
||||
哼
|
||||
哼唷
|
||||
呼哧
|
||||
乎
|
||||
哗
|
||||
还是
|
||||
还有
|
||||
换句话说
|
||||
换言之
|
||||
或
|
||||
或是
|
||||
或者
|
||||
极了
|
||||
及
|
||||
及其
|
||||
及至
|
||||
即
|
||||
即便
|
||||
即或
|
||||
即令
|
||||
即若
|
||||
即使
|
||||
几
|
||||
几时
|
||||
己
|
||||
既
|
||||
既然
|
||||
既是
|
||||
继而
|
||||
加之
|
||||
假如
|
||||
假若
|
||||
假使
|
||||
鉴于
|
||||
将
|
||||
较
|
||||
较之
|
||||
叫
|
||||
接着
|
||||
结果
|
||||
借
|
||||
紧接着
|
||||
进而
|
||||
尽
|
||||
尽管
|
||||
经
|
||||
经过
|
||||
就
|
||||
就是
|
||||
就是说
|
||||
据
|
||||
具体地说
|
||||
具体说来
|
||||
开始
|
||||
开外
|
||||
靠
|
||||
咳
|
||||
可
|
||||
可见
|
||||
可是
|
||||
可以
|
||||
况且
|
||||
啦
|
||||
来
|
||||
来着
|
||||
离
|
||||
例如
|
||||
哩
|
||||
连
|
||||
连同
|
||||
两者
|
||||
了
|
||||
临
|
||||
另
|
||||
另外
|
||||
另一方面
|
||||
论
|
||||
嘛
|
||||
吗
|
||||
慢说
|
||||
漫说
|
||||
冒
|
||||
么
|
||||
每
|
||||
每当
|
||||
们
|
||||
莫若
|
||||
某
|
||||
某个
|
||||
某些
|
||||
拿
|
||||
哪
|
||||
哪边
|
||||
哪儿
|
||||
哪个
|
||||
哪里
|
||||
哪年
|
||||
哪怕
|
||||
哪天
|
||||
哪些
|
||||
哪样
|
||||
那
|
||||
那边
|
||||
那儿
|
||||
那个
|
||||
那会儿
|
||||
那里
|
||||
那么
|
||||
那么些
|
||||
那么样
|
||||
那时
|
||||
那些
|
||||
那样
|
||||
乃
|
||||
乃至
|
||||
呢
|
||||
能
|
||||
你
|
||||
你们
|
||||
您
|
||||
宁
|
||||
宁可
|
||||
宁肯
|
||||
宁愿
|
||||
哦
|
||||
呕
|
||||
啪达
|
||||
旁人
|
||||
呸
|
||||
凭
|
||||
凭借
|
||||
其
|
||||
其次
|
||||
其二
|
||||
其他
|
||||
其它
|
||||
其一
|
||||
其余
|
||||
其中
|
||||
起
|
||||
起见
|
||||
起见
|
||||
岂但
|
||||
恰恰相反
|
||||
前后
|
||||
前者
|
||||
且
|
||||
然而
|
||||
然后
|
||||
然则
|
||||
让
|
||||
人家
|
||||
任
|
||||
任何
|
||||
任凭
|
||||
如
|
||||
如此
|
||||
如果
|
||||
如何
|
||||
如其
|
||||
如若
|
||||
如上所述
|
||||
若
|
||||
若非
|
||||
若是
|
||||
啥
|
||||
上下
|
||||
尚且
|
||||
设若
|
||||
设使
|
||||
甚而
|
||||
甚么
|
||||
甚至
|
||||
省得
|
||||
时候
|
||||
什么
|
||||
什么样
|
||||
使得
|
||||
是
|
||||
是的
|
||||
首先
|
||||
谁
|
||||
谁知
|
||||
顺
|
||||
顺着
|
||||
似的
|
||||
虽
|
||||
虽然
|
||||
虽说
|
||||
虽则
|
||||
随
|
||||
随着
|
||||
所
|
||||
所以
|
||||
他
|
||||
他们
|
||||
他人
|
||||
它
|
||||
它们
|
||||
她
|
||||
她们
|
||||
倘
|
||||
倘或
|
||||
倘然
|
||||
倘若
|
||||
倘使
|
||||
腾
|
||||
替
|
||||
通过
|
||||
同
|
||||
同时
|
||||
哇
|
||||
万一
|
||||
往
|
||||
望
|
||||
为
|
||||
为何
|
||||
为了
|
||||
为什么
|
||||
为着
|
||||
喂
|
||||
嗡嗡
|
||||
我
|
||||
我们
|
||||
呜
|
||||
呜呼
|
||||
乌乎
|
||||
无论
|
||||
无宁
|
||||
毋宁
|
||||
嘻
|
||||
吓
|
||||
相对而言
|
||||
像
|
||||
向
|
||||
向着
|
||||
嘘
|
||||
呀
|
||||
焉
|
||||
沿
|
||||
沿着
|
||||
要
|
||||
要不
|
||||
要不然
|
||||
要不是
|
||||
要么
|
||||
要是
|
||||
也
|
||||
也罢
|
||||
也好
|
||||
一
|
||||
一般
|
||||
一旦
|
||||
一方面
|
||||
一来
|
||||
一切
|
||||
一样
|
||||
一则
|
||||
依
|
||||
依照
|
||||
矣
|
||||
以
|
||||
以便
|
||||
以及
|
||||
以免
|
||||
以至
|
||||
以至于
|
||||
以致
|
||||
抑或
|
||||
因
|
||||
因此
|
||||
因而
|
||||
因为
|
||||
哟
|
||||
用
|
||||
由
|
||||
由此可见
|
||||
由于
|
||||
有
|
||||
有的
|
||||
有关
|
||||
有些
|
||||
又
|
||||
于
|
||||
于是
|
||||
于是乎
|
||||
与
|
||||
与此同时
|
||||
与否
|
||||
与其
|
||||
越是
|
||||
云云
|
||||
哉
|
||||
再说
|
||||
再者
|
||||
在
|
||||
在下
|
||||
咱
|
||||
咱们
|
||||
则
|
||||
怎
|
||||
怎么
|
||||
怎么办
|
||||
怎么样
|
||||
怎样
|
||||
咋
|
||||
照
|
||||
照着
|
||||
者
|
||||
这
|
||||
这边
|
||||
这儿
|
||||
这个
|
||||
这会儿
|
||||
这就是说
|
||||
这里
|
||||
这么
|
||||
这么点儿
|
||||
这么些
|
||||
这么样
|
||||
这时
|
||||
这些
|
||||
这样
|
||||
正如
|
||||
吱
|
||||
之
|
||||
之类
|
||||
之所以
|
||||
之一
|
||||
只是
|
||||
只限
|
||||
只要
|
||||
只有
|
||||
至
|
||||
至于
|
||||
诸位
|
||||
着
|
||||
着呢
|
||||
自
|
||||
自从
|
||||
自个儿
|
||||
自各儿
|
||||
自己
|
||||
自家
|
||||
自身
|
||||
综上所述
|
||||
总的来看
|
||||
总的来说
|
||||
总的说来
|
||||
总而言之
|
||||
总之
|
||||
纵
|
||||
纵令
|
||||
纵然
|
||||
纵使
|
||||
遵照
|
||||
作为
|
||||
兮
|
||||
呃
|
||||
呗
|
||||
咚
|
||||
咦
|
||||
喏
|
||||
啐
|
||||
喔唷
|
||||
嗬
|
||||
嗯
|
||||
嗳
|
||||
@@ -3,15 +3,19 @@
|
||||
协调稠密检索、稀疏检索和 Rerank,提供统一的检索接口
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
import time
|
||||
|
||||
from astrbot.core.db.vec_db.base import BaseVecDB
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
from astrbot.core.db.vec_db.base import Result
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
from ..kb_helper import KBHelper
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,36 +42,29 @@ class RetrievalManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vec_db: BaseVecDB,
|
||||
sparse_retriever: SparseRetriever,
|
||||
rank_fusion: RankFusion,
|
||||
kb_db: KBDatabase,
|
||||
rerank_provider: Optional[RerankProvider] = None,
|
||||
kb_db: KBSQLiteDatabase,
|
||||
):
|
||||
"""初始化检索管理器
|
||||
|
||||
Args:
|
||||
vec_db: 向量数据库实例
|
||||
vec_db_factory: 向量数据库工厂
|
||||
sparse_retriever: 稀疏检索器
|
||||
rank_fusion: 结果融合器
|
||||
kb_db: 知识库数据库实例
|
||||
rerank_provider: Rerank 提供商 (可选)
|
||||
"""
|
||||
self.vec_db = vec_db
|
||||
self.sparse_retriever = sparse_retriever
|
||||
self.rank_fusion = rank_fusion
|
||||
self.kb_db = kb_db
|
||||
self.rerank_provider = rerank_provider
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
top_k_dense: int = 50,
|
||||
top_k_sparse: int = 50,
|
||||
top_n_fusion: int = 20,
|
||||
kb_id_helper_map: dict[str, KBHelper],
|
||||
top_k_fusion: int = 20,
|
||||
top_m_final: int = 5,
|
||||
enable_rerank: bool = True,
|
||||
) -> List[RetrievalResult]:
|
||||
"""混合检索
|
||||
|
||||
@@ -80,40 +77,74 @@ class RetrievalManager:
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k_dense: 稠密检索返回数量
|
||||
top_k_sparse: 稀疏检索返回数量
|
||||
top_n_fusion: 融合后返回数量
|
||||
top_m_final: 最终返回数量
|
||||
enable_rerank: 是否启用 Rerank
|
||||
|
||||
Returns:
|
||||
List[RetrievalResult]: 检索结果列表
|
||||
"""
|
||||
if not kb_ids:
|
||||
return []
|
||||
|
||||
kb_options: dict = {}
|
||||
new_kb_ids = []
|
||||
for kb_id in kb_ids:
|
||||
kb_helper = kb_id_helper_map.get(kb_id)
|
||||
if kb_helper:
|
||||
kb = kb_helper.kb
|
||||
kb_options[kb_id] = {
|
||||
"top_k_dense": kb.top_k_dense or 50,
|
||||
"top_k_sparse": kb.top_k_sparse or 50,
|
||||
"top_m_final": kb.top_m_final or 5,
|
||||
"vec_db": kb_helper.vec_db,
|
||||
"rerank_provider_id": kb.rerank_provider_id,
|
||||
}
|
||||
new_kb_ids.append(kb_id)
|
||||
else:
|
||||
logger.warning(f"知识库 ID {kb_id} 实例未找到, 已跳过该知识库的检索")
|
||||
|
||||
kb_ids = new_kb_ids
|
||||
|
||||
# 1. 稠密检索
|
||||
time_start = time.time()
|
||||
dense_results = await self._dense_retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
top_k=top_k_dense,
|
||||
kb_options=kb_options,
|
||||
)
|
||||
time_end = time.time()
|
||||
logger.debug(
|
||||
f"Dense retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(dense_results)} results."
|
||||
)
|
||||
|
||||
# 2. 稀疏检索
|
||||
time_start = time.time()
|
||||
sparse_results = await self.sparse_retriever.retrieve(
|
||||
query=query,
|
||||
kb_ids=kb_ids,
|
||||
top_k=top_k_sparse,
|
||||
kb_options=kb_options,
|
||||
)
|
||||
time_end = time.time()
|
||||
logger.debug(
|
||||
f"Sparse retrieval across {len(kb_ids)} bases took {time_end - time_start:.2f}s and returned {len(sparse_results)} results."
|
||||
)
|
||||
|
||||
# 3. 结果融合
|
||||
time_start = time.time()
|
||||
fused_results = await self.rank_fusion.fuse(
|
||||
dense_results=dense_results,
|
||||
sparse_results=sparse_results,
|
||||
top_k=top_n_fusion,
|
||||
top_k=top_k_fusion,
|
||||
)
|
||||
time_end = time.time()
|
||||
logger.debug(
|
||||
f"Rank fusion took {time_end - time_start:.2f}s and returned {len(fused_results)} results."
|
||||
)
|
||||
|
||||
# 4. 转换为 RetrievalResult (获取元数据)
|
||||
retrieval_results = []
|
||||
for fr in fused_results:
|
||||
metadata_dict = await self.kb_db.get_chunk_with_metadata(fr.chunk_id)
|
||||
metadata_dict = await self.kb_db.get_document_with_metadata(fr.doc_id)
|
||||
if metadata_dict:
|
||||
retrieval_results.append(
|
||||
RetrievalResult(
|
||||
@@ -125,33 +156,45 @@ class RetrievalManager:
|
||||
content=fr.content,
|
||||
score=fr.score,
|
||||
metadata={
|
||||
"chunk_index": metadata_dict["chunk"].chunk_index,
|
||||
"char_count": metadata_dict["chunk"].char_count,
|
||||
"chunk_index": fr.chunk_index,
|
||||
"char_count": len(fr.content),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# 5. Rerank (可选)
|
||||
if enable_rerank and self.rerank_provider and retrieval_results:
|
||||
# 5. Rerank
|
||||
first_rerank = None
|
||||
for kb_id in kb_ids:
|
||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
|
||||
if (
|
||||
vec_db
|
||||
and vec_db.rerank_provider
|
||||
and rerank_pi
|
||||
and rerank_pi == vec_db.rerank_provider.meta().id
|
||||
):
|
||||
first_rerank = vec_db.rerank_provider
|
||||
break
|
||||
if first_rerank and retrieval_results:
|
||||
retrieval_results = await self._rerank(
|
||||
query=query,
|
||||
results=retrieval_results,
|
||||
top_k=top_m_final,
|
||||
rerank_provider=self.rerank_provider,
|
||||
rerank_provider=first_rerank,
|
||||
)
|
||||
else:
|
||||
retrieval_results = retrieval_results[:top_m_final]
|
||||
|
||||
return retrieval_results
|
||||
return retrieval_results[:top_m_final]
|
||||
|
||||
async def _dense_retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
top_k: int,
|
||||
kb_options: dict,
|
||||
):
|
||||
"""稠密检索 (向量相似度)
|
||||
|
||||
为每个知识库使用独立的向量数据库进行检索,然后合并结果。
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
@@ -160,28 +203,32 @@ class RetrievalManager:
|
||||
Returns:
|
||||
List[Result]: 检索结果列表
|
||||
"""
|
||||
# 直接调用向量数据库检索
|
||||
vec_results = await self.vec_db.retrieve(
|
||||
query=query,
|
||||
top_k=top_k * len(kb_ids) * 2, # 增加候选数量以便过滤
|
||||
)
|
||||
|
||||
# 过滤:只保留指定知识库的结果
|
||||
filtered_results = []
|
||||
for result in vec_results:
|
||||
metadata_str = result.data.get("metadata", "{}")
|
||||
all_results: list[Result] = []
|
||||
for kb_id in kb_ids:
|
||||
if kb_id not in kb_options:
|
||||
continue
|
||||
try:
|
||||
metadata = json.loads(metadata_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
metadata = {}
|
||||
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
|
||||
dense_k = int(kb_options[kb_id]["top_k_dense"])
|
||||
vec_results = await vec_db.retrieve(
|
||||
query=query,
|
||||
k=dense_k,
|
||||
fetch_k=dense_k * 2,
|
||||
rerank=False, # 稠密检索阶段不进行 rerank
|
||||
metadata_filters={"kb_id": kb_id},
|
||||
)
|
||||
|
||||
if metadata.get("kb_id") in kb_ids:
|
||||
filtered_results.append(result)
|
||||
all_results.extend(vec_results)
|
||||
except Exception as e:
|
||||
from astrbot.core import logger
|
||||
|
||||
if len(filtered_results) >= top_k:
|
||||
break
|
||||
logger.warning(f"知识库 {kb_id} 稠密检索失败: {e}")
|
||||
continue
|
||||
|
||||
return filtered_results[:top_k]
|
||||
# 按相似度排序并返回 top_k
|
||||
all_results.sort(key=lambda x: x.similarity, reverse=True)
|
||||
# return all_results[: len(all_results) // len(kb_ids)]
|
||||
return all_results
|
||||
|
||||
async def _rerank(
|
||||
self,
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
使用 Reciprocal Rank Fusion (RRF) 算法融合稠密检索和稀疏检索的结果
|
||||
"""
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
from astrbot.core.db.vec_db.base import Result
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseResult
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ class FusedResult:
|
||||
"""融合后的检索结果"""
|
||||
|
||||
chunk_id: str
|
||||
chunk_index: int
|
||||
doc_id: str
|
||||
kb_id: str
|
||||
content: str
|
||||
@@ -30,7 +31,7 @@ class RankFusion:
|
||||
- 使用 Reciprocal Rank Fusion (RRF) 算法
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBDatabase, k: int = 60):
|
||||
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
|
||||
"""初始化结果融合器
|
||||
|
||||
Args:
|
||||
@@ -42,10 +43,10 @@ class RankFusion:
|
||||
|
||||
async def fuse(
|
||||
self,
|
||||
dense_results: List[Result],
|
||||
sparse_results: List[SparseResult],
|
||||
dense_results: list[Result],
|
||||
sparse_results: list[SparseResult],
|
||||
top_k: int = 20,
|
||||
) -> List[FusedResult]:
|
||||
) -> list[FusedResult]:
|
||||
"""融合稠密和稀疏检索结果
|
||||
|
||||
RRF 公式:
|
||||
@@ -62,14 +63,14 @@ class RankFusion:
|
||||
# 1. 构建排名映射
|
||||
dense_ranks = {
|
||||
r.data["doc_id"]: (idx + 1) for idx, r in enumerate(dense_results)
|
||||
}
|
||||
} # 这里的 doc_id 实际上是 chunk_id
|
||||
sparse_ranks = {r.chunk_id: (idx + 1) for idx, r in enumerate(sparse_results)}
|
||||
|
||||
# 2. 收集所有唯一的 ID (来自稠密检索的是 vec_doc_id, 稀疏检索的是 chunk_id)
|
||||
# 2. 收集所有唯一的 ID
|
||||
# 需要统一为 chunk_id
|
||||
all_chunk_ids = set()
|
||||
vec_doc_id_to_dense = {} # vec_doc_id -> Result
|
||||
chunk_id_to_sparse = {} # chunk_id -> SparseResult
|
||||
vec_doc_id_to_dense: dict[str, Result] = {} # vec_doc_id -> Result
|
||||
chunk_id_to_sparse: dict[str, SparseResult] = {} # chunk_id -> SparseResult
|
||||
|
||||
# 处理稀疏检索结果
|
||||
for r in sparse_results:
|
||||
@@ -83,7 +84,7 @@ class RankFusion:
|
||||
vec_doc_id_to_dense[vec_doc_id] = r
|
||||
|
||||
# 3. 计算 RRF 分数
|
||||
rrf_scores: Dict[str, float] = {}
|
||||
rrf_scores: dict[str, float] = {}
|
||||
|
||||
for identifier in all_chunk_ids:
|
||||
score = 0.0
|
||||
@@ -112,6 +113,7 @@ class RankFusion:
|
||||
fused_results.append(
|
||||
FusedResult(
|
||||
chunk_id=sr.chunk_id,
|
||||
chunk_index=sr.chunk_index,
|
||||
doc_id=sr.doc_id,
|
||||
kb_id=sr.kb_id,
|
||||
content=sr.content,
|
||||
@@ -120,16 +122,17 @@ class RankFusion:
|
||||
)
|
||||
elif identifier in vec_doc_id_to_dense:
|
||||
# 从向量检索获取信息,需要从数据库获取块的详细信息
|
||||
chunk = await self.kb_db.get_chunk_by_vec_doc_id(identifier)
|
||||
if chunk:
|
||||
fused_results.append(
|
||||
FusedResult(
|
||||
chunk_id=chunk.chunk_id,
|
||||
doc_id=chunk.doc_id,
|
||||
kb_id=chunk.kb_id,
|
||||
content=chunk.content,
|
||||
score=rrf_scores[identifier],
|
||||
)
|
||||
vec_result = vec_doc_id_to_dense[identifier]
|
||||
chunk_md = json.loads(vec_result.data["metadata"])
|
||||
fused_results.append(
|
||||
FusedResult(
|
||||
chunk_id=identifier,
|
||||
chunk_index=chunk_md["chunk_index"],
|
||||
doc_id=chunk_md["kb_doc_id"],
|
||||
kb_id=chunk_md["kb_id"],
|
||||
content=vec_result.data["text"],
|
||||
score=rrf_scores[identifier],
|
||||
)
|
||||
)
|
||||
|
||||
return fused_results
|
||||
|
||||
@@ -3,18 +3,20 @@
|
||||
使用 BM25 算法进行基于关键词的文档检索
|
||||
"""
|
||||
|
||||
import jieba
|
||||
import os
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
from astrbot.core.knowledge_base.database import KBDatabase
|
||||
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseResult:
|
||||
"""稀疏检索结果"""
|
||||
|
||||
chunk_index: int
|
||||
chunk_id: str
|
||||
doc_id: str
|
||||
kb_id: str
|
||||
@@ -30,7 +32,7 @@ class SparseRetriever:
|
||||
- 使用 BM25 算法计算相关度
|
||||
"""
|
||||
|
||||
def __init__(self, kb_db: KBDatabase):
|
||||
def __init__(self, kb_db: KBSQLiteDatabase):
|
||||
"""初始化稀疏检索器
|
||||
|
||||
Args:
|
||||
@@ -39,37 +41,73 @@ class SparseRetriever:
|
||||
self.kb_db = kb_db
|
||||
self._index_cache = {} # 缓存 BM25 索引
|
||||
|
||||
with open(
|
||||
os.path.join(os.path.dirname(__file__), "hit_stopwords.txt"),
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
self.hit_stopwords = {
|
||||
word.strip() for word in set(f.read().splitlines()) if word.strip()
|
||||
}
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
kb_ids: List[str],
|
||||
top_k: int = 50,
|
||||
) -> List[SparseResult]:
|
||||
kb_ids: list[str],
|
||||
kb_options: dict,
|
||||
) -> list[SparseResult]:
|
||||
"""执行稀疏检索
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k: 返回结果数量
|
||||
kb_options: 每个知识库的检索选项
|
||||
|
||||
Returns:
|
||||
List[SparseResult]: 检索结果列表
|
||||
"""
|
||||
# 1. 获取所有相关块
|
||||
chunks = await self.kb_db.get_chunks_by_kb_ids(kb_ids)
|
||||
top_k_sparse = 0
|
||||
chunks = []
|
||||
for kb_id in kb_ids:
|
||||
vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
|
||||
if not vec_db:
|
||||
continue
|
||||
result = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={}, limit=None, offset=None
|
||||
)
|
||||
chunk_mds = [json.loads(doc["metadata"]) for doc in result]
|
||||
result = [
|
||||
{
|
||||
"chunk_id": doc["doc_id"],
|
||||
"chunk_index": chunk_md["chunk_index"],
|
||||
"doc_id": chunk_md["kb_doc_id"],
|
||||
"kb_id": kb_id,
|
||||
"text": doc["text"],
|
||||
}
|
||||
for doc, chunk_md in zip(result, chunk_mds)
|
||||
]
|
||||
chunks.extend(result)
|
||||
top_k_sparse += kb_options.get(kb_id, {}).get("top_k_sparse", 50)
|
||||
|
||||
if not chunks:
|
||||
return []
|
||||
|
||||
# 2. 准备文档和索引
|
||||
corpus = [chunk.content for chunk in chunks]
|
||||
tokenized_corpus = [doc.split() for doc in corpus]
|
||||
corpus = [chunk["text"] for chunk in chunks]
|
||||
tokenized_corpus = [list(jieba.cut(doc)) for doc in corpus]
|
||||
tokenized_corpus = [
|
||||
[word for word in doc if word not in self.hit_stopwords]
|
||||
for doc in tokenized_corpus
|
||||
]
|
||||
|
||||
# 3. 构建 BM25 索引
|
||||
bm25 = BM25Okapi(tokenized_corpus)
|
||||
|
||||
# 4. 执行检索
|
||||
tokenized_query = query.split()
|
||||
tokenized_query = list(jieba.cut(query))
|
||||
tokenized_query = [
|
||||
word for word in tokenized_query if word not in self.hit_stopwords
|
||||
]
|
||||
scores = bm25.get_scores(tokenized_query)
|
||||
|
||||
# 5. 排序并返回 Top-K
|
||||
@@ -78,13 +116,15 @@ class SparseRetriever:
|
||||
chunk = chunks[idx]
|
||||
results.append(
|
||||
SparseResult(
|
||||
chunk_id=chunk.chunk_id,
|
||||
doc_id=chunk.doc_id,
|
||||
kb_id=chunk.kb_id,
|
||||
content=chunk.content,
|
||||
chunk_id=chunk["chunk_id"],
|
||||
chunk_index=chunk["chunk_index"],
|
||||
doc_id=chunk["doc_id"],
|
||||
kb_id=chunk["kb_id"],
|
||||
content=chunk["text"],
|
||||
score=float(score),
|
||||
)
|
||||
)
|
||||
|
||||
results.sort(key=lambda x: x.score, reverse=True)
|
||||
return results[:top_k]
|
||||
# return results[: len(results) // len(kb_ids)]
|
||||
return results[:top_k_sparse]
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
"""会话知识库配置数据库操作
|
||||
|
||||
该模块封装会话知识库配置的数据库查询操作。
|
||||
|
||||
注意: 会话配置表 (kb_session_config) 存储在知识库独立数据库 (kb.db) 中,
|
||||
而不是主数据库 (astrbot.db) 中,以实现完全解耦。
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from astrbot.core.knowledge_base.kb_sqlite import KBSQLiteDatabase
|
||||
from astrbot.core.knowledge_base.models import KBSessionConfig
|
||||
|
||||
|
||||
class SessionConfigDB:
|
||||
"""会话知识库配置数据库操作类
|
||||
|
||||
职责:
|
||||
- 提供会话知识库配置管理
|
||||
- 统一异常处理
|
||||
|
||||
注意: 该类操作知识库独立数据库,实现完全解耦
|
||||
"""
|
||||
|
||||
def __init__(self, db: KBSQLiteDatabase):
|
||||
"""初始化会话配置数据库操作类
|
||||
|
||||
Args:
|
||||
db: 知识库独立数据库实例 (kb.db),不是主数据库
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
async def get_session_kb_ids(self, session_id: str) -> list[str]:
|
||||
"""获取会话关联的知识库 ID 列表
|
||||
|
||||
查找顺序:
|
||||
1. 会话级别配置 (优先)
|
||||
2. 平台级别配置
|
||||
3. 返回空列表
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 1. 查找会话级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "session",
|
||||
KBSessionConfig.scope_id == session_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 2. 提取平台 ID (格式: platform:xxx:session_id)
|
||||
parts = session_id.split(":")
|
||||
if len(parts) >= 2:
|
||||
platform_id = parts[0]
|
||||
|
||||
# 查找平台级别配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == "platform",
|
||||
KBSessionConfig.scope_id == platform_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
return json.loads(config.kb_ids)
|
||||
|
||||
# 3. 无配置
|
||||
return []
|
||||
|
||||
async def set_session_kb_ids(
|
||||
self,
|
||||
scope: str,
|
||||
scope_id: str,
|
||||
kb_ids: list[str],
|
||||
top_k: Optional[int] = None,
|
||||
enable_rerank: Optional[bool] = None,
|
||||
) -> KBSessionConfig:
|
||||
"""设置会话知识库配置
|
||||
|
||||
Args:
|
||||
scope: 配置范围 (session/platform)
|
||||
scope_id: 范围标识 (会话 ID 或平台 ID)
|
||||
kb_ids: 知识库 ID 列表
|
||||
top_k: 返回结果数量 (可选)
|
||||
enable_rerank: 是否启用 Rerank (可选)
|
||||
"""
|
||||
async with self.db.get_db() as session:
|
||||
# 查找现有配置
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config:
|
||||
# 更新现有配置
|
||||
config.kb_ids = json.dumps(kb_ids)
|
||||
if top_k is not None:
|
||||
config.top_k = top_k
|
||||
if enable_rerank is not None:
|
||||
config.enable_rerank = enable_rerank
|
||||
else:
|
||||
# 创建新配置
|
||||
config = KBSessionConfig(
|
||||
scope=scope,
|
||||
scope_id=scope_id,
|
||||
kb_ids=json.dumps(kb_ids),
|
||||
top_k=top_k,
|
||||
enable_rerank=enable_rerank,
|
||||
)
|
||||
session.add(config)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
async def delete_session_kb_config(self, scope: str, scope_id: str) -> bool:
|
||||
"""删除会话知识库配置"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig).where(
|
||||
KBSessionConfig.scope == scope,
|
||||
KBSessionConfig.scope_id == scope_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if not config:
|
||||
return False
|
||||
|
||||
await session.delete(config)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def list_all_session_configs(
|
||||
self, offset: int = 0, limit: int = 100, scope: Optional[str] = None
|
||||
) -> list[KBSessionConfig]:
|
||||
"""列出所有会话配置"""
|
||||
async with self.db.get_db() as session:
|
||||
stmt = select(KBSessionConfig)
|
||||
|
||||
if scope:
|
||||
stmt = stmt.where(KBSessionConfig.scope == scope)
|
||||
|
||||
stmt = (
|
||||
stmt.offset(offset)
|
||||
.limit(limit)
|
||||
.order_by(KBSessionConfig.created_at.desc())
|
||||
)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
@@ -7,7 +7,7 @@ import copy
|
||||
import json
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
from typing import AsyncGenerator, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Image
|
||||
@@ -45,7 +45,7 @@ except (ModuleNotFoundError, ImportError):
|
||||
|
||||
|
||||
AgentContextWrapper = ContextWrapper[AstrAgentContext]
|
||||
AgentRunner = ToolLoopAgentRunner[AgentContextWrapper]
|
||||
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
|
||||
|
||||
|
||||
class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
@@ -103,7 +103,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
|
||||
request = ProviderRequest(
|
||||
prompt=input_,
|
||||
system_prompt=tool.description,
|
||||
system_prompt=tool.description or "",
|
||||
image_urls=[], # 暂时不传递原始 agent 的上下文
|
||||
contexts=[], # 暂时不传递原始 agent 的上下文
|
||||
func_tool=toolset,
|
||||
@@ -240,7 +240,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
||||
yield res
|
||||
|
||||
|
||||
class MainAgentHooks(BaseAgentRunHooks[AgentContextWrapper]):
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
await call_event_hook(
|
||||
@@ -338,7 +338,7 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent) -> Provider | None:
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
"""选择使用的 LLM 提供商"""
|
||||
sel_provider = event.get_extra("selected_provider")
|
||||
_ctx = self.ctx.plugin_manager.context
|
||||
@@ -368,7 +368,7 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, _nested: bool = False
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
@@ -383,6 +383,9 @@ class LLMRequestSubStage(Stage):
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||
return
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
@@ -489,6 +492,9 @@ class LLMRequestSubStage(Stage):
|
||||
new_tool_set.add_tool(tool)
|
||||
req.func_tool = new_tool_set
|
||||
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
@@ -526,8 +532,10 @@ class LLMRequestSubStage(Stage):
|
||||
chain = (
|
||||
MessageChain().message(final_llm_resp.completion_text).chain
|
||||
)
|
||||
else:
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
@@ -538,6 +546,9 @@ class LLMRequestSubStage(Stage):
|
||||
async for _ in run_agent(agent_runner, self.max_step, self.show_tool_use):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
@@ -556,6 +567,8 @@ class LLMRequestSubStage(Stage):
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||
):
|
||||
"""处理 WebChat 平台的特殊情况,包括第一次 LLM 对话时总结对话内容生成 title"""
|
||||
if not req.conversation:
|
||||
return
|
||||
conversation = await self.conv_manager.get_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ async def inject_kb_context(
|
||||
umo: str,
|
||||
p_ctx: PipelineContext,
|
||||
req: ProviderRequest,
|
||||
top_k: int = 5,
|
||||
) -> None:
|
||||
"""inject knowledge base context into the provider request
|
||||
|
||||
@@ -15,13 +14,19 @@ async def inject_kb_context(
|
||||
p_ctx: Pipeline context
|
||||
req: Provider request
|
||||
"""
|
||||
kb_injector = p_ctx.plugin_manager.context.kb_manager.get_kb_injector()
|
||||
if not kb_injector:
|
||||
kb_mgr = p_ctx.plugin_manager.context.kb_manager
|
||||
kb_names = p_ctx.astrbot_config.get("kb_names", [])
|
||||
top_k_fusion = p_ctx.astrbot_config.get("kb_fusion_top_k", 20)
|
||||
top_k = p_ctx.astrbot_config.get("kb_final_top_k", 5)
|
||||
|
||||
if not kb_names:
|
||||
return
|
||||
kb_context = await kb_injector.retrieve_and_inject(
|
||||
unified_msg_origin=umo,
|
||||
|
||||
kb_context = await kb_mgr.retrieve(
|
||||
query=req.prompt,
|
||||
top_k=top_k,
|
||||
kb_names=kb_names,
|
||||
top_k_fusion=top_k_fusion,
|
||||
top_m_final=top_k,
|
||||
)
|
||||
if not kb_context:
|
||||
return
|
||||
|
||||
@@ -4,7 +4,7 @@ import re
|
||||
import hashlib
|
||||
import uuid
|
||||
|
||||
from typing import List, Union, Optional, AsyncGenerator, TypeVar, Any
|
||||
from typing import List, Union, Optional, AsyncGenerator, Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.db.po import Conversation
|
||||
@@ -26,8 +26,6 @@ from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
from .message_session import MessageSession, MessageSesion # noqa
|
||||
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class AstrMessageEvent(abc.ABC):
|
||||
def __init__(
|
||||
@@ -92,8 +90,10 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
return self.message_str
|
||||
|
||||
def _outline_chain(self, chain: List[BaseMessageComponent]) -> str:
|
||||
def _outline_chain(self, chain: Optional[List[BaseMessageComponent]]) -> str:
|
||||
outline = ""
|
||||
if not chain:
|
||||
return outline
|
||||
for i in chain:
|
||||
if isinstance(i, Plain):
|
||||
outline += i.text
|
||||
@@ -175,9 +175,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
self._extras[key] = value
|
||||
|
||||
def get_extra(
|
||||
self, key: str | None = None, default: _VT = None
|
||||
) -> dict[str, Any] | _VT:
|
||||
def get_extra(self, key: str | None = None, default=None) -> Any:
|
||||
"""
|
||||
获取额外的信息。
|
||||
"""
|
||||
@@ -265,6 +263,9 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
result = MessageEventResult().message(result)
|
||||
# 兼容外部插件或调用方传入的 chain=None 的情况,确保为可迭代列表
|
||||
if isinstance(result, MessageEventResult) and result.chain is None:
|
||||
result.chain = []
|
||||
self._result = result
|
||||
|
||||
def stop_event(self):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Any, Optional, Awaitable
|
||||
import random
|
||||
from typing import Dict, Any, Optional, Awaitable, List
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import MessageChain
|
||||
@@ -14,6 +14,13 @@ from astrbot.core.platform.astr_message_event import MessageSession
|
||||
import astrbot.api.message_components as Comp
|
||||
|
||||
from .misskey_api import MisskeyAPI
|
||||
import os
|
||||
|
||||
try:
|
||||
import magic # type: ignore
|
||||
except Exception:
|
||||
magic = None
|
||||
|
||||
from .misskey_event import MisskeyPlatformEvent
|
||||
from .misskey_utils import (
|
||||
serialize_message_chain,
|
||||
@@ -25,9 +32,15 @@ from .misskey_utils import (
|
||||
extract_sender_info,
|
||||
create_base_message,
|
||||
process_at_mention,
|
||||
format_poll,
|
||||
cache_user_info,
|
||||
cache_room_info,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
# Constants
|
||||
MAX_FILE_UPLOAD_COUNT = 16
|
||||
DEFAULT_UPLOAD_CONCURRENCY = 3
|
||||
|
||||
|
||||
@register_platform_adapter("misskey", "Misskey 平台适配器")
|
||||
@@ -46,6 +59,31 @@ class MisskeyPlatformAdapter(Platform):
|
||||
)
|
||||
self.local_only = self.config.get("misskey_local_only", False)
|
||||
self.enable_chat = self.config.get("misskey_enable_chat", True)
|
||||
self.enable_file_upload = self.config.get("misskey_enable_file_upload", True)
|
||||
self.upload_folder = self.config.get("misskey_upload_folder")
|
||||
|
||||
# download / security related options (exposed to platform_config)
|
||||
self.allow_insecure_downloads = bool(
|
||||
self.config.get("misskey_allow_insecure_downloads", False)
|
||||
)
|
||||
# parse download timeout and chunk size safely
|
||||
_dt = self.config.get("misskey_download_timeout")
|
||||
try:
|
||||
self.download_timeout = int(_dt) if _dt is not None else 15
|
||||
except Exception:
|
||||
self.download_timeout = 15
|
||||
|
||||
_chunk = self.config.get("misskey_download_chunk_size")
|
||||
try:
|
||||
self.download_chunk_size = int(_chunk) if _chunk is not None else 64 * 1024
|
||||
except Exception:
|
||||
self.download_chunk_size = 64 * 1024
|
||||
# parse max download bytes safely
|
||||
_md_bytes = self.config.get("misskey_max_download_bytes")
|
||||
try:
|
||||
self.max_download_bytes = int(_md_bytes) if _md_bytes is not None else None
|
||||
except Exception:
|
||||
self.max_download_bytes = None
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
@@ -63,6 +101,11 @@ class MisskeyPlatformAdapter(Platform):
|
||||
"misskey_default_visibility": "public",
|
||||
"misskey_local_only": False,
|
||||
"misskey_enable_chat": True,
|
||||
# download / security options
|
||||
"misskey_allow_insecure_downloads": False,
|
||||
"misskey_download_timeout": 15,
|
||||
"misskey_download_chunk_size": 65536,
|
||||
"misskey_max_download_bytes": None,
|
||||
}
|
||||
default_config.update(self.config)
|
||||
|
||||
@@ -78,7 +121,14 @@ class MisskeyPlatformAdapter(Platform):
|
||||
logger.error("[Misskey] 配置不完整,无法启动")
|
||||
return
|
||||
|
||||
self.api = MisskeyAPI(self.instance_url, self.access_token)
|
||||
self.api = MisskeyAPI(
|
||||
self.instance_url,
|
||||
self.access_token,
|
||||
allow_insecure_downloads=self.allow_insecure_downloads,
|
||||
download_timeout=self.download_timeout,
|
||||
chunk_size=self.download_chunk_size,
|
||||
max_download_bytes=self.max_download_bytes,
|
||||
)
|
||||
self._running = True
|
||||
|
||||
try:
|
||||
@@ -95,6 +145,80 @@ class MisskeyPlatformAdapter(Platform):
|
||||
|
||||
await self._start_websocket_connection()
|
||||
|
||||
def _register_event_handlers(self, streaming):
|
||||
"""注册事件处理器"""
|
||||
streaming.add_message_handler("notification", self._handle_notification)
|
||||
streaming.add_message_handler("main:notification", self._handle_notification)
|
||||
|
||||
if self.enable_chat:
|
||||
streaming.add_message_handler("newChatMessage", self._handle_chat_message)
|
||||
streaming.add_message_handler(
|
||||
"messaging:newChatMessage", self._handle_chat_message
|
||||
)
|
||||
streaming.add_message_handler("_debug", self._debug_handler)
|
||||
|
||||
async def _send_text_only_message(
|
||||
self, session_id: str, text: str, session, message_chain
|
||||
):
|
||||
"""发送纯文本消息(无文件上传)"""
|
||||
if not self.api:
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
if session_id and is_valid_user_session_id(session_id):
|
||||
from .misskey_utils import extract_user_id_from_session_id
|
||||
|
||||
user_id = extract_user_id_from_session_id(session_id)
|
||||
payload: Dict[str, Any] = {"toUserId": user_id, "text": text}
|
||||
await self.api.send_message(payload)
|
||||
elif session_id and is_valid_room_session_id(session_id):
|
||||
from .misskey_utils import extract_room_id_from_session_id
|
||||
|
||||
room_id = extract_room_id_from_session_id(session_id)
|
||||
payload = {"toRoomId": room_id, "text": text}
|
||||
await self.api.send_room_message(payload)
|
||||
|
||||
return await super().send_by_session(session, message_chain)
|
||||
|
||||
def _process_poll_data(
|
||||
self, message: AstrBotMessage, poll: Dict[str, Any], message_parts: List[str]
|
||||
):
|
||||
"""处理投票数据,将其添加到消息中"""
|
||||
try:
|
||||
if not isinstance(message.raw_message, dict):
|
||||
message.raw_message = {}
|
||||
message.raw_message["poll"] = poll
|
||||
setattr(message, "poll", poll)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
poll_text = format_poll(poll)
|
||||
if poll_text:
|
||||
message.message.append(Comp.Plain(poll_text))
|
||||
message_parts.append(poll_text)
|
||||
|
||||
def _extract_additional_fields(self, session, message_chain) -> Dict[str, Any]:
|
||||
"""从会话和消息链中提取额外字段"""
|
||||
fields = {"cw": None, "poll": None, "renote_id": None, "channel_id": None}
|
||||
|
||||
for comp in message_chain.chain:
|
||||
if hasattr(comp, "cw") and getattr(comp, "cw", None):
|
||||
fields["cw"] = getattr(comp, "cw")
|
||||
break
|
||||
|
||||
if hasattr(session, "extra_data") and isinstance(
|
||||
getattr(session, "extra_data", None), dict
|
||||
):
|
||||
extra_data = getattr(session, "extra_data")
|
||||
fields.update(
|
||||
{
|
||||
"poll": extra_data.get("poll"),
|
||||
"renote_id": extra_data.get("renote_id"),
|
||||
"channel_id": extra_data.get("channel_id"),
|
||||
}
|
||||
)
|
||||
|
||||
return fields
|
||||
|
||||
async def _start_websocket_connection(self):
|
||||
backoff_delay = 1.0
|
||||
max_backoff = 300.0
|
||||
@@ -109,25 +233,20 @@ class MisskeyPlatformAdapter(Platform):
|
||||
break
|
||||
|
||||
streaming = self.api.get_streaming_client()
|
||||
streaming.add_message_handler("notification", self._handle_notification)
|
||||
if self.enable_chat:
|
||||
streaming.add_message_handler(
|
||||
"newChatMessage", self._handle_chat_message
|
||||
)
|
||||
streaming.add_message_handler("_debug", self._debug_handler)
|
||||
self._register_event_handlers(streaming)
|
||||
|
||||
if await streaming.connect():
|
||||
logger.info(
|
||||
f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})"
|
||||
)
|
||||
connection_attempts = 0 # 重置计数器
|
||||
connection_attempts = 0
|
||||
await streaming.subscribe_channel("main")
|
||||
if self.enable_chat:
|
||||
await streaming.subscribe_channel("messaging")
|
||||
await streaming.subscribe_channel("messagingIndex")
|
||||
logger.info("[Misskey] 聊天频道已订阅")
|
||||
|
||||
backoff_delay = 1.0 # 重置延迟
|
||||
backoff_delay = 1.0
|
||||
await streaming.listen()
|
||||
else:
|
||||
logger.error(
|
||||
@@ -140,18 +259,20 @@ class MisskeyPlatformAdapter(Platform):
|
||||
)
|
||||
|
||||
if self._running:
|
||||
jitter = random.uniform(0, 1.0)
|
||||
sleep_time = backoff_delay + jitter
|
||||
logger.info(
|
||||
f"[Misskey] {backoff_delay:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
|
||||
f"[Misskey] {sleep_time:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
|
||||
)
|
||||
await asyncio.sleep(backoff_delay)
|
||||
await asyncio.sleep(sleep_time)
|
||||
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
|
||||
|
||||
async def _handle_notification(self, data: Dict[str, Any]):
|
||||
try:
|
||||
logger.debug(
|
||||
f"[Misskey] 收到通知事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||
)
|
||||
notification_type = data.get("type")
|
||||
logger.debug(
|
||||
f"[Misskey] 收到通知事件: type={notification_type}, user_id={data.get('userId', 'unknown')}"
|
||||
)
|
||||
if notification_type in ["mention", "reply", "quote"]:
|
||||
note = data.get("note")
|
||||
if note and self._is_bot_mentioned(note):
|
||||
@@ -164,7 +285,7 @@ class MisskeyPlatformAdapter(Platform):
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.api,
|
||||
client=self,
|
||||
)
|
||||
self.commit_event(event)
|
||||
except Exception as e:
|
||||
@@ -172,17 +293,16 @@ class MisskeyPlatformAdapter(Platform):
|
||||
|
||||
async def _handle_chat_message(self, data: Dict[str, Any]):
|
||||
try:
|
||||
logger.debug(
|
||||
f"[Misskey] 收到聊天事件数据:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
sender_id = str(
|
||||
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "")
|
||||
)
|
||||
room_id = data.get("toRoomId")
|
||||
logger.debug(
|
||||
f"[Misskey] 收到聊天事件: sender_id={sender_id}, room_id={room_id}, is_self={sender_id == self.client_self_id}"
|
||||
)
|
||||
if sender_id == self.client_self_id:
|
||||
return
|
||||
|
||||
room_id = data.get("toRoomId")
|
||||
if room_id:
|
||||
raw_text = data.get("text", "")
|
||||
logger.debug(
|
||||
@@ -200,15 +320,16 @@ class MisskeyPlatformAdapter(Platform):
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.api,
|
||||
client=self,
|
||||
)
|
||||
self.commit_event(event)
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
|
||||
|
||||
async def _debug_handler(self, data: Dict[str, Any]):
|
||||
event_type = data.get("type", "unknown")
|
||||
logger.debug(
|
||||
f"[Misskey] 收到未处理事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||
f"[Misskey] 收到未处理事件: type={event_type}, channel={data.get('channel', 'unknown')}"
|
||||
)
|
||||
|
||||
def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool:
|
||||
@@ -239,43 +360,250 @@ class MisskeyPlatformAdapter(Platform):
|
||||
|
||||
try:
|
||||
session_id = session.session_id
|
||||
|
||||
text, has_at_user = serialize_message_chain(message_chain.chain)
|
||||
|
||||
if not has_at_user and session_id:
|
||||
user_info = self._user_cache.get(session_id)
|
||||
# 从session_id中提取用户ID用于缓存查询
|
||||
# session_id格式为: "chat%<user_id>" 或 "room%<room_id>" 或 "note%<user_id>"
|
||||
user_id_for_cache = None
|
||||
if "%" in session_id:
|
||||
parts = session_id.split("%")
|
||||
if len(parts) >= 2:
|
||||
user_id_for_cache = parts[1]
|
||||
|
||||
user_info = None
|
||||
if user_id_for_cache:
|
||||
user_info = self._user_cache.get(user_id_for_cache)
|
||||
|
||||
text = add_at_mention_if_needed(text, user_info, has_at_user)
|
||||
|
||||
# 检查是否有文件组件
|
||||
has_file_components = any(
|
||||
isinstance(comp, Comp.Image)
|
||||
or isinstance(comp, Comp.File)
|
||||
or hasattr(comp, "convert_to_file_path")
|
||||
or hasattr(comp, "get_file")
|
||||
or any(
|
||||
hasattr(comp, a) for a in ("file", "url", "path", "src", "source")
|
||||
)
|
||||
for comp in message_chain.chain
|
||||
)
|
||||
|
||||
if not text or not text.strip():
|
||||
logger.warning("[Misskey] 消息内容为空,跳过发送")
|
||||
return await super().send_by_session(session, message_chain)
|
||||
if not has_file_components:
|
||||
logger.warning("[Misskey] 消息内容为空且无文件组件,跳过发送")
|
||||
return await super().send_by_session(session, message_chain)
|
||||
else:
|
||||
text = ""
|
||||
|
||||
if len(text) > self.max_message_length:
|
||||
text = text[: self.max_message_length] + "..."
|
||||
|
||||
if session_id and is_valid_user_session_id(session_id):
|
||||
from .misskey_utils import extract_user_id_from_session_id
|
||||
file_ids: List[str] = []
|
||||
fallback_urls: List[str] = []
|
||||
|
||||
user_id = extract_user_id_from_session_id(session_id)
|
||||
await self.api.send_message(user_id, text)
|
||||
elif session_id and is_valid_room_session_id(session_id):
|
||||
if not self.enable_file_upload:
|
||||
return await self._send_text_only_message(
|
||||
session_id, text, session, message_chain
|
||||
)
|
||||
|
||||
MAX_UPLOAD_CONCURRENCY = 10
|
||||
upload_concurrency = int(
|
||||
self.config.get(
|
||||
"misskey_upload_concurrency", DEFAULT_UPLOAD_CONCURRENCY
|
||||
)
|
||||
)
|
||||
upload_concurrency = min(upload_concurrency, MAX_UPLOAD_CONCURRENCY)
|
||||
sem = asyncio.Semaphore(upload_concurrency)
|
||||
|
||||
async def _upload_comp(comp) -> Optional[object]:
|
||||
"""组件上传函数:处理 URL(下载后上传)或本地文件(直接上传)"""
|
||||
from .misskey_utils import (
|
||||
resolve_component_url_or_path,
|
||||
upload_local_with_retries,
|
||||
)
|
||||
|
||||
local_path = None
|
||||
try:
|
||||
async with sem:
|
||||
if not self.api:
|
||||
return None
|
||||
|
||||
# 解析组件的 URL 或本地路径
|
||||
url_candidate, local_path = await resolve_component_url_or_path(
|
||||
comp
|
||||
)
|
||||
|
||||
if not url_candidate and not local_path:
|
||||
return None
|
||||
|
||||
preferred_name = getattr(comp, "name", None) or getattr(
|
||||
comp, "file", None
|
||||
)
|
||||
|
||||
# URL 上传:下载后本地上传
|
||||
if url_candidate:
|
||||
result = await self.api.upload_and_find_file(
|
||||
str(url_candidate),
|
||||
preferred_name,
|
||||
folder_id=self.upload_folder,
|
||||
)
|
||||
if isinstance(result, dict) and result.get("id"):
|
||||
return str(result["id"])
|
||||
|
||||
# 本地文件上传
|
||||
if local_path:
|
||||
file_id = await upload_local_with_retries(
|
||||
self.api,
|
||||
str(local_path),
|
||||
preferred_name,
|
||||
self.upload_folder,
|
||||
)
|
||||
if file_id:
|
||||
return file_id
|
||||
|
||||
# 所有上传都失败,尝试获取 URL 作为回退
|
||||
if hasattr(comp, "register_to_file_service"):
|
||||
try:
|
||||
url = await comp.register_to_file_service()
|
||||
if url:
|
||||
return {"fallback_url": url}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if local_path and isinstance(local_path, str):
|
||||
data_temp = os.path.join(get_astrbot_data_path(), "temp")
|
||||
if local_path.startswith(data_temp) and os.path.exists(
|
||||
local_path
|
||||
):
|
||||
try:
|
||||
os.remove(local_path)
|
||||
logger.debug(f"[Misskey] 已清理临时文件: {local_path}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 收集所有可能包含文件/URL信息的组件:支持异步接口或同步字段
|
||||
file_components = []
|
||||
for comp in message_chain.chain:
|
||||
try:
|
||||
if (
|
||||
isinstance(comp, Comp.Image)
|
||||
or isinstance(comp, Comp.File)
|
||||
or hasattr(comp, "convert_to_file_path")
|
||||
or hasattr(comp, "get_file")
|
||||
or any(
|
||||
hasattr(comp, a)
|
||||
for a in ("file", "url", "path", "src", "source")
|
||||
)
|
||||
):
|
||||
file_components.append(comp)
|
||||
except Exception:
|
||||
# 保守跳过无法访问属性的组件
|
||||
continue
|
||||
|
||||
if len(file_components) > MAX_FILE_UPLOAD_COUNT:
|
||||
logger.warning(
|
||||
f"[Misskey] 文件数量超过限制 ({len(file_components)} > {MAX_FILE_UPLOAD_COUNT}),只上传前{MAX_FILE_UPLOAD_COUNT}个文件"
|
||||
)
|
||||
file_components = file_components[:MAX_FILE_UPLOAD_COUNT]
|
||||
|
||||
upload_tasks = [_upload_comp(comp) for comp in file_components]
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*upload_tasks) if upload_tasks else []
|
||||
for r in results:
|
||||
if not r:
|
||||
continue
|
||||
if isinstance(r, dict) and r.get("fallback_url"):
|
||||
url = r.get("fallback_url")
|
||||
if url:
|
||||
fallback_urls.append(str(url))
|
||||
else:
|
||||
try:
|
||||
fid_str = str(r)
|
||||
if fid_str:
|
||||
file_ids.append(fid_str)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
logger.debug("[Misskey] 并发上传过程中出现异常,继续发送文本")
|
||||
|
||||
if session_id and is_valid_room_session_id(session_id):
|
||||
from .misskey_utils import extract_room_id_from_session_id
|
||||
|
||||
room_id = extract_room_id_from_session_id(session_id)
|
||||
await self.api.send_room_message(room_id, text)
|
||||
else:
|
||||
visibility, visible_user_ids = resolve_message_visibility(
|
||||
user_id=session_id,
|
||||
user_cache=self._user_cache,
|
||||
self_id=self.client_self_id,
|
||||
default_visibility=self.default_visibility,
|
||||
if fallback_urls:
|
||||
appended = "\n" + "\n".join(fallback_urls)
|
||||
text = (text or "") + appended
|
||||
payload: Dict[str, Any] = {"toRoomId": room_id, "text": text}
|
||||
if file_ids:
|
||||
payload["fileIds"] = file_ids
|
||||
await self.api.send_room_message(payload)
|
||||
elif session_id:
|
||||
from .misskey_utils import (
|
||||
extract_user_id_from_session_id,
|
||||
is_valid_chat_session_id,
|
||||
)
|
||||
|
||||
await self.api.create_note(
|
||||
text,
|
||||
visibility=visibility,
|
||||
visible_user_ids=visible_user_ids,
|
||||
local_only=self.local_only,
|
||||
)
|
||||
if is_valid_chat_session_id(session_id):
|
||||
user_id = extract_user_id_from_session_id(session_id)
|
||||
if fallback_urls:
|
||||
appended = "\n" + "\n".join(fallback_urls)
|
||||
text = (text or "") + appended
|
||||
payload: Dict[str, Any] = {"toUserId": user_id, "text": text}
|
||||
if file_ids:
|
||||
# 聊天消息只支持单个文件,使用 fileId 而不是 fileIds
|
||||
payload["fileId"] = file_ids[0]
|
||||
if len(file_ids) > 1:
|
||||
logger.warning(
|
||||
f"[Misskey] 聊天消息只支持单个文件,忽略其余 {len(file_ids) - 1} 个文件"
|
||||
)
|
||||
await self.api.send_message(payload)
|
||||
else:
|
||||
# 回退到发帖逻辑
|
||||
# 去掉 session_id 中的 note% 前缀以匹配 user_cache 的键格式
|
||||
user_id_for_cache = (
|
||||
session_id.split("%")[1] if "%" in session_id else session_id
|
||||
)
|
||||
|
||||
# 获取用户缓存信息(包含reply_to_note_id)
|
||||
user_info_for_reply = self._user_cache.get(user_id_for_cache, {})
|
||||
|
||||
visibility, visible_user_ids = resolve_message_visibility(
|
||||
user_id=user_id_for_cache,
|
||||
user_cache=self._user_cache,
|
||||
self_id=self.client_self_id,
|
||||
default_visibility=self.default_visibility,
|
||||
)
|
||||
logger.debug(
|
||||
f"[Misskey] 解析可见性: visibility={visibility}, visible_user_ids={visible_user_ids}, session_id={session_id}, user_id_for_cache={user_id_for_cache}"
|
||||
)
|
||||
|
||||
fields = self._extract_additional_fields(session, message_chain)
|
||||
if fallback_urls:
|
||||
appended = "\n" + "\n".join(fallback_urls)
|
||||
text = (text or "") + appended
|
||||
|
||||
# 从缓存中获取原消息ID作为reply_id
|
||||
reply_id = user_info_for_reply.get("reply_to_note_id")
|
||||
|
||||
await self.api.create_note(
|
||||
text=text,
|
||||
visibility=visibility,
|
||||
visible_user_ids=visible_user_ids,
|
||||
file_ids=file_ids or None,
|
||||
local_only=self.local_only,
|
||||
reply_id=reply_id, # 添加reply_id参数
|
||||
cw=fields["cw"],
|
||||
poll=fields["poll"],
|
||||
renote_id=fields["renote_id"],
|
||||
channel_id=fields["channel_id"],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey] 发送消息失败: {e}")
|
||||
@@ -309,6 +637,14 @@ class MisskeyPlatformAdapter(Platform):
|
||||
file_parts = process_files(message, files)
|
||||
message_parts.extend(file_parts)
|
||||
|
||||
poll = raw_data.get("poll") or (
|
||||
raw_data.get("note", {}).get("poll")
|
||||
if isinstance(raw_data.get("note"), dict)
|
||||
else None
|
||||
)
|
||||
if poll and isinstance(poll, dict):
|
||||
self._process_poll_data(message, poll, message_parts)
|
||||
|
||||
message.message_str = (
|
||||
" ".join(part for part in message_parts if part.strip())
|
||||
if message_parts
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import random
|
||||
import asyncio
|
||||
from typing import Any, Optional, Dict, List, Callable, Awaitable
|
||||
import uuid
|
||||
|
||||
@@ -11,6 +13,7 @@ except ImportError as e:
|
||||
) from e
|
||||
|
||||
from astrbot.api import logger
|
||||
from .misskey_utils import FileIDExtractor
|
||||
|
||||
# Constants
|
||||
API_MAX_RETRIES = 3
|
||||
@@ -55,6 +58,7 @@ class StreamingClient:
|
||||
self.is_connected = False
|
||||
self.message_handlers: Dict[str, Callable] = {}
|
||||
self.channels: Dict[str, str] = {}
|
||||
self.desired_channels: Dict[str, Optional[Dict]] = {}
|
||||
self._running = False
|
||||
self._last_pong = None
|
||||
|
||||
@@ -72,6 +76,18 @@ class StreamingClient:
|
||||
self._running = True
|
||||
|
||||
logger.info("[Misskey WebSocket] 已连接")
|
||||
if self.desired_channels:
|
||||
try:
|
||||
desired = list(self.desired_channels.items())
|
||||
for channel_type, params in desired:
|
||||
try:
|
||||
await self.subscribe_channel(channel_type, params)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[Misskey WebSocket] 重新订阅 {channel_type} 失败: {e}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -112,9 +128,12 @@ class StreamingClient:
|
||||
return
|
||||
|
||||
message = {"type": "disconnect", "body": {"id": channel_id}}
|
||||
|
||||
await self.websocket.send(json.dumps(message))
|
||||
del self.channels[channel_id]
|
||||
channel_type = self.channels.get(channel_id)
|
||||
if channel_id in self.channels:
|
||||
del self.channels[channel_id]
|
||||
if channel_type and channel_type not in self.channels.values():
|
||||
self.desired_channels.pop(channel_type, None)
|
||||
|
||||
def add_message_handler(
|
||||
self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
|
||||
@@ -141,25 +160,67 @@ class StreamingClient:
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
|
||||
self.is_connected = False
|
||||
try:
|
||||
await self.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
logger.warning(
|
||||
f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
|
||||
)
|
||||
self.is_connected = False
|
||||
try:
|
||||
await self.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
except websockets.exceptions.InvalidHandshake as e:
|
||||
logger.error(f"[Misskey WebSocket] 握手失败: {e}")
|
||||
self.is_connected = False
|
||||
try:
|
||||
await self.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
|
||||
self.is_connected = False
|
||||
try:
|
||||
await self.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _handle_message(self, data: Dict[str, Any]):
|
||||
message_type = data.get("type")
|
||||
body = data.get("body", {})
|
||||
|
||||
logger.debug(
|
||||
f"[Misskey WebSocket] 收到消息类型: {message_type}\n数据: {json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||
)
|
||||
def _build_channel_summary(message_type: Optional[str], body: Any) -> str:
|
||||
try:
|
||||
if not isinstance(body, dict):
|
||||
return f"[Misskey WebSocket] 收到消息类型: {message_type}"
|
||||
|
||||
inner = body.get("body") if isinstance(body.get("body"), dict) else body
|
||||
note = (
|
||||
inner.get("note")
|
||||
if isinstance(inner, dict) and isinstance(inner.get("note"), dict)
|
||||
else None
|
||||
)
|
||||
|
||||
text = note.get("text") if note else None
|
||||
note_id = note.get("id") if note else None
|
||||
files = note.get("files") or [] if note else []
|
||||
has_files = bool(files)
|
||||
is_hidden = bool(note.get("isHidden")) if note else False
|
||||
user = note.get("user", {}) if note else None
|
||||
|
||||
return (
|
||||
f"[Misskey WebSocket] 收到消息类型: {message_type} | "
|
||||
f"note_id={note_id} | user={user.get('username') if user else None} | "
|
||||
f"text={text[:80] if text else '[no-text]'} | files={has_files} | hidden={is_hidden}"
|
||||
)
|
||||
except Exception:
|
||||
return f"[Misskey WebSocket] 收到消息类型: {message_type}"
|
||||
|
||||
channel_summary = _build_channel_summary(message_type, body)
|
||||
logger.info(channel_summary)
|
||||
|
||||
if message_type == "channel":
|
||||
channel_id = body.get("id")
|
||||
@@ -202,16 +263,60 @@ class StreamingClient:
|
||||
await self.message_handlers["_debug"](data)
|
||||
|
||||
|
||||
def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
|
||||
def retry_async(
|
||||
max_retries: int = 3,
|
||||
retryable_exceptions: tuple = (APIConnectionError, APIRateLimitError),
|
||||
backoff_base: float = 1.0,
|
||||
max_backoff: float = 30.0,
|
||||
):
|
||||
"""
|
||||
智能异步重试装饰器
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
retryable_exceptions: 可重试的异常类型
|
||||
backoff_base: 退避基数
|
||||
max_backoff: 最大退避时间
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
last_exc = None
|
||||
for _ in range(max_retries):
|
||||
func_name = getattr(func, "__name__", "unknown")
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except retryable_exceptions as e:
|
||||
last_exc = e
|
||||
if attempt == max_retries:
|
||||
logger.error(
|
||||
f"[Misskey API] {func_name} 重试 {max_retries} 次后仍失败: {e}"
|
||||
)
|
||||
break
|
||||
|
||||
# 智能退避策略
|
||||
if isinstance(e, APIRateLimitError):
|
||||
# 频率限制用更长的退避时间
|
||||
backoff = min(backoff_base * (3**attempt), max_backoff)
|
||||
else:
|
||||
# 其他错误用指数退避
|
||||
backoff = min(backoff_base * (2**attempt), max_backoff)
|
||||
|
||||
jitter = random.uniform(0.1, 0.5) # 随机抖动
|
||||
sleep_time = backoff + jitter
|
||||
|
||||
logger.warning(
|
||||
f"[Misskey API] {func_name} 第 {attempt} 次重试失败: {e},"
|
||||
f"{sleep_time:.1f}s后重试"
|
||||
)
|
||||
await asyncio.sleep(sleep_time)
|
||||
continue
|
||||
except Exception as e:
|
||||
# 非可重试异常直接抛出
|
||||
logger.error(f"[Misskey API] {func_name} 遇到不可重试异常: {e}")
|
||||
raise
|
||||
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
|
||||
@@ -221,11 +326,27 @@ def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
|
||||
|
||||
|
||||
class MisskeyAPI:
|
||||
def __init__(self, instance_url: str, access_token: str):
|
||||
def __init__(
|
||||
self,
|
||||
instance_url: str,
|
||||
access_token: str,
|
||||
*,
|
||||
allow_insecure_downloads: bool = False,
|
||||
download_timeout: int = 15,
|
||||
chunk_size: int = 64 * 1024,
|
||||
max_download_bytes: Optional[int] = None,
|
||||
):
|
||||
self.instance_url = instance_url.rstrip("/")
|
||||
self.access_token = access_token
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self.streaming: Optional[StreamingClient] = None
|
||||
# download options
|
||||
self.allow_insecure_downloads = allow_insecure_downloads
|
||||
self.download_timeout = download_timeout
|
||||
self.chunk_size = chunk_size
|
||||
self.max_download_bytes = (
|
||||
int(max_download_bytes) if max_download_bytes is not None else None
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
@@ -258,16 +379,37 @@ class MisskeyAPI:
|
||||
def _handle_response_status(self, status: int, endpoint: str):
|
||||
"""处理 HTTP 响应状态码"""
|
||||
if status == 400:
|
||||
logger.error(f"API 请求错误: {endpoint} (状态码: {status})")
|
||||
logger.error(f"[Misskey API] 请求参数错误: {endpoint} (HTTP {status})")
|
||||
raise APIError(f"Bad request for {endpoint}")
|
||||
elif status in (401, 403):
|
||||
logger.error(f"API 认证失败: {endpoint} (状态码: {status})")
|
||||
raise AuthenticationError(f"Authentication failed for {endpoint}")
|
||||
elif status == 401:
|
||||
logger.error(f"[Misskey API] 未授权访问: {endpoint} (HTTP {status})")
|
||||
raise AuthenticationError(f"Unauthorized access for {endpoint}")
|
||||
elif status == 403:
|
||||
logger.error(f"[Misskey API] 访问被禁止: {endpoint} (HTTP {status})")
|
||||
raise AuthenticationError(f"Forbidden access for {endpoint}")
|
||||
elif status == 404:
|
||||
logger.error(f"[Misskey API] 资源不存在: {endpoint} (HTTP {status})")
|
||||
raise APIError(f"Resource not found for {endpoint}")
|
||||
elif status == 413:
|
||||
logger.error(f"[Misskey API] 请求体过大: {endpoint} (HTTP {status})")
|
||||
raise APIError(f"Request entity too large for {endpoint}")
|
||||
elif status == 429:
|
||||
logger.warning(f"API 频率限制: {endpoint} (状态码: {status})")
|
||||
logger.warning(f"[Misskey API] 请求频率限制: {endpoint} (HTTP {status})")
|
||||
raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
|
||||
elif status == 500:
|
||||
logger.error(f"[Misskey API] 服务器内部错误: {endpoint} (HTTP {status})")
|
||||
raise APIConnectionError(f"Internal server error for {endpoint}")
|
||||
elif status == 502:
|
||||
logger.error(f"[Misskey API] 网关错误: {endpoint} (HTTP {status})")
|
||||
raise APIConnectionError(f"Bad gateway for {endpoint}")
|
||||
elif status == 503:
|
||||
logger.error(f"[Misskey API] 服务不可用: {endpoint} (HTTP {status})")
|
||||
raise APIConnectionError(f"Service unavailable for {endpoint}")
|
||||
elif status == 504:
|
||||
logger.error(f"[Misskey API] 网关超时: {endpoint} (HTTP {status})")
|
||||
raise APIConnectionError(f"Gateway timeout for {endpoint}")
|
||||
else:
|
||||
logger.error(f"API 请求失败: {endpoint} (状态码: {status})")
|
||||
logger.error(f"[Misskey API] 未知错误: {endpoint} (HTTP {status})")
|
||||
raise APIConnectionError(f"HTTP {status} for {endpoint}")
|
||||
|
||||
async def _process_response(
|
||||
@@ -286,21 +428,25 @@ class MisskeyAPI:
|
||||
else []
|
||||
)
|
||||
if notifications_data:
|
||||
logger.debug(f"获取到 {len(notifications_data)} 条新通知")
|
||||
logger.debug(
|
||||
f"[Misskey API] 获取到 {len(notifications_data)} 条新通知"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"API 请求成功: {endpoint}")
|
||||
logger.debug(f"[Misskey API] 请求成功: {endpoint}")
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"响应不是有效的 JSON 格式: {e}")
|
||||
logger.error(f"[Misskey API] 响应格式错误: {e}")
|
||||
raise APIConnectionError("Invalid JSON response") from e
|
||||
else:
|
||||
try:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
f"API 请求失败: {endpoint} - 状态码: {response.status}, 响应: {error_text}"
|
||||
f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}, 响应: {error_text}"
|
||||
)
|
||||
except Exception:
|
||||
logger.error(f"API 请求失败: {endpoint} - 状态码: {response.status}")
|
||||
logger.error(
|
||||
f"[Misskey API] 请求失败: {endpoint} - HTTP {response.status}"
|
||||
)
|
||||
|
||||
self._handle_response_status(response.status, endpoint)
|
||||
raise APIConnectionError(f"Request failed for {endpoint}")
|
||||
@@ -321,53 +467,307 @@ class MisskeyAPI:
|
||||
async with self.session.post(url, json=payload) as response:
|
||||
return await self._process_response(response, endpoint)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"HTTP 请求错误: {e}")
|
||||
logger.error(f"[Misskey API] HTTP 请求错误: {e}")
|
||||
raise APIConnectionError(f"HTTP request failed: {e}") from e
|
||||
|
||||
async def create_note(
|
||||
self,
|
||||
text: str,
|
||||
text: Optional[str] = None,
|
||||
visibility: str = "public",
|
||||
reply_id: Optional[str] = None,
|
||||
visible_user_ids: Optional[List[str]] = None,
|
||||
file_ids: Optional[List[str]] = None,
|
||||
local_only: bool = False,
|
||||
cw: Optional[str] = None,
|
||||
poll: Optional[Dict[str, Any]] = None,
|
||||
renote_id: Optional[str] = None,
|
||||
channel_id: Optional[str] = None,
|
||||
reaction_acceptance: Optional[str] = None,
|
||||
no_extract_mentions: Optional[bool] = None,
|
||||
no_extract_hashtags: Optional[bool] = None,
|
||||
no_extract_emojis: Optional[bool] = None,
|
||||
media_ids: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""创建新贴文"""
|
||||
data: Dict[str, Any] = {
|
||||
"text": text,
|
||||
"visibility": visibility,
|
||||
"localOnly": local_only,
|
||||
}
|
||||
"""Create a note (wrapper for notes/create). All additional fields are optional and passed through to the API."""
|
||||
data: Dict[str, Any] = {}
|
||||
|
||||
if text is not None:
|
||||
data["text"] = text
|
||||
|
||||
data["visibility"] = visibility
|
||||
data["localOnly"] = local_only
|
||||
|
||||
if reply_id:
|
||||
data["replyId"] = reply_id
|
||||
|
||||
if visible_user_ids and visibility == "specified":
|
||||
data["visibleUserIds"] = visible_user_ids
|
||||
|
||||
if file_ids:
|
||||
data["fileIds"] = file_ids
|
||||
if media_ids:
|
||||
data["mediaIds"] = media_ids
|
||||
|
||||
if cw is not None:
|
||||
data["cw"] = cw
|
||||
if poll is not None:
|
||||
data["poll"] = poll
|
||||
if renote_id is not None:
|
||||
data["renoteId"] = renote_id
|
||||
if channel_id is not None:
|
||||
data["channelId"] = channel_id
|
||||
if reaction_acceptance is not None:
|
||||
data["reactionAcceptance"] = reaction_acceptance
|
||||
if no_extract_mentions is not None:
|
||||
data["noExtractMentions"] = bool(no_extract_mentions)
|
||||
if no_extract_hashtags is not None:
|
||||
data["noExtractHashtags"] = bool(no_extract_hashtags)
|
||||
if no_extract_emojis is not None:
|
||||
data["noExtractEmojis"] = bool(no_extract_emojis)
|
||||
|
||||
result = await self._make_request("notes/create", data)
|
||||
note_id = result.get("createdNote", {}).get("id", "unknown")
|
||||
logger.debug(f"发帖成功,note_id: {note_id}")
|
||||
note_id = (
|
||||
result.get("createdNote", {}).get("id", "unknown")
|
||||
if isinstance(result, dict)
|
||||
else "unknown"
|
||||
)
|
||||
logger.debug(f"[Misskey API] 发帖成功: {note_id}")
|
||||
return result
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file_path: str,
|
||||
name: Optional[str] = None,
|
||||
folder_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Upload a file to Misskey drive/files/create and return a dict containing id and raw result."""
|
||||
if not file_path:
|
||||
raise APIError("No file path provided for upload")
|
||||
|
||||
url = f"{self.instance_url}/api/drive/files/create"
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("i", self.access_token)
|
||||
|
||||
try:
|
||||
filename = name or file_path.split("/")[-1]
|
||||
if folder_id:
|
||||
form.add_field("folderId", str(folder_id))
|
||||
|
||||
try:
|
||||
f = open(file_path, "rb")
|
||||
except FileNotFoundError as e:
|
||||
logger.error(f"[Misskey API] 本地文件不存在: {file_path}")
|
||||
raise APIError(f"File not found: {file_path}") from e
|
||||
|
||||
try:
|
||||
form.add_field("file", f, filename=filename)
|
||||
async with self.session.post(url, data=form) as resp:
|
||||
result = await self._process_response(resp, "drive/files/create")
|
||||
file_id = FileIDExtractor.extract_file_id(result)
|
||||
logger.debug(
|
||||
f"[Misskey API] 本地文件上传成功: {filename} -> {file_id}"
|
||||
)
|
||||
return {"id": file_id, "raw": result}
|
||||
finally:
|
||||
f.close()
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"[Misskey API] 文件上传网络错误: {e}")
|
||||
raise APIConnectionError(f"Upload failed: {e}") from e
|
||||
|
||||
async def find_files_by_hash(self, md5_hash: str) -> List[Dict[str, Any]]:
|
||||
"""Find files by MD5 hash"""
|
||||
if not md5_hash:
|
||||
raise APIError("No MD5 hash provided for find-by-hash")
|
||||
|
||||
data = {"md5": md5_hash}
|
||||
|
||||
try:
|
||||
logger.debug(f"[Misskey API] find-by-hash 请求: md5={md5_hash}")
|
||||
result = await self._make_request("drive/files/find-by-hash", data)
|
||||
logger.debug(
|
||||
f"[Misskey API] find-by-hash 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] 根据哈希查找文件失败: {e}")
|
||||
raise
|
||||
|
||||
async def find_files_by_name(
|
||||
self, name: str, folder_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find files by name"""
|
||||
if not name:
|
||||
raise APIError("No name provided for find")
|
||||
|
||||
data: Dict[str, Any] = {"name": name}
|
||||
if folder_id:
|
||||
data["folderId"] = folder_id
|
||||
|
||||
try:
|
||||
logger.debug(f"[Misskey API] find 请求: name={name}, folder_id={folder_id}")
|
||||
result = await self._make_request("drive/files/find", data)
|
||||
logger.debug(
|
||||
f"[Misskey API] find 响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] 根据名称查找文件失败: {e}")
|
||||
raise
|
||||
|
||||
async def find_files(
|
||||
self,
|
||||
limit: int = 10,
|
||||
folder_id: Optional[str] = None,
|
||||
type: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List files with optional filters"""
|
||||
data: Dict[str, Any] = {"limit": limit}
|
||||
if folder_id is not None:
|
||||
data["folderId"] = folder_id
|
||||
if type is not None:
|
||||
data["type"] = type
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"[Misskey API] 列表文件请求: limit={limit}, folder_id={folder_id}, type={type}"
|
||||
)
|
||||
result = await self._make_request("drive/files", data)
|
||||
logger.debug(
|
||||
f"[Misskey API] 列表文件响应: 找到 {len(result) if isinstance(result, list) else 0} 个文件"
|
||||
)
|
||||
return result if isinstance(result, list) else []
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] 列表文件失败: {e}")
|
||||
raise
|
||||
|
||||
async def _download_with_existing_session(
|
||||
self, url: str, ssl_verify: bool = True
|
||||
) -> Optional[bytes]:
|
||||
"""使用现有会话下载文件"""
|
||||
if not (hasattr(self, "session") and self.session):
|
||||
raise APIConnectionError("No existing session available")
|
||||
|
||||
async with self.session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=15), ssl=ssl_verify
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.read()
|
||||
return None
|
||||
|
||||
async def _download_with_temp_session(
|
||||
self, url: str, ssl_verify: bool = True
|
||||
) -> Optional[bytes]:
|
||||
"""使用临时会话下载文件"""
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_verify)
|
||||
async with aiohttp.ClientSession(connector=connector) as temp_session:
|
||||
async with temp_session.get(
|
||||
url, timeout=aiohttp.ClientTimeout(total=15)
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.read()
|
||||
return None
|
||||
|
||||
async def upload_and_find_file(
|
||||
self,
|
||||
url: str,
|
||||
name: Optional[str] = None,
|
||||
folder_id: Optional[str] = None,
|
||||
max_wait_time: float = 30.0,
|
||||
check_interval: float = 2.0,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
简化的文件上传:尝试 URL 上传,失败则下载后本地上传
|
||||
|
||||
Args:
|
||||
url: 文件URL
|
||||
name: 文件名(可选)
|
||||
folder_id: 文件夹ID(可选)
|
||||
max_wait_time: 保留参数(未使用)
|
||||
check_interval: 保留参数(未使用)
|
||||
|
||||
Returns:
|
||||
包含文件ID和元信息的字典,失败时返回None
|
||||
"""
|
||||
if not url:
|
||||
raise APIError("URL不能为空")
|
||||
|
||||
# 通过本地上传获取即时文件 ID(下载文件 → 上传 → 返回 ID)
|
||||
try:
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
# SSL 验证下载,失败则重试不验证 SSL
|
||||
tmp_bytes = None
|
||||
try:
|
||||
tmp_bytes = await self._download_with_existing_session(
|
||||
url, ssl_verify=True
|
||||
) or await self._download_with_temp_session(url, ssl_verify=True)
|
||||
except Exception as ssl_error:
|
||||
logger.debug(
|
||||
f"[Misskey API] SSL 验证下载失败: {ssl_error},重试不验证 SSL"
|
||||
)
|
||||
try:
|
||||
tmp_bytes = await self._download_with_existing_session(
|
||||
url, ssl_verify=False
|
||||
) or await self._download_with_temp_session(url, ssl_verify=False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if tmp_bytes:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as tmpf:
|
||||
tmpf.write(tmp_bytes)
|
||||
tmp_path = tmpf.name
|
||||
|
||||
try:
|
||||
result = await self.upload_file(tmp_path, name, folder_id)
|
||||
logger.debug(f"[Misskey API] 本地上传成功: {result.get('id')}")
|
||||
return result
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] 本地上传失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def get_current_user(self) -> Dict[str, Any]:
|
||||
"""获取当前用户信息"""
|
||||
return await self._make_request("i", {})
|
||||
|
||||
async def send_message(self, user_id: str, text: str) -> Dict[str, Any]:
|
||||
"""发送聊天消息"""
|
||||
result = await self._make_request(
|
||||
"chat/messages/create-to-user", {"toUserId": user_id, "text": text}
|
||||
)
|
||||
async def send_message(
|
||||
self, user_id_or_payload: Any, text: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送聊天消息。
|
||||
|
||||
Accepts either (user_id: str, text: str) or a single dict payload prepared by caller.
|
||||
"""
|
||||
if isinstance(user_id_or_payload, dict):
|
||||
data = user_id_or_payload
|
||||
else:
|
||||
data = {"toUserId": user_id_or_payload, "text": text}
|
||||
|
||||
result = await self._make_request("chat/messages/create-to-user", data)
|
||||
message_id = result.get("id", "unknown")
|
||||
logger.debug(f"聊天发送成功,message_id: {message_id}")
|
||||
logger.debug(f"[Misskey API] 聊天消息发送成功: {message_id}")
|
||||
return result
|
||||
|
||||
async def send_room_message(self, room_id: str, text: str) -> Dict[str, Any]:
|
||||
"""发送房间消息"""
|
||||
result = await self._make_request(
|
||||
"chat/messages/create-to-room", {"toRoomId": room_id, "text": text}
|
||||
)
|
||||
async def send_room_message(
|
||||
self, room_id_or_payload: Any, text: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""发送房间消息。
|
||||
|
||||
Accepts either (room_id: str, text: str) or a single dict payload.
|
||||
"""
|
||||
if isinstance(room_id_or_payload, dict):
|
||||
data = room_id_or_payload
|
||||
else:
|
||||
data = {"toRoomId": room_id_or_payload, "text": text}
|
||||
|
||||
result = await self._make_request("chat/messages/create-to-room", data)
|
||||
message_id = result.get("id", "unknown")
|
||||
logger.debug(f"房间消息发送成功,message_id: {message_id}")
|
||||
logger.debug(f"[Misskey API] 房间消息发送成功: {message_id}")
|
||||
return result
|
||||
|
||||
async def get_messages(
|
||||
@@ -381,9 +781,8 @@ class MisskeyAPI:
|
||||
result = await self._make_request("chat/messages/user-timeline", data)
|
||||
if isinstance(result, list):
|
||||
return result
|
||||
else:
|
||||
logger.warning(f"获取聊天消息响应格式异常: {type(result)}")
|
||||
return []
|
||||
logger.warning(f"[Misskey API] 聊天消息响应格式异常: {type(result)}")
|
||||
return []
|
||||
|
||||
async def get_mentions(
|
||||
self, limit: int = 10, since_id: Optional[str] = None
|
||||
@@ -400,5 +799,142 @@ class MisskeyAPI:
|
||||
elif isinstance(result, dict) and "notifications" in result:
|
||||
return result["notifications"]
|
||||
else:
|
||||
logger.warning(f"获取提及通知响应格式异常: {type(result)}")
|
||||
logger.warning(f"[Misskey API] 提及通知响应格式异常: {type(result)}")
|
||||
return []
|
||||
|
||||
async def send_message_with_media(
|
||||
self,
|
||||
message_type: str,
|
||||
target_id: str,
|
||||
text: Optional[str] = None,
|
||||
media_urls: Optional[List[str]] = None,
|
||||
local_files: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
通用消息发送函数:统一处理文本+媒体发送
|
||||
|
||||
Args:
|
||||
message_type: 消息类型 ('chat', 'room', 'note')
|
||||
target_id: 目标ID (用户ID/房间ID/频道ID等)
|
||||
text: 文本内容
|
||||
media_urls: 媒体文件URL列表
|
||||
local_files: 本地文件路径列表
|
||||
**kwargs: 其他参数(如visibility等)
|
||||
|
||||
Returns:
|
||||
发送结果字典
|
||||
|
||||
Raises:
|
||||
APIError: 参数错误或发送失败
|
||||
"""
|
||||
if not text and not media_urls and not local_files:
|
||||
raise APIError("消息内容不能为空:需要文本或媒体文件")
|
||||
|
||||
file_ids = []
|
||||
|
||||
# 处理远程媒体文件
|
||||
if media_urls:
|
||||
file_ids.extend(await self._process_media_urls(media_urls))
|
||||
|
||||
# 处理本地文件
|
||||
if local_files:
|
||||
file_ids.extend(await self._process_local_files(local_files))
|
||||
|
||||
# 根据消息类型发送
|
||||
return await self._dispatch_message(
|
||||
message_type, target_id, text, file_ids, **kwargs
|
||||
)
|
||||
|
||||
async def _process_media_urls(self, urls: List[str]) -> List[str]:
|
||||
"""处理远程媒体文件URL列表,返回文件ID列表"""
|
||||
file_ids = []
|
||||
for url in urls:
|
||||
try:
|
||||
result = await self.upload_and_find_file(url)
|
||||
if result and result.get("id"):
|
||||
file_ids.append(result["id"])
|
||||
logger.debug(f"[Misskey API] URL媒体上传成功: {result['id']}")
|
||||
else:
|
||||
logger.error(f"[Misskey API] URL媒体上传失败: {url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] URL媒体处理失败 {url}: {e}")
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
continue
|
||||
return file_ids
|
||||
|
||||
async def _process_local_files(self, file_paths: List[str]) -> List[str]:
|
||||
"""处理本地文件路径列表,返回文件ID列表"""
|
||||
file_ids = []
|
||||
for file_path in file_paths:
|
||||
try:
|
||||
result = await self.upload_file(file_path)
|
||||
if result and result.get("id"):
|
||||
file_ids.append(result["id"])
|
||||
logger.debug(f"[Misskey API] 本地文件上传成功: {result['id']}")
|
||||
else:
|
||||
logger.error(f"[Misskey API] 本地文件上传失败: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Misskey API] 本地文件处理失败 {file_path}: {e}")
|
||||
continue
|
||||
return file_ids
|
||||
|
||||
async def _dispatch_message(
|
||||
self,
|
||||
message_type: str,
|
||||
target_id: str,
|
||||
text: Optional[str],
|
||||
file_ids: List[str],
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""根据消息类型分发到对应的发送方法"""
|
||||
if message_type == "chat":
|
||||
# 聊天消息使用 fileId (单数)
|
||||
payload = {"toUserId": target_id}
|
||||
if text:
|
||||
payload["text"] = text
|
||||
if file_ids:
|
||||
if len(file_ids) == 1:
|
||||
payload["fileId"] = file_ids[0]
|
||||
else:
|
||||
# 多文件时逐个发送
|
||||
results = []
|
||||
for file_id in file_ids:
|
||||
single_payload = payload.copy()
|
||||
single_payload["fileId"] = file_id
|
||||
result = await self.send_message(single_payload)
|
||||
results.append(result)
|
||||
return {"multiple": True, "results": results}
|
||||
return await self.send_message(payload)
|
||||
|
||||
elif message_type == "room":
|
||||
# 房间消息使用 fileId (单数)
|
||||
payload = {"toRoomId": target_id}
|
||||
if text:
|
||||
payload["text"] = text
|
||||
if file_ids:
|
||||
if len(file_ids) == 1:
|
||||
payload["fileId"] = file_ids[0]
|
||||
else:
|
||||
# 多文件时逐个发送
|
||||
results = []
|
||||
for file_id in file_ids:
|
||||
single_payload = payload.copy()
|
||||
single_payload["fileId"] = file_id
|
||||
result = await self.send_room_message(single_payload)
|
||||
results.append(result)
|
||||
return {"multiple": True, "results": results}
|
||||
return await self.send_room_message(payload)
|
||||
|
||||
elif message_type == "note":
|
||||
# 发帖使用 fileIds (复数)
|
||||
note_kwargs = {
|
||||
"text": text,
|
||||
"file_ids": file_ids or None,
|
||||
}
|
||||
# 合并其他参数
|
||||
note_kwargs.update(kwargs)
|
||||
return await self.create_note(**note_kwargs)
|
||||
|
||||
else:
|
||||
raise APIError(f"不支持的消息类型: {message_type}")
|
||||
|
||||
@@ -40,48 +40,83 @@ class MisskeyPlatformEvent(AstrMessageEvent):
|
||||
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
content, has_at = serialize_message_chain(message.chain)
|
||||
|
||||
if not content:
|
||||
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
|
||||
return
|
||||
|
||||
"""发送消息,使用适配器的完整上传和发送逻辑"""
|
||||
try:
|
||||
original_message_id = getattr(self.message_obj, "message_id", None)
|
||||
raw_message = getattr(self.message_obj, "raw_message", {})
|
||||
logger.debug(
|
||||
f"[MisskeyEvent] send 方法被调用,消息链包含 {len(message.chain)} 个组件"
|
||||
)
|
||||
|
||||
if raw_message and not has_at:
|
||||
user_data = raw_message.get("user", {})
|
||||
user_info = {
|
||||
"username": user_data.get("username", ""),
|
||||
"nickname": user_data.get("name", user_data.get("username", "")),
|
||||
}
|
||||
content = add_at_mention_if_needed(content, user_info, has_at)
|
||||
# 使用适配器的 send_by_session 方法,它包含文件上传逻辑
|
||||
from astrbot.core.platform.message_session import MessageSession
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
|
||||
# 根据会话类型选择发送方式
|
||||
if hasattr(self.client, "send_message") and is_valid_user_session_id(
|
||||
self.session_id
|
||||
):
|
||||
user_id = extract_user_id_from_session_id(self.session_id)
|
||||
await self.client.send_message(user_id, content)
|
||||
elif hasattr(self.client, "send_room_message") and is_valid_room_session_id(
|
||||
self.session_id
|
||||
):
|
||||
room_id = extract_room_id_from_session_id(self.session_id)
|
||||
await self.client.send_room_message(room_id, content)
|
||||
elif original_message_id and hasattr(self.client, "create_note"):
|
||||
visibility, visible_user_ids = resolve_visibility_from_raw_message(
|
||||
raw_message
|
||||
)
|
||||
await self.client.create_note(
|
||||
content,
|
||||
reply_id=original_message_id,
|
||||
visibility=visibility,
|
||||
visible_user_ids=visible_user_ids,
|
||||
)
|
||||
elif hasattr(self.client, "create_note"):
|
||||
logger.debug("[MisskeyEvent] 创建新帖子")
|
||||
await self.client.create_note(content)
|
||||
# 根据session_id类型确定消息类型
|
||||
if is_valid_user_session_id(self.session_id):
|
||||
message_type = MessageType.FRIEND_MESSAGE
|
||||
elif is_valid_room_session_id(self.session_id):
|
||||
message_type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
message_type = MessageType.FRIEND_MESSAGE # 默认
|
||||
|
||||
session = MessageSession(
|
||||
platform_name=self.platform_meta.name,
|
||||
message_type=message_type,
|
||||
session_id=self.session_id,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"[MisskeyEvent] 检查适配器方法: hasattr(self.client, 'send_by_session') = {hasattr(self.client, 'send_by_session')}"
|
||||
)
|
||||
|
||||
# 调用适配器的 send_by_session 方法
|
||||
if hasattr(self.client, "send_by_session"):
|
||||
logger.debug("[MisskeyEvent] 调用适配器的 send_by_session 方法")
|
||||
await self.client.send_by_session(session, message)
|
||||
else:
|
||||
# 回退到原来的简化发送逻辑
|
||||
content, has_at = serialize_message_chain(message.chain)
|
||||
|
||||
if not content:
|
||||
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
|
||||
return
|
||||
|
||||
original_message_id = getattr(self.message_obj, "message_id", None)
|
||||
raw_message = getattr(self.message_obj, "raw_message", {})
|
||||
|
||||
if raw_message and not has_at:
|
||||
user_data = raw_message.get("user", {})
|
||||
user_info = {
|
||||
"username": user_data.get("username", ""),
|
||||
"nickname": user_data.get(
|
||||
"name", user_data.get("username", "")
|
||||
),
|
||||
}
|
||||
content = add_at_mention_if_needed(content, user_info, has_at)
|
||||
|
||||
# 根据会话类型选择发送方式
|
||||
if hasattr(self.client, "send_message") and is_valid_user_session_id(
|
||||
self.session_id
|
||||
):
|
||||
user_id = extract_user_id_from_session_id(self.session_id)
|
||||
await self.client.send_message(user_id, content)
|
||||
elif hasattr(
|
||||
self.client, "send_room_message"
|
||||
) and is_valid_room_session_id(self.session_id):
|
||||
room_id = extract_room_id_from_session_id(self.session_id)
|
||||
await self.client.send_room_message(room_id, content)
|
||||
elif original_message_id and hasattr(self.client, "create_note"):
|
||||
visibility, visible_user_ids = resolve_visibility_from_raw_message(
|
||||
raw_message
|
||||
)
|
||||
await self.client.create_note(
|
||||
content,
|
||||
reply_id=original_message_id,
|
||||
visibility=visibility,
|
||||
visible_user_ids=visible_user_ids,
|
||||
)
|
||||
elif hasattr(self.client, "create_note"):
|
||||
logger.debug("[MisskeyEvent] 创建新帖子")
|
||||
await self.client.create_note(content)
|
||||
|
||||
await super().send(message)
|
||||
|
||||
|
||||
@@ -5,6 +5,68 @@ import astrbot.api.message_components as Comp
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
|
||||
|
||||
class FileIDExtractor:
|
||||
"""从 API 响应中提取文件 ID 的帮助类(无状态)。"""
|
||||
|
||||
@staticmethod
|
||||
def extract_file_id(result: Any) -> Optional[str]:
|
||||
if not isinstance(result, dict):
|
||||
return None
|
||||
|
||||
id_paths = [
|
||||
lambda r: r.get("createdFile", {}).get("id"),
|
||||
lambda r: r.get("file", {}).get("id"),
|
||||
lambda r: r.get("id"),
|
||||
]
|
||||
|
||||
for p in id_paths:
|
||||
try:
|
||||
if fid := p(result):
|
||||
return fid
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MessagePayloadBuilder:
|
||||
"""构建不同类型消息负载的帮助类(无状态)。"""
|
||||
|
||||
@staticmethod
|
||||
def build_chat_payload(
|
||||
user_id: str, text: Optional[str], file_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
payload = {"toUserId": user_id}
|
||||
if text:
|
||||
payload["text"] = text
|
||||
if file_id:
|
||||
payload["fileId"] = file_id
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def build_room_payload(
|
||||
room_id: str, text: Optional[str], file_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
payload = {"toRoomId": room_id}
|
||||
if text:
|
||||
payload["text"] = text
|
||||
if file_id:
|
||||
payload["fileId"] = file_id
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def build_note_payload(
|
||||
text: Optional[str], file_ids: Optional[List[str]] = None, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
payload: Dict[str, Any] = {}
|
||||
if text:
|
||||
payload["text"] = text
|
||||
if file_ids:
|
||||
payload["fileIds"] = file_ids
|
||||
payload |= kwargs
|
||||
return payload
|
||||
|
||||
|
||||
def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
|
||||
"""将消息链序列化为文本字符串"""
|
||||
text_parts = []
|
||||
@@ -15,11 +77,19 @@ def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
|
||||
if isinstance(component, Comp.Plain):
|
||||
return component.text
|
||||
elif isinstance(component, Comp.File):
|
||||
file_name = getattr(component, "name", "文件")
|
||||
return f"[文件: {file_name}]"
|
||||
# 为文件组件返回占位符,但适配器仍会处理原组件
|
||||
return "[文件]"
|
||||
elif isinstance(component, Comp.Image):
|
||||
# 为图片组件返回占位符,但适配器仍会处理原组件
|
||||
return "[图片]"
|
||||
elif isinstance(component, Comp.At):
|
||||
has_at = True
|
||||
return f"@{component.qq}"
|
||||
# 优先使用name字段(用户名),如果没有则使用qq字段
|
||||
# 这样可以避免在Misskey中生成 @<user_id> 这样的无效提及
|
||||
if hasattr(component, "name") and component.name:
|
||||
return f"@{component.name}"
|
||||
else:
|
||||
return f"@{component.qq}"
|
||||
elif hasattr(component, "text"):
|
||||
text = getattr(component, "text", "")
|
||||
if "@" in text:
|
||||
@@ -43,15 +113,22 @@ def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
|
||||
|
||||
|
||||
def resolve_message_visibility(
|
||||
user_id: Optional[str],
|
||||
user_cache: Dict[str, Any],
|
||||
self_id: Optional[str],
|
||||
user_id: Optional[str] = None,
|
||||
user_cache: Optional[Dict[str, Any]] = None,
|
||||
self_id: Optional[str] = None,
|
||||
raw_message: Optional[Dict[str, Any]] = None,
|
||||
default_visibility: str = "public",
|
||||
) -> Tuple[str, Optional[List[str]]]:
|
||||
"""解析 Misskey 消息的可见性设置"""
|
||||
"""解析 Misskey 消息的可见性设置
|
||||
|
||||
可以从 user_cache 或 raw_message 中解析,支持两种调用方式:
|
||||
1. 基于 user_cache: resolve_message_visibility(user_id, user_cache, self_id)
|
||||
2. 基于 raw_message: resolve_message_visibility(raw_message=raw_message, self_id=self_id)
|
||||
"""
|
||||
visibility = default_visibility
|
||||
visible_user_ids = None
|
||||
|
||||
# 优先从 user_cache 解析
|
||||
if user_id and user_cache:
|
||||
user_info = user_cache.get(user_id)
|
||||
if user_info:
|
||||
@@ -66,38 +143,36 @@ def resolve_message_visibility(
|
||||
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
||||
else:
|
||||
visibility = original_visibility
|
||||
return visibility, visible_user_ids
|
||||
|
||||
# 回退到从 raw_message 解析
|
||||
if raw_message:
|
||||
original_visibility = raw_message.get("visibility", default_visibility)
|
||||
if original_visibility == "specified":
|
||||
visibility = "specified"
|
||||
original_visible_users = raw_message.get("visibleUserIds", [])
|
||||
sender_id = raw_message.get("userId", "")
|
||||
|
||||
users_to_include = []
|
||||
if sender_id:
|
||||
users_to_include.append(sender_id)
|
||||
if self_id:
|
||||
users_to_include.append(self_id)
|
||||
|
||||
visible_user_ids = list(set(original_visible_users + users_to_include))
|
||||
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
||||
else:
|
||||
visibility = original_visibility
|
||||
|
||||
return visibility, visible_user_ids
|
||||
|
||||
|
||||
# 保留旧函数名作为向后兼容的别名
|
||||
def resolve_visibility_from_raw_message(
|
||||
raw_message: Dict[str, Any], self_id: Optional[str] = None
|
||||
) -> Tuple[str, Optional[List[str]]]:
|
||||
"""从原始消息数据中解析可见性设置"""
|
||||
visibility = "public"
|
||||
visible_user_ids = None
|
||||
|
||||
if not raw_message:
|
||||
return visibility, visible_user_ids
|
||||
|
||||
original_visibility = raw_message.get("visibility", "public")
|
||||
if original_visibility == "specified":
|
||||
visibility = "specified"
|
||||
original_visible_users = raw_message.get("visibleUserIds", [])
|
||||
sender_id = raw_message.get("userId", "")
|
||||
|
||||
users_to_include = []
|
||||
if sender_id:
|
||||
users_to_include.append(sender_id)
|
||||
if self_id:
|
||||
users_to_include.append(self_id)
|
||||
|
||||
visible_user_ids = list(set(original_visible_users + users_to_include))
|
||||
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
||||
else:
|
||||
visibility = original_visibility
|
||||
|
||||
return visibility, visible_user_ids
|
||||
"""从原始消息数据中解析可见性设置(已弃用,使用 resolve_message_visibility 替代)"""
|
||||
return resolve_message_visibility(raw_message=raw_message, self_id=self_id)
|
||||
|
||||
|
||||
def is_valid_user_session_id(session_id: Union[str, Any]) -> bool:
|
||||
@@ -128,6 +203,20 @@ def is_valid_room_session_id(session_id: Union[str, Any]) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def is_valid_chat_session_id(session_id: Union[str, Any]) -> bool:
|
||||
"""检查 session_id 是否是有效的聊天 session_id (仅限chat%前缀)"""
|
||||
if not isinstance(session_id, str) or "%" not in session_id:
|
||||
return False
|
||||
|
||||
parts = session_id.split("%")
|
||||
return (
|
||||
len(parts) == 2
|
||||
and parts[0] == "chat"
|
||||
and bool(parts[1])
|
||||
and parts[1] != "unknown"
|
||||
)
|
||||
|
||||
|
||||
def extract_user_id_from_session_id(session_id: str) -> str:
|
||||
"""从 session_id 中提取用户 ID"""
|
||||
if "%" in session_id:
|
||||
@@ -149,21 +238,22 @@ def extract_room_id_from_session_id(session_id: str) -> str:
|
||||
def add_at_mention_if_needed(
|
||||
text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False
|
||||
) -> str:
|
||||
"""如果需要且没有@用户,则添加@用户"""
|
||||
"""如果需要且没有@用户,则添加@用户
|
||||
|
||||
注意:仅在有有效的username时才添加@提及,避免使用用户ID
|
||||
"""
|
||||
if has_at or not user_info:
|
||||
return text
|
||||
|
||||
username = user_info.get("username")
|
||||
nickname = user_info.get("nickname")
|
||||
# 如果没有username,则不添加@提及,返回原文本
|
||||
# 这样可以避免生成 @<user_id> 这样的无效提及
|
||||
if not username:
|
||||
return text
|
||||
|
||||
if username:
|
||||
mention = f"@{username}"
|
||||
if not text.startswith(mention):
|
||||
text = f"{mention}\n{text}".strip()
|
||||
elif nickname:
|
||||
mention = f"@{nickname}"
|
||||
if not text.startswith(mention):
|
||||
text = f"{mention}\n{text}".strip()
|
||||
mention = f"@{username}"
|
||||
if not text.startswith(mention):
|
||||
text = f"{mention}\n{text}".strip()
|
||||
|
||||
return text
|
||||
|
||||
@@ -197,6 +287,22 @@ def process_files(
|
||||
return file_parts
|
||||
|
||||
|
||||
def format_poll(poll: Dict[str, Any]) -> str:
|
||||
"""将 Misskey 的 poll 对象格式化为可读字符串。"""
|
||||
if not poll or not isinstance(poll, dict):
|
||||
return ""
|
||||
multiple = poll.get("multiple", False)
|
||||
choices = poll.get("choices", [])
|
||||
text_choices = [
|
||||
f"({idx}) {c.get('text', '')} [{c.get('votes', 0)}票]"
|
||||
for idx, c in enumerate(choices, start=1)
|
||||
]
|
||||
parts = ["[投票]", ("允许多选" if multiple else "单选")] + (
|
||||
["选项: " + ", ".join(text_choices)] if text_choices else []
|
||||
)
|
||||
return " ".join(parts)
|
||||
|
||||
|
||||
def extract_sender_info(
|
||||
raw_data: Dict[str, Any], is_chat: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
@@ -248,7 +354,7 @@ def create_base_message(
|
||||
else:
|
||||
session_prefix = "note"
|
||||
session_id = f"{session_prefix}%{sender_info['sender_id']}"
|
||||
message.type = MessageType.FRIEND_MESSAGE
|
||||
message.type = MessageType.OTHER_MESSAGE
|
||||
|
||||
message.session_id = (
|
||||
session_id if sender_info["sender_id"] else f"{session_prefix}%unknown"
|
||||
@@ -303,6 +409,8 @@ def cache_user_info(
|
||||
"nickname": sender_info["nickname"],
|
||||
"visibility": raw_data.get("visibility", "public"),
|
||||
"visible_user_ids": raw_data.get("visibleUserIds", []),
|
||||
# 保存原消息ID,用于回复时作为reply_id
|
||||
"reply_to_note_id": raw_data.get("id"),
|
||||
}
|
||||
|
||||
user_cache[sender_info["sender_id"]] = user_cache_data
|
||||
@@ -325,3 +433,106 @@ def cache_room_info(
|
||||
"visibility": "specified",
|
||||
"visible_user_ids": [client_self_id],
|
||||
}
|
||||
|
||||
|
||||
async def resolve_component_url_or_path(
|
||||
comp: Any,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""尝试从组件解析可上传的远程 URL 或本地路径。
|
||||
|
||||
返回 (url_candidate, local_path)。两者可能都为 None。
|
||||
这个函数尽量不抛异常,调用方可按需处理 None。
|
||||
"""
|
||||
url_candidate = None
|
||||
local_path = None
|
||||
|
||||
async def _get_str_value(coro_or_val):
|
||||
"""辅助函数:统一处理协程或普通值"""
|
||||
try:
|
||||
if hasattr(coro_or_val, "__await__"):
|
||||
result = await coro_or_val
|
||||
else:
|
||||
result = coro_or_val
|
||||
return result if isinstance(result, str) else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1. 尝试异步方法
|
||||
for method in ["convert_to_file_path", "get_file", "register_to_file_service"]:
|
||||
if not hasattr(comp, method):
|
||||
continue
|
||||
try:
|
||||
value = await _get_str_value(getattr(comp, method)())
|
||||
if value:
|
||||
if value.startswith("http"):
|
||||
url_candidate = value
|
||||
break
|
||||
else:
|
||||
local_path = value
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2. 尝试 get_file(True) 获取可直接访问的 URL
|
||||
if not url_candidate and hasattr(comp, "get_file"):
|
||||
try:
|
||||
value = await _get_str_value(comp.get_file(True))
|
||||
if value and value.startswith("http"):
|
||||
url_candidate = value
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 3. 回退到同步属性
|
||||
if not url_candidate and not local_path:
|
||||
for attr in ("file", "url", "path", "src", "source"):
|
||||
try:
|
||||
value = getattr(comp, attr, None)
|
||||
if value and isinstance(value, str):
|
||||
if value.startswith("http"):
|
||||
url_candidate = value
|
||||
break
|
||||
else:
|
||||
local_path = value
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return url_candidate, local_path
|
||||
|
||||
|
||||
def summarize_component_for_log(comp: Any) -> Dict[str, Any]:
|
||||
"""生成适合日志的组件属性字典(尽量不抛异常)。"""
|
||||
attrs = {}
|
||||
for a in ("file", "url", "path", "src", "source", "name"):
|
||||
try:
|
||||
v = getattr(comp, a, None)
|
||||
if v is not None:
|
||||
attrs[a] = v
|
||||
except Exception:
|
||||
continue
|
||||
return attrs
|
||||
|
||||
|
||||
async def upload_local_with_retries(
|
||||
api: Any,
|
||||
local_path: str,
|
||||
preferred_name: Optional[str],
|
||||
folder_id: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""尝试本地上传,返回 file id 或 None。如果文件类型不允许则直接失败。"""
|
||||
try:
|
||||
res = await api.upload_file(local_path, preferred_name, folder_id)
|
||||
if isinstance(res, dict):
|
||||
fid = res.get("id") or (res.get("raw") or {}).get("createdFile", {}).get(
|
||||
"id"
|
||||
)
|
||||
if fid:
|
||||
return str(fid)
|
||||
except Exception:
|
||||
# 上传失败,直接返回 None,让上层处理错误
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
@@ -15,12 +15,13 @@ class QQOfficialWebhook:
|
||||
self.appid = config["appid"]
|
||||
self.secret = config["secret"]
|
||||
self.port = config.get("port", 6196)
|
||||
self.is_sandbox = config.get("is_sandbox", False)
|
||||
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
|
||||
|
||||
if isinstance(self.port, str):
|
||||
self.port = int(self.port)
|
||||
|
||||
self.http: BotHttp = BotHttp(timeout=300)
|
||||
self.http: BotHttp = BotHttp(timeout=300, is_sandbox=self.is_sandbox)
|
||||
self.api: BotAPI = BotAPI(http=self.http)
|
||||
self.token = Token(self.appid, self.secret)
|
||||
|
||||
|
||||
@@ -499,10 +499,36 @@ class SatoriPlatformAdapter(Platform):
|
||||
}
|
||||
|
||||
return None
|
||||
except ET.ParseError as e:
|
||||
logger.warning(f"XML解析失败,使用正则提取: {e}")
|
||||
return await self._extract_quote_with_regex(content)
|
||||
except Exception as e:
|
||||
logger.error(f"提取<quote>标签时发生错误: {e}")
|
||||
return None
|
||||
|
||||
async def _extract_quote_with_regex(self, content: str) -> Optional[dict]:
|
||||
"""使用正则表达式提取quote标签信息"""
|
||||
import re
|
||||
|
||||
quote_pattern = r"<quote\s+([^>]*)>(.*?)</quote>"
|
||||
match = re.search(quote_pattern, content, re.DOTALL)
|
||||
|
||||
if not match:
|
||||
return None
|
||||
|
||||
attrs_str = match.group(1)
|
||||
inner_content = match.group(2)
|
||||
|
||||
id_match = re.search(r'id\s*=\s*["\']([^"\']*)["\']', attrs_str)
|
||||
quote_id = id_match.group(1) if id_match else ""
|
||||
content_without_quote = content.replace(match.group(0), "")
|
||||
content_without_quote = content_without_quote.strip()
|
||||
|
||||
return {
|
||||
"quote": {"id": quote_id, "content": inner_content},
|
||||
"content_without_quote": content_without_quote,
|
||||
}
|
||||
|
||||
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
|
||||
"""转换引用消息"""
|
||||
try:
|
||||
@@ -574,7 +600,7 @@ class SatoriPlatformAdapter(Platform):
|
||||
root = ET.fromstring(processed_content)
|
||||
await self._parse_xml_node(root, elements)
|
||||
except ET.ParseError as e:
|
||||
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
|
||||
logger.warning(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
|
||||
# 如果解析失败,将整个内容当作纯文本
|
||||
if content.strip():
|
||||
elements.append(Plain(text=content))
|
||||
|
||||
@@ -2,7 +2,18 @@ from typing import TYPE_CHECKING
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, At, File, Record, Video, Reply
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
At,
|
||||
File,
|
||||
Record,
|
||||
Video,
|
||||
Reply,
|
||||
Forward,
|
||||
Node,
|
||||
Nodes,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .satori_adapter import SatoriPlatformAdapter
|
||||
@@ -48,55 +59,24 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
||||
content_parts = []
|
||||
|
||||
for component in message.chain:
|
||||
if isinstance(component, Plain):
|
||||
text = (
|
||||
component.text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
)
|
||||
content_parts.append(text)
|
||||
component_content = await cls._convert_component_to_satori_static(
|
||||
component
|
||||
)
|
||||
if component_content:
|
||||
content_parts.append(component_content)
|
||||
|
||||
elif isinstance(component, At):
|
||||
if component.qq:
|
||||
content_parts.append(f'<at id="{component.qq}"/>')
|
||||
elif component.name:
|
||||
content_parts.append(f'<at name="{component.name}"/>')
|
||||
# 特殊处理 Node 和 Nodes 组件
|
||||
if isinstance(component, Node):
|
||||
# 单个转发节点
|
||||
node_content = await cls._convert_node_to_satori_static(component)
|
||||
if node_content:
|
||||
content_parts.append(node_content)
|
||||
|
||||
elif isinstance(component, Image):
|
||||
try:
|
||||
image_base64 = await component.convert_to_base64()
|
||||
if image_base64:
|
||||
content_parts.append(
|
||||
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, File):
|
||||
content_parts.append(
|
||||
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
||||
)
|
||||
|
||||
elif isinstance(component, Record):
|
||||
try:
|
||||
record_base64 = await component.convert_to_base64()
|
||||
if record_base64:
|
||||
content_parts.append(
|
||||
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"语音转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, Reply):
|
||||
content_parts.append(f'<reply id="{component.id}"/>')
|
||||
|
||||
elif isinstance(component, Video):
|
||||
try:
|
||||
video_path_url = await component.convert_to_file_path()
|
||||
if video_path_url:
|
||||
content_parts.append(f'<video src="{video_path_url}"/>')
|
||||
except Exception as e:
|
||||
logger.error(f"视频文件转换失败: {e}")
|
||||
elif isinstance(component, Nodes):
|
||||
# 合并转发消息
|
||||
node_content = await cls._convert_nodes_to_satori_static(component)
|
||||
if node_content:
|
||||
content_parts.append(node_content)
|
||||
|
||||
content = "".join(content_parts)
|
||||
channel_id = session_id
|
||||
@@ -138,55 +118,22 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
||||
content_parts = []
|
||||
|
||||
for component in message.chain:
|
||||
if isinstance(component, Plain):
|
||||
text = (
|
||||
component.text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
)
|
||||
content_parts.append(text)
|
||||
component_content = await self._convert_component_to_satori(component)
|
||||
if component_content:
|
||||
content_parts.append(component_content)
|
||||
|
||||
elif isinstance(component, At):
|
||||
if component.qq:
|
||||
content_parts.append(f'<at id="{component.qq}"/>')
|
||||
elif component.name:
|
||||
content_parts.append(f'<at name="{component.name}"/>')
|
||||
# 特殊处理 Node 和 Nodes 组件
|
||||
if isinstance(component, Node):
|
||||
# 单个转发节点
|
||||
node_content = await self._convert_node_to_satori(component)
|
||||
if node_content:
|
||||
content_parts.append(node_content)
|
||||
|
||||
elif isinstance(component, Image):
|
||||
try:
|
||||
image_base64 = await component.convert_to_base64()
|
||||
if image_base64:
|
||||
content_parts.append(
|
||||
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, File):
|
||||
content_parts.append(
|
||||
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
||||
)
|
||||
|
||||
elif isinstance(component, Record):
|
||||
try:
|
||||
record_base64 = await component.convert_to_base64()
|
||||
if record_base64:
|
||||
content_parts.append(
|
||||
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"语音转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, Reply):
|
||||
content_parts.append(f'<reply id="{component.id}"/>')
|
||||
|
||||
elif isinstance(component, Video):
|
||||
try:
|
||||
video_path_url = await component.convert_to_file_path()
|
||||
if video_path_url:
|
||||
content_parts.append(f'<video src="{video_path_url}"/>')
|
||||
except Exception as e:
|
||||
logger.error(f"视频文件转换失败: {e}")
|
||||
elif isinstance(component, Nodes):
|
||||
# 合并转发消息
|
||||
node_content = await self._convert_nodes_to_satori(component)
|
||||
if node_content:
|
||||
content_parts.append(node_content)
|
||||
|
||||
content = "".join(content_parts)
|
||||
channel_id = self.session_id
|
||||
@@ -250,3 +197,227 @@ class SatoriPlatformEvent(AstrMessageEvent):
|
||||
logger.error(f"Satori 流式消息发送异常: {e}")
|
||||
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _convert_component_to_satori(self, component) -> str:
|
||||
"""将单个消息组件转换为 Satori 格式"""
|
||||
try:
|
||||
if isinstance(component, Plain):
|
||||
text = (
|
||||
component.text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
)
|
||||
return text
|
||||
|
||||
elif isinstance(component, At):
|
||||
if component.qq:
|
||||
return f'<at id="{component.qq}"/>'
|
||||
elif component.name:
|
||||
return f'<at name="{component.name}"/>'
|
||||
|
||||
elif isinstance(component, Image):
|
||||
try:
|
||||
image_base64 = await component.convert_to_base64()
|
||||
if image_base64:
|
||||
return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, File):
|
||||
return (
|
||||
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
||||
)
|
||||
|
||||
elif isinstance(component, Record):
|
||||
try:
|
||||
record_base64 = await component.convert_to_base64()
|
||||
if record_base64:
|
||||
return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"语音转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, Reply):
|
||||
return f'<reply id="{component.id}"/>'
|
||||
|
||||
elif isinstance(component, Video):
|
||||
try:
|
||||
video_path_url = await component.convert_to_file_path()
|
||||
if video_path_url:
|
||||
return f'<video src="{video_path_url}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"视频文件转换失败: {e}")
|
||||
|
||||
elif isinstance(component, Forward):
|
||||
return f'<message id="{component.id}" forward/>'
|
||||
|
||||
# 对于其他未处理的组件类型,返回空字符串
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换消息组件失败: {e}")
|
||||
return ""
|
||||
|
||||
async def _convert_node_to_satori(self, node: Node) -> str:
|
||||
"""将单个转发节点转换为 Satori 格式"""
|
||||
try:
|
||||
content_parts = []
|
||||
if node.content:
|
||||
for content_component in node.content:
|
||||
component_content = await self._convert_component_to_satori(
|
||||
content_component
|
||||
)
|
||||
if component_content:
|
||||
content_parts.append(component_content)
|
||||
|
||||
content = "".join(content_parts)
|
||||
|
||||
# 如果内容为空,添加默认内容
|
||||
if not content.strip():
|
||||
content = "[转发消息]"
|
||||
|
||||
# 构建 Satori 格式的转发节点
|
||||
author_attrs = []
|
||||
if node.uin:
|
||||
author_attrs.append(f'id="{node.uin}"')
|
||||
if node.name:
|
||||
author_attrs.append(f'name="{node.name}"')
|
||||
|
||||
author_attr_str = " ".join(author_attrs)
|
||||
|
||||
return f"<message><author {author_attr_str}/>{content}</message>"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换转发节点失败: {e}")
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
async def _convert_component_to_satori_static(cls, component) -> str:
|
||||
"""将单个消息组件转换为 Satori 格式"""
|
||||
try:
|
||||
if isinstance(component, Plain):
|
||||
text = (
|
||||
component.text.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
)
|
||||
return text
|
||||
|
||||
elif isinstance(component, At):
|
||||
if component.qq:
|
||||
return f'<at id="{component.qq}"/>'
|
||||
elif component.name:
|
||||
return f'<at name="{component.name}"/>'
|
||||
|
||||
elif isinstance(component, Image):
|
||||
try:
|
||||
image_base64 = await component.convert_to_base64()
|
||||
if image_base64:
|
||||
return f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"图片转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, File):
|
||||
return (
|
||||
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
||||
)
|
||||
|
||||
elif isinstance(component, Record):
|
||||
try:
|
||||
record_base64 = await component.convert_to_base64()
|
||||
if record_base64:
|
||||
return f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"语音转换为base64失败: {e}")
|
||||
|
||||
elif isinstance(component, Reply):
|
||||
return f'<reply id="{component.id}"/>'
|
||||
|
||||
elif isinstance(component, Video):
|
||||
try:
|
||||
video_path_url = await component.convert_to_file_path()
|
||||
if video_path_url:
|
||||
return f'<video src="{video_path_url}"/>'
|
||||
except Exception as e:
|
||||
logger.error(f"视频文件转换失败: {e}")
|
||||
|
||||
elif isinstance(component, Forward):
|
||||
return f'<message id="{component.id}" forward/>'
|
||||
|
||||
# 对于其他未处理的组件类型,返回空字符串
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换消息组件失败: {e}")
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
async def _convert_node_to_satori_static(cls, node: Node) -> str:
|
||||
"""将单个转发节点转换为 Satori 格式"""
|
||||
try:
|
||||
content_parts = []
|
||||
if node.content:
|
||||
for content_component in node.content:
|
||||
component_content = await cls._convert_component_to_satori_static(
|
||||
content_component
|
||||
)
|
||||
if component_content:
|
||||
content_parts.append(component_content)
|
||||
|
||||
content = "".join(content_parts)
|
||||
|
||||
# 如果内容为空,添加默认内容
|
||||
if not content.strip():
|
||||
content = "[转发消息]"
|
||||
|
||||
author_attrs = []
|
||||
if node.uin:
|
||||
author_attrs.append(f'id="{node.uin}"')
|
||||
if node.name:
|
||||
author_attrs.append(f'name="{node.name}"')
|
||||
|
||||
author_attr_str = " ".join(author_attrs)
|
||||
|
||||
return f"<message><author {author_attr_str}/>{content}</message>"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换转发节点失败: {e}")
|
||||
return ""
|
||||
|
||||
async def _convert_nodes_to_satori(self, nodes: Nodes) -> str:
|
||||
"""将多个转发节点转换为 Satori 格式的合并转发"""
|
||||
try:
|
||||
node_parts = []
|
||||
|
||||
for node in nodes.nodes:
|
||||
node_content = await self._convert_node_to_satori(node)
|
||||
if node_content:
|
||||
node_parts.append(node_content)
|
||||
|
||||
if node_parts:
|
||||
return f"<message forward>{''.join(node_parts)}</message>"
|
||||
else:
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换合并转发消息失败: {e}")
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
async def _convert_nodes_to_satori_static(cls, nodes: Nodes) -> str:
|
||||
"""将多个转发节点转换为 Satori 格式的合并转发"""
|
||||
try:
|
||||
node_parts = []
|
||||
|
||||
for node in nodes.nodes:
|
||||
node_content = await cls._convert_node_to_satori_static(node)
|
||||
if node_content:
|
||||
node_parts.append(node_content)
|
||||
|
||||
if node_parts:
|
||||
return f"<message forward>{''.join(node_parts)}</message>"
|
||||
else:
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"转换合并转发消息失败: {e}")
|
||||
return ""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
import asyncio
|
||||
from typing import List
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
@@ -203,6 +204,72 @@ class EmbeddingProvider(AbstractProvider):
|
||||
"""获取向量的维度"""
|
||||
...
|
||||
|
||||
async def get_embeddings_batch(
|
||||
self,
|
||||
texts: list[str],
|
||||
batch_size: int = 16,
|
||||
tasks_limit: int = 3,
|
||||
max_retries: int = 3,
|
||||
progress_callback=None,
|
||||
) -> list[list[float]]:
|
||||
"""批量获取文本的向量,分批处理以节省内存
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
batch_size: 每批处理的文本数量
|
||||
tasks_limit: 并发任务数量限制
|
||||
max_retries: 失败时的最大重试次数
|
||||
progress_callback: 进度回调函数,接收参数 (current, total)
|
||||
|
||||
Returns:
|
||||
向量列表
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(tasks_limit)
|
||||
all_embeddings: list[list[float]] = []
|
||||
failed_batches: list[tuple[int, list[str]]] = []
|
||||
completed_count = 0
|
||||
total_count = len(texts)
|
||||
|
||||
async def process_batch(batch_idx: int, batch_texts: list[str]):
|
||||
nonlocal completed_count
|
||||
async with semaphore:
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
batch_embeddings = await self.get_embeddings(batch_texts)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
completed_count += len(batch_texts)
|
||||
if progress_callback:
|
||||
await progress_callback(completed_count, total_count)
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
# 最后一次重试失败,记录失败的批次
|
||||
failed_batches.append((batch_idx, batch_texts))
|
||||
raise Exception(
|
||||
f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {str(e)}"
|
||||
)
|
||||
# 等待一段时间后重试,使用指数退避
|
||||
await asyncio.sleep(2**attempt)
|
||||
|
||||
tasks = []
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[i : i + batch_size]
|
||||
batch_idx = i // batch_size
|
||||
tasks.append(process_batch(batch_idx, batch_texts))
|
||||
|
||||
# 收集所有任务的结果,包括失败的任务
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 检查是否有失败的任务
|
||||
errors = [r for r in results if isinstance(r, Exception)]
|
||||
if errors:
|
||||
error_msg = (
|
||||
f"有 {len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}"
|
||||
)
|
||||
raise Exception(error_msg)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
|
||||
class RerankProvider(AbstractProvider):
|
||||
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
|
||||
|
||||
@@ -10,7 +10,7 @@ from anthropic.types import Message
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from typing import AsyncGenerator
|
||||
@@ -104,7 +104,7 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
payloads["tools"] = tool_list
|
||||
@@ -135,7 +135,7 @@ class ProviderAnthropic(Provider):
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
self, payloads: dict, tools: ToolSet | None
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
@@ -326,7 +326,7 @@ class ProviderAnthropic(Provider):
|
||||
async for llm_response in self._query_stream(payloads, func_tool):
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
async def assemble_context(self, text: str, image_urls: List[str] | None = None):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
if not image_urls:
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
import re
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
from astrbot.core import logger, sp
|
||||
from dashscope import Application
|
||||
from dashscope.app.application_response import ApplicationResponse
|
||||
|
||||
|
||||
@register_provider_adapter("dashscope", "Dashscope APP 适配器。")
|
||||
@@ -62,11 +61,11 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
session_id=None,
|
||||
image_urls=[],
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
@@ -122,6 +121,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
|
||||
assert isinstance(response, ApplicationResponse)
|
||||
|
||||
logger.debug(f"dashscope resp: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
@@ -135,12 +136,12 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
),
|
||||
)
|
||||
|
||||
output_text = response.output.get("text", "")
|
||||
output_text = response.output.get("text", "") or ""
|
||||
# RAG 引用脚标格式化
|
||||
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
|
||||
if self.output_reference and response.output.get("doc_references", None):
|
||||
ref_str = ""
|
||||
for ref in response.output.get("doc_references", []):
|
||||
for ref in response.output.get("doc_references", []) or []:
|
||||
ref_title = (
|
||||
ref.get("title", "")
|
||||
if ref.get("title")
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import astrbot.core.message.components as Comp
|
||||
import os
|
||||
from typing import List
|
||||
from .. import Provider
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.utils.dify_api_client import DifyAPIClient
|
||||
from astrbot.core.utils.io import download_image_by_url, download_file
|
||||
@@ -55,11 +53,11 @@ class ProviderDify(Provider):
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
**kwargs,
|
||||
@@ -223,7 +221,7 @@ class ProviderDify(Provider):
|
||||
# Chat
|
||||
return MessageChain(chain=[Comp.Plain(chunk)])
|
||||
|
||||
async def parse_file(item: dict) -> Comp:
|
||||
async def parse_file(item: dict):
|
||||
match item["type"]:
|
||||
case "image":
|
||||
return Comp.Image(file=item["url"], url=item["url"])
|
||||
|
||||
@@ -16,7 +16,7 @@ from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from typing import List, AsyncGenerator
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
@@ -49,7 +49,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.client = AsyncAzureOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
api_version=provider_config.get("api_version", None),
|
||||
base_url=provider_config.get("api_base", None),
|
||||
base_url=provider_config.get("api_base", ""),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
else:
|
||||
@@ -79,7 +79,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
except NotFoundError as e:
|
||||
raise Exception(f"获取模型列表失败:{e}")
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
async def _query(self, payloads: dict, tools: ToolSet) -> LLMResponse:
|
||||
if tools:
|
||||
model = payloads.get("model", "").lower()
|
||||
omit_empty_param_field = "gemini" in model
|
||||
@@ -126,7 +126,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
self, payloads: dict, tools: ToolSet
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式查询API,逐步返回结果"""
|
||||
if tools:
|
||||
@@ -183,9 +183,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
yield llm_response
|
||||
|
||||
async def parse_openai_completion(
|
||||
self, completion: ChatCompletion, tools: FuncCall
|
||||
):
|
||||
async def parse_openai_completion(self, completion: ChatCompletion, tools: ToolSet):
|
||||
"""解析 OpenAI 的 ChatCompletion 响应"""
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
@@ -208,7 +206,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# workaround for #1359
|
||||
tool_call = json.loads(tool_call)
|
||||
for tool in tools.func_list:
|
||||
if tool.name == tool_call.function.name:
|
||||
if (
|
||||
tool_call.type == "function"
|
||||
and tool.name == tool_call.function.name
|
||||
):
|
||||
# workaround for #1454
|
||||
if isinstance(tool_call.function.arguments, str):
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
@@ -277,7 +278,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
e: Exception,
|
||||
payloads: dict,
|
||||
context_query: list,
|
||||
func_tool: FuncCall,
|
||||
func_tool: ToolSet,
|
||||
chosen_key: str,
|
||||
available_api_keys: List[str],
|
||||
retry_cnt: int,
|
||||
@@ -420,7 +421,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if success:
|
||||
break
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
if retry_cnt == max_retries - 1 or llm_response is None:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
if last_exception is None:
|
||||
raise Exception("未知错误")
|
||||
@@ -430,10 +431,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
session_id=None,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
@@ -526,7 +527,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
def set_key(self, key):
|
||||
self.client.api_key = key
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
|
||||
async def assemble_context(
|
||||
self, text: str, image_urls: List[str] | None = None
|
||||
) -> dict:
|
||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||
if image_urls:
|
||||
user_content = {
|
||||
|
||||
@@ -30,7 +30,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model", None))
|
||||
self.set_model(provider_config.get("model", ""))
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
|
||||
@@ -19,7 +19,7 @@ from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
from astrbot.core.knowledge_base.kb_manager_lifecycle import KnowledgeBaseManager
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from .star import star_registry, StarMetadata, star_map
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
|
||||
|
||||
class UmopConfigRouter:
|
||||
"""UMOP 配置路由器"""
|
||||
|
||||
def __init__(self, sp: SharedPreferences):
|
||||
self.umop_to_conf_id: dict[str, str] = {}
|
||||
"""UMOP 到配置文件 ID 的映射"""
|
||||
self.sp = sp
|
||||
|
||||
self._load_routing_table()
|
||||
|
||||
def _load_routing_table(self):
|
||||
"""加载路由表"""
|
||||
# 从 SharedPreferences 中加载 umop_to_conf_id 映射
|
||||
sp_data = self.sp.get(
|
||||
"umop_config_routing", {}, scope="global", scope_id="global"
|
||||
)
|
||||
self.umop_to_conf_id = sp_data
|
||||
|
||||
def _is_umo_match(self, p1: str, p2: str) -> bool:
|
||||
"""判断 p2 umo 是否逻辑包含于 p1 umo"""
|
||||
p1_ls = p1.split(":")
|
||||
p2_ls = p2.split(":")
|
||||
|
||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||
return False # 非法格式
|
||||
|
||||
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
||||
|
||||
def get_conf_id_for_umop(self, umo: str) -> str | None:
|
||||
"""根据 UMO 获取对应的配置文件 ID
|
||||
|
||||
Args:
|
||||
umo (str): UMO 字符串
|
||||
|
||||
Returns:
|
||||
str | None: 配置文件 ID,如果没有找到则返回 None
|
||||
"""
|
||||
for pattern, conf_id in self.umop_to_conf_id.items():
|
||||
if self._is_umo_match(pattern, umo):
|
||||
return conf_id
|
||||
return None
|
||||
|
||||
async def update_routing_data(self, new_routing: dict[str, str]):
|
||||
"""更新路由表
|
||||
|
||||
Args:
|
||||
new_routing (dict[str, str]): 新的 UMOP 到配置文件 ID 的映射。umo 由三个部分组成 [platform_id]:[message_type]:[session_id]。
|
||||
umop 可以是 "::" (代表所有), 可以是 "[platform_id]::" (代表指定平台下的所有类型消息和会话)。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 new_routing 中的 key 格式不正确
|
||||
"""
|
||||
for part in new_routing.keys():
|
||||
if not isinstance(part, str) or len(part.split(":")) != 3:
|
||||
raise ValueError(
|
||||
"umop keys must be strings in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all"
|
||||
)
|
||||
|
||||
self.umop_to_conf_id = new_routing
|
||||
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
|
||||
|
||||
async def update_route(self, umo: str, conf_id: str):
|
||||
"""更新一条路由
|
||||
|
||||
Args:
|
||||
umo (str): UMO 字符串
|
||||
conf_id (str): 配置文件 ID
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 umo 格式不正确
|
||||
"""
|
||||
if not isinstance(umo, str) or len(umo.split(":")) != 3:
|
||||
raise ValueError(
|
||||
"umop must be a string in the format [platform_id]:[message_type]:[session_id], with optional wildcards * or empty for all"
|
||||
)
|
||||
|
||||
self.umop_to_conf_id[umo] = conf_id
|
||||
await self.sp.global_put("umop_config_routing", self.umop_to_conf_id)
|
||||
@@ -6,6 +6,7 @@ from .route import Route, Response, RouteContext
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from quart import request
|
||||
from astrbot.core.config.default import (
|
||||
DEFAULT_CONFIG,
|
||||
CONFIG_METADATA_2,
|
||||
DEFAULT_VALUE_MAP,
|
||||
CONFIG_METADATA_3,
|
||||
@@ -152,13 +153,19 @@ class ConfigRoute(Route):
|
||||
self.config: AstrBotConfig = core_lifecycle.astrbot_config
|
||||
self._logo_token_cache = {} # 缓存logo token,避免重复注册
|
||||
self.acm = core_lifecycle.astrbot_config_mgr
|
||||
self.ucr = core_lifecycle.umop_config_router
|
||||
self.routes = {
|
||||
"/config/abconf/new": ("POST", self.create_abconf),
|
||||
"/config/abconf": ("GET", self.get_abconf),
|
||||
"/config/abconfs": ("GET", self.get_abconf_list),
|
||||
"/config/abconf/delete": ("POST", self.delete_abconf),
|
||||
"/config/abconf/update": ("POST", self.update_abconf),
|
||||
"/config/umo_abconf_routes": ("GET", self.get_uc_table),
|
||||
"/config/umo_abconf_route/update_all": ("POST", self.update_ucr_all),
|
||||
"/config/umo_abconf_route/update": ("POST", self.update_ucr),
|
||||
"/config/umo_abconf_route/delete": ("POST", self.delete_ucr),
|
||||
"/config/get": ("GET", self.get_configs),
|
||||
"/config/default": ("GET", self.get_default_config),
|
||||
"/config/astrbot/update": ("POST", self.post_astrbot_configs),
|
||||
"/config/plugin/update": ("POST", self.post_plugin_configs),
|
||||
"/config/platform/new": ("POST", self.post_new_platform),
|
||||
@@ -174,6 +181,75 @@ class ConfigRoute(Route):
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
async def get_uc_table(self):
|
||||
"""获取 UMOP 配置路由表"""
|
||||
return Response().ok({"routing": self.ucr.umop_to_conf_id}).__dict__
|
||||
|
||||
async def update_ucr_all(self):
|
||||
"""更新 UMOP 配置路由表的全部内容"""
|
||||
post_data = await request.json
|
||||
if not post_data:
|
||||
return Response().error("缺少配置数据").__dict__
|
||||
|
||||
new_routing = post_data.get("routing", None)
|
||||
|
||||
if not new_routing or not isinstance(new_routing, dict):
|
||||
return Response().error("缺少或错误的路由表数据").__dict__
|
||||
|
||||
try:
|
||||
await self.ucr.update_routing_data(new_routing)
|
||||
return Response().ok(message="更新成功").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"更新路由表失败: {str(e)}").__dict__
|
||||
|
||||
async def update_ucr(self):
|
||||
"""更新 UMOP 配置路由表"""
|
||||
post_data = await request.json
|
||||
if not post_data:
|
||||
return Response().error("缺少配置数据").__dict__
|
||||
|
||||
umo = post_data.get("umo", None)
|
||||
conf_id = post_data.get("conf_id", None)
|
||||
|
||||
if not umo or not conf_id:
|
||||
return Response().error("缺少 UMO 或配置文件 ID").__dict__
|
||||
|
||||
try:
|
||||
await self.ucr.update_route(umo, conf_id)
|
||||
return Response().ok(message="更新成功").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"更新路由表失败: {str(e)}").__dict__
|
||||
|
||||
async def delete_ucr(self):
|
||||
"""删除 UMOP 配置路由表中的一项"""
|
||||
post_data = await request.json
|
||||
if not post_data:
|
||||
return Response().error("缺少配置数据").__dict__
|
||||
|
||||
umo = post_data.get("umo", None)
|
||||
|
||||
if not umo:
|
||||
return Response().error("缺少 UMO").__dict__
|
||||
|
||||
try:
|
||||
if umo in self.ucr.umop_to_conf_id:
|
||||
del self.ucr.umop_to_conf_id[umo]
|
||||
await self.ucr.update_routing_data(self.ucr.umop_to_conf_id)
|
||||
return Response().ok(message="删除成功").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除路由表项失败: {str(e)}").__dict__
|
||||
|
||||
async def get_default_config(self):
|
||||
"""获取默认配置文件"""
|
||||
return (
|
||||
Response()
|
||||
.ok({"config": DEFAULT_CONFIG, "metadata": CONFIG_METADATA_3})
|
||||
.__dict__
|
||||
)
|
||||
|
||||
async def get_abconf_list(self):
|
||||
"""获取所有 AstrBot 配置文件的列表"""
|
||||
abconf_list = self.acm.get_conf_list()
|
||||
@@ -184,11 +260,11 @@ class ConfigRoute(Route):
|
||||
post_data = await request.json
|
||||
if not post_data:
|
||||
return Response().error("缺少配置数据").__dict__
|
||||
umo_parts = post_data["umo_parts"]
|
||||
name = post_data.get("name", None)
|
||||
config = post_data.get("config", DEFAULT_CONFIG)
|
||||
|
||||
try:
|
||||
conf_id = self.acm.create_conf(umo_parts=umo_parts, name=name)
|
||||
conf_id = self.acm.create_conf(name=name, config=config)
|
||||
return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
@@ -250,10 +326,9 @@ class ConfigRoute(Route):
|
||||
return Response().error("缺少配置文件 ID").__dict__
|
||||
|
||||
name = post_data.get("name")
|
||||
umo_parts = post_data.get("umo_parts")
|
||||
|
||||
try:
|
||||
success = self.acm.update_conf_info(conf_id, name=name, umo_parts=umo_parts)
|
||||
success = self.acm.update_conf_info(conf_id, name=name)
|
||||
if success:
|
||||
return Response().ok(message="更新成功").__dict__
|
||||
else:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -52,10 +52,17 @@ class UpdateRoute(Route):
|
||||
|
||||
try:
|
||||
dv = await get_dashboard_version()
|
||||
# WebUI 版本独立于核心版本:不再用 dv 与 v{VERSION} 比较,避免误报
|
||||
if type_ == "dashboard":
|
||||
return (
|
||||
Response()
|
||||
.ok({"has_new_version": dv != f"v{VERSION}", "current_version": dv})
|
||||
.ok(
|
||||
{
|
||||
"has_new_version": False,
|
||||
"current_version": dv,
|
||||
"installed": bool(dv),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
else:
|
||||
@@ -67,7 +74,8 @@ class UpdateRoute(Route):
|
||||
"version": f"v{VERSION}",
|
||||
"has_new_version": ret is not None,
|
||||
"dashboard_version": dv,
|
||||
"dashboard_has_new_version": dv and dv != f"v{VERSION}",
|
||||
# dv正常获取则不会提示需要更新
|
||||
"dashboard_has_new_version": not bool(dv),
|
||||
},
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
|
||||
@@ -0,0 +1,161 @@
|
||||
import base64
|
||||
import os
|
||||
import traceback
|
||||
from io import BytesIO
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.knowledge_base.kb_helper import KBHelper
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
|
||||
|
||||
|
||||
async def generate_tsne_visualization(
|
||||
query: str, kb_names: list[str], kb_manager: KnowledgeBaseManager
|
||||
) -> str | None:
|
||||
"""生成 t-SNE 可视化图片
|
||||
|
||||
Args:
|
||||
query: 查询文本
|
||||
kb_names: 知识库名称列表
|
||||
kb_manager: 知识库管理器
|
||||
|
||||
Returns:
|
||||
图片路径或 None
|
||||
"""
|
||||
try:
|
||||
import faiss
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use("Agg") # 使用非交互式后端
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.manifold import TSNE
|
||||
except ImportError as e:
|
||||
raise Exception(
|
||||
"缺少必要的库以生成 t-SNE 可视化。请安装 matplotlib 和 scikit-learn: {e}"
|
||||
) from e
|
||||
|
||||
try:
|
||||
# 获取第一个知识库的向量数据
|
||||
kb_helper: KBHelper | None = None
|
||||
for kb_name in kb_names:
|
||||
kb_helper = await kb_manager.get_kb_by_name(kb_name)
|
||||
if kb_helper:
|
||||
break
|
||||
|
||||
if not kb_helper:
|
||||
logger.warning("未找到知识库")
|
||||
return None
|
||||
|
||||
kb = kb_helper.kb
|
||||
index_path = f"data/knowledge_base/{kb.kb_id}/index.faiss"
|
||||
|
||||
# 读取 FAISS 索引
|
||||
if not os.path.exists(index_path):
|
||||
logger.warning(f"FAISS 索引不存在: {index_path}")
|
||||
return None
|
||||
|
||||
index = faiss.read_index(index_path)
|
||||
|
||||
if index.ntotal == 0:
|
||||
logger.warning("索引为空")
|
||||
return None
|
||||
|
||||
# 提取所有向量
|
||||
logger.info(f"提取 {index.ntotal} 个向量用于可视化...")
|
||||
if isinstance(index, faiss.IndexIDMap):
|
||||
base_index = faiss.downcast_index(index.index)
|
||||
if hasattr(base_index, "reconstruct_n"):
|
||||
vectors = base_index.reconstruct_n(0, index.ntotal)
|
||||
else:
|
||||
vectors = np.zeros((index.ntotal, index.d), dtype=np.float32)
|
||||
for i in range(index.ntotal):
|
||||
base_index.reconstruct(i, vectors[i])
|
||||
elif hasattr(index, "reconstruct_n"):
|
||||
vectors = index.reconstruct_n(0, index.ntotal)
|
||||
else:
|
||||
vectors = np.zeros((index.ntotal, index.d), dtype=np.float32)
|
||||
for i in range(index.ntotal):
|
||||
index.reconstruct(i, vectors[i])
|
||||
|
||||
# 获取查询向量
|
||||
vec_db: FaissVecDB = kb_helper.vec_db # type: ignore
|
||||
embedding_provider = vec_db.embedding_provider
|
||||
query_embedding = await embedding_provider.get_embedding(query)
|
||||
query_vector = np.array([query_embedding], dtype=np.float32)
|
||||
|
||||
# 合并所有向量和查询向量
|
||||
all_vectors = np.vstack([vectors, query_vector])
|
||||
|
||||
# t-SNE 降维
|
||||
logger.info("开始 t-SNE 降维...")
|
||||
perplexity = min(30, all_vectors.shape[0] - 1)
|
||||
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity)
|
||||
vectors_2d = tsne.fit_transform(all_vectors)
|
||||
|
||||
# 分离知识库向量和查询向量
|
||||
kb_vectors_2d = vectors_2d[:-1]
|
||||
query_vector_2d = vectors_2d[-1]
|
||||
|
||||
# 可视化
|
||||
logger.info("生成可视化图表...")
|
||||
plt.figure(figsize=(14, 10))
|
||||
|
||||
# 绘制知识库向量
|
||||
scatter = plt.scatter(
|
||||
kb_vectors_2d[:, 0],
|
||||
kb_vectors_2d[:, 1],
|
||||
alpha=0.5,
|
||||
s=40,
|
||||
c=range(len(kb_vectors_2d)),
|
||||
cmap="viridis",
|
||||
label="Knowledge Base Vectors",
|
||||
)
|
||||
|
||||
# 绘制查询向量(红色 X)
|
||||
plt.scatter(
|
||||
query_vector_2d[0],
|
||||
query_vector_2d[1],
|
||||
c="red",
|
||||
s=300,
|
||||
marker="X",
|
||||
edgecolors="black",
|
||||
linewidths=2,
|
||||
label="Query",
|
||||
zorder=5,
|
||||
)
|
||||
|
||||
# 添加查询文本标注
|
||||
plt.annotate(
|
||||
"Query",
|
||||
(query_vector_2d[0], query_vector_2d[1]),
|
||||
xytext=(10, 10),
|
||||
textcoords="offset points",
|
||||
fontsize=10,
|
||||
bbox={"boxstyle": "round,pad=0.5", "fc": "yellow", "alpha": 0.7},
|
||||
arrowprops={"arrowstyle": "->", "connectionstyle": "arc3,rad=0"},
|
||||
)
|
||||
|
||||
plt.colorbar(scatter, label="Vector Index")
|
||||
plt.title(
|
||||
f"t-SNE Visualization: Query in Knowledge Base\n"
|
||||
f"({index.ntotal} vectors, {index.d} dimensions, KB: {kb.kb_name})",
|
||||
fontsize=14,
|
||||
pad=20,
|
||||
)
|
||||
plt.xlabel("t-SNE Dimension 1", fontsize=12)
|
||||
plt.ylabel("t-SNE Dimension 2", fontsize=12)
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.legend(fontsize=10, loc="upper right")
|
||||
|
||||
# base64 编码图片返回
|
||||
buffer = BytesIO()
|
||||
plt.savefig(buffer, format="png", dpi=150, bbox_inches="tight")
|
||||
plt.close()
|
||||
buffer.seek(0)
|
||||
img_base64 = base64.b64encode(buffer.read()).decode("utf-8")
|
||||
return img_base64
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成 t-SNE 可视化时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
@@ -44,6 +44,7 @@
|
||||
"@mdi/font": "7.2.96",
|
||||
"@rushstack/eslint-patch": "1.3.3",
|
||||
"@types/chance": "1.1.3",
|
||||
"@types/markdown-it": "^14.1.2",
|
||||
"@types/node": "^20.5.7",
|
||||
"@vitejs/plugin-vue": "4.3.3",
|
||||
"@vue/eslint-config-prettier": "8.0.0",
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
<template>
|
||||
<div :class="$vuetify.display.mobile ? '' : 'd-flex'">
|
||||
<v-tabs v-model="tab" :direction="$vuetify.display.mobile ? 'horizontal' : 'vertical'"
|
||||
:align-tabs="$vuetify.display.mobile ? 'left' : 'start'" color="deep-purple-accent-4" class="config-tabs">
|
||||
<v-tab v-for="(val, key, index) in metadata" :key="index" :value="index"
|
||||
style="font-weight: 1000; font-size: 15px">
|
||||
{{ metadata[key]['name'] }}
|
||||
</v-tab>
|
||||
</v-tabs>
|
||||
<v-tabs-window v-model="tab" class="config-tabs-window" :style="readonly ? 'pointer-events: none; opacity: 0.6;' : ''">
|
||||
<v-tabs-window-item v-for="(val, key, index) in metadata" v-show="index == tab" :key="index">
|
||||
<v-container fluid>
|
||||
<div v-for="(val2, key2, index2) in metadata[key]['metadata']" :key="key2">
|
||||
<!-- Support both traditional and JSON selector metadata -->
|
||||
<AstrBotConfigV4 :metadata="{ [key2]: metadata[key]['metadata'][key2] }" :iterable="config_data"
|
||||
:metadataKey="key2">
|
||||
</AstrBotConfigV4>
|
||||
</div>
|
||||
</v-container>
|
||||
</v-tabs-window-item>
|
||||
|
||||
|
||||
<div style="margin-left: 16px; padding-bottom: 16px">
|
||||
<small>{{ tm('help.helpPrefix') }}
|
||||
<a href="https://astrbot.app/" target="_blank">{{ tm('help.documentation') }}</a>
|
||||
{{ tm('help.helpMiddle') }}
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=EYGsuUTfe00_iOu9JTXS7_TEpMkXOvwv&jump_from=webapi&authKey=uUEMKCROfsseS+8IzqPjzV3y1tzy4AkykwTib2jNkOFdzezF9s9XknqnIaf3CDft"
|
||||
target="_blank">{{ tm('help.support') }}</a>{{ tm('help.helpSuffix') }}
|
||||
</small>
|
||||
</div>
|
||||
|
||||
</v-tabs-window>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import AstrBotConfigV4 from '@/components/shared/AstrBotConfigV4.vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
export default {
|
||||
name: 'AstrBotCoreConfigWrapper',
|
||||
components: {
|
||||
AstrBotConfigV4
|
||||
},
|
||||
props: {
|
||||
metadata: {
|
||||
type: Object,
|
||||
required: true,
|
||||
default: () => ({})
|
||||
},
|
||||
config_data: {
|
||||
type: Object,
|
||||
required: true,
|
||||
default: () => ({})
|
||||
},
|
||||
readonly: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
}
|
||||
},
|
||||
setup() {
|
||||
const { tm } = useModuleI18n('features/config');
|
||||
return {
|
||||
tm
|
||||
};
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
tab: 0, // 用于切换配置标签页
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
// 如果需要添加其他方法,可以在这里添加
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style>
|
||||
@media (min-width: 768px) {
|
||||
.config-tabs {
|
||||
display: flex;
|
||||
margin: 16px 16px 0 0;
|
||||
}
|
||||
|
||||
.config-tabs-window {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.config-tabs .v-tab {
|
||||
justify-content: flex-start !important;
|
||||
text-align: left;
|
||||
min-height: 48px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 767px) {
|
||||
.config-tabs {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.config-tabs-window {
|
||||
margin-top: 16px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
File diff suppressed because it is too large
Load Diff
@@ -135,7 +135,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
<!-- Regular Property -->
|
||||
<template v-else>
|
||||
<v-row v-if="!metadata[metadataKey].items[key]?.invisible && shouldShowItem(metadata[metadataKey].items[key], key)" class="config-row">
|
||||
<v-col cols="12" sm="6" class="property-info">
|
||||
<v-col cols="12" sm="7" class="property-info">
|
||||
<v-list-item density="compact">
|
||||
<v-list-item-title class="property-name">
|
||||
<span v-if="metadata[metadataKey].items[key]?.description">
|
||||
@@ -153,16 +153,6 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-list-item>
|
||||
</v-col>
|
||||
|
||||
<v-col cols="12" sm="1" class="d-flex align-center type-indicator">
|
||||
<v-chip v-if="!metadata[metadataKey].items[key]?.invisible"
|
||||
color="primary"
|
||||
label
|
||||
size="x-small"
|
||||
variant="flat">
|
||||
{{ metadata[metadataKey].items[key]?.type || 'string' }}
|
||||
</v-chip>
|
||||
</v-col>
|
||||
|
||||
<v-col cols="12" sm="5" class="config-input">
|
||||
<div v-if="metadata[metadataKey].items[key]" class="w-100">
|
||||
<!-- Special handling for specific metadata types -->
|
||||
@@ -335,7 +325,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
<!-- Simple Value Configuration -->
|
||||
<div v-else class="simple-config">
|
||||
<v-row class="config-row">
|
||||
<v-col cols="12" sm="6" class="property-info">
|
||||
<v-col cols="12" sm="7" class="property-info">
|
||||
<v-list-item density="compact">
|
||||
<v-list-item-title class="property-name">
|
||||
{{ metadata[metadataKey]?.description }}
|
||||
@@ -349,16 +339,6 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-list-item>
|
||||
</v-col>
|
||||
|
||||
<v-col cols="12" sm="1" class="d-flex align-center type-indicator">
|
||||
<v-chip v-if="!metadata[metadataKey]?.invisible"
|
||||
color="primary"
|
||||
label
|
||||
size="x-small"
|
||||
variant="flat">
|
||||
{{ metadata[metadataKey]?.type }}
|
||||
</v-chip>
|
||||
</v-col>
|
||||
|
||||
<v-col cols="12" sm="5" class="config-input">
|
||||
<div class="w-100">
|
||||
<!-- Select input -->
|
||||
@@ -548,8 +528,8 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
}
|
||||
|
||||
.config-divider {
|
||||
border-color: rgba(0, 0, 0, 0.1);
|
||||
margin: 4px 0;
|
||||
border-color: rgba(0, 0, 0, 0.05);
|
||||
margin: 0px 16px;
|
||||
}
|
||||
|
||||
.editor-container {
|
||||
|
||||
@@ -120,7 +120,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
|
||||
|
||||
<v-card style="margin-bottom: 16px; padding-bottom: 8px; background-color: rgb(var(--v-theme-background));" rounded="md" variant="outlined">
|
||||
<v-card-text class="config-section" v-if="metadata[metadataKey]?.type === 'object'">
|
||||
<v-card-text class="config-section" v-if="metadata[metadataKey]?.type === 'object'" style="padding-bottom: 8px;">
|
||||
<v-list-item-title class="config-title">
|
||||
{{ metadata[metadataKey]?.description }}
|
||||
</v-list-item-title>
|
||||
@@ -365,7 +365,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
.config-row {
|
||||
margin: 0;
|
||||
align-items: center;
|
||||
padding: 10px 8px;
|
||||
padding: 8px 8px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, inject } from 'vue';
|
||||
import {useCustomizerStore} from "@/stores/customizer";
|
||||
import { useCustomizerStore } from "@/stores/customizer";
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
const props = defineProps({
|
||||
@@ -84,130 +84,130 @@ const viewReadme = () => {
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<v-card class="mx-auto d-flex flex-column" :elevation="highlight ? 0 : 1"
|
||||
:style="{ height: $vuetify.display.xs ? '250px' : '220px',
|
||||
backgroundColor: useCustomizerStore().uiTheme==='PurpleTheme' ? marketMode ? '#f8f0dd' : '#ffffff' : '#282833',
|
||||
color: useCustomizerStore().uiTheme==='PurpleTheme' ? '#000000dd' : '#ffffff'}">
|
||||
<v-card-text style="padding: 16px; padding-bottom: 0px; display: flex; justify-content: space-between;">
|
||||
<v-card class="mx-auto d-flex flex-column" elevation="2" :style="{
|
||||
position: 'relative',
|
||||
backgroundColor: useCustomizerStore().uiTheme === 'PurpleTheme' ? marketMode ? '#f8f0dd' : '#ffffff' : '#282833',
|
||||
color: useCustomizerStore().uiTheme === 'PurpleTheme' ? '#000000dd' : '#ffffff'
|
||||
}">
|
||||
<v-card-text style="padding: 16px; padding-bottom: 0px; display: flex; gap: 16px; width: 100%;">
|
||||
|
||||
<div class="flex-grow-1">
|
||||
<div>{{ extension.author }} /</div>
|
||||
<div v-if="extension?.icon">
|
||||
<v-avatar size="65">
|
||||
<v-img :src="extension.icon"
|
||||
:alt="extension.name" cover></v-img>
|
||||
</v-avatar>
|
||||
</div>
|
||||
|
||||
<p class="text-h4 font-weight-black" :class="{ 'text-h4': $vuetify.display.xs }">
|
||||
{{ extension.name }}
|
||||
<v-tooltip location="top" v-if="extension?.has_update && !marketMode">
|
||||
<template v-slot:activator="{ props: tooltipProps }">
|
||||
<v-icon v-bind="tooltipProps" color="warning" class="ml-2" icon="mdi-update" size="small"></v-icon>
|
||||
<div style="width: 100%;">
|
||||
<!-- Top-right three-dot menu -->
|
||||
<div style="position: absolute; right: 8px; top: 8px; z-index: 5;">
|
||||
<v-menu offset-y>
|
||||
<template v-slot:activator="{ props: menuProps }">
|
||||
<v-btn v-bind="menuProps" icon variant="text" aria-label="more">
|
||||
<v-icon icon="mdi-dots-vertical"></v-icon>
|
||||
</v-btn>
|
||||
</template>
|
||||
<span>{{ tm("card.status.hasUpdate") }}: {{ extension.online_version }}</span>
|
||||
</v-tooltip>
|
||||
<v-tooltip location="top" v-if="!extension.activated && !marketMode">
|
||||
<template v-slot:activator="{ props: tooltipProps }">
|
||||
<v-icon v-bind="tooltipProps" color="error" class="ml-2" icon="mdi-cancel" size="small"></v-icon>
|
||||
</template>
|
||||
<span>{{ tm("card.status.disabled") }}</span>
|
||||
</v-tooltip>
|
||||
</p>
|
||||
|
||||
<div class="mt-1 d-flex flex-wrap">
|
||||
<v-chip color="primary" label size="small">
|
||||
<v-icon icon="mdi-source-branch" start></v-icon>
|
||||
{{ extension.version }}
|
||||
</v-chip>
|
||||
<v-chip v-if="extension?.has_update " color="warning" label size="small" class="ml-2">
|
||||
<v-icon icon="mdi-arrow-up-bold" start></v-icon>
|
||||
{{ extension.online_version }}
|
||||
</v-chip>
|
||||
<v-chip color="primary" label size="small" class="ml-2" v-if="extension.handlers?.length">
|
||||
<v-icon icon="mdi-cogs" start></v-icon>
|
||||
{{ extension.handlers?.length }}{{ tm("card.status.handlersCount") }}
|
||||
</v-chip>
|
||||
<v-chip v-for="tag in extension.tags" :key="tag" :color="tag === 'danger' ? 'error' : 'primary'" label
|
||||
size="small" class="ml-2">
|
||||
{{ tag === 'danger' ? tm('tags.danger') : tag }}
|
||||
</v-chip>
|
||||
<v-list>
|
||||
<v-list-item @click="viewReadme">
|
||||
<v-list-item-title>📄 {{ tm('buttons.viewDocs') }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item v-if="marketMode && !extension?.installed" @click="installExtension">
|
||||
<v-list-item-title>
|
||||
{{ tm('buttons.install') }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item v-if="marketMode && extension?.installed">
|
||||
<v-list-item-title class="text--disabled">{{ tm('status.installed') }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<!-- Divider between market actions and plugin actions -->
|
||||
<v-divider v-if="!marketMode" />
|
||||
|
||||
<template v-if="!marketMode">
|
||||
<v-list-item @click="configure">
|
||||
<v-list-item-title>
|
||||
{{ tm('card.actions.pluginConfig') }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item @click="uninstallExtension">
|
||||
<v-list-item-title class="text-error">{{ tm('card.actions.uninstallPlugin') }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item @click="reloadExtension">
|
||||
<v-list-item-title>{{ tm('card.actions.reloadPlugin') }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item @click="toggleActivation">
|
||||
<v-list-item-title>
|
||||
{{ extension.activated ? tm('buttons.disable') : tm('buttons.enable') }}{{
|
||||
tm('card.actions.togglePlugin') }}
|
||||
</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item @click="viewHandlers">
|
||||
<v-list-item-title>{{ tm('card.actions.viewHandlers') }} ({{ extension.handlers.length
|
||||
}})</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item @click="updateExtension" :disabled="!extension?.has_update">
|
||||
<v-list-item-title>
|
||||
{{ tm('card.actions.updateTo') }} {{ extension.online_version || extension.version }}
|
||||
</v-list-item-title>
|
||||
</v-list-item>
|
||||
</template>
|
||||
</v-list>
|
||||
</v-menu>
|
||||
</div>
|
||||
|
||||
<div class="mt-2" :class="{ 'text-caption': $vuetify.display.xs }" style="max-height: 65px; overflow-y: auto;">
|
||||
{{ extension.desc }}
|
||||
<div style="width: 100%; margin-bottom: 24px;">
|
||||
<!-- 最多一行 -->
|
||||
<div class="text-caption" style="color: gray; white-space: nowrap; overflow: hidden; text-overflow: ellipsis; margin-right: 36px;">
|
||||
{{ extension.author }} / {{ extension.name }}
|
||||
</div>
|
||||
<p class="text-h3 font-weight-black" :class="{ 'text-h4': $vuetify.display.xs }">
|
||||
{{ extension.name }}
|
||||
<v-tooltip location="top" v-if="extension?.has_update && !marketMode">
|
||||
<template v-slot:activator="{ props: tooltipProps }">
|
||||
<v-icon v-bind="tooltipProps" color="warning" class="ml-2" icon="mdi-update" size="small"></v-icon>
|
||||
</template>
|
||||
<span>{{ tm("card.status.hasUpdate") }}: {{ extension.online_version }}</span>
|
||||
</v-tooltip>
|
||||
<v-tooltip location="top" v-if="!extension.activated && !marketMode">
|
||||
<template v-slot:activator="{ props: tooltipProps }">
|
||||
<v-icon v-bind="tooltipProps" color="error" class="ml-2" icon="mdi-cancel" size="small"></v-icon>
|
||||
</template>
|
||||
<span>{{ tm("card.status.disabled") }}</span>
|
||||
</v-tooltip>
|
||||
</p>
|
||||
|
||||
<div class="mt-1 d-flex flex-wrap">
|
||||
<v-chip color="primary" label size="small">
|
||||
<v-icon icon="mdi-source-branch" start></v-icon>
|
||||
{{ extension.version }}
|
||||
</v-chip>
|
||||
<v-chip v-if="extension?.has_update" color="warning" label size="small" class="ml-2">
|
||||
<v-icon icon="mdi-arrow-up-bold" start></v-icon>
|
||||
{{ extension.online_version }}
|
||||
</v-chip>
|
||||
<v-chip color="primary" label size="small" class="ml-2" v-if="extension.handlers?.length">
|
||||
<v-icon icon="mdi-cogs" start></v-icon>
|
||||
{{ extension.handlers?.length }}{{ tm("card.status.handlersCount") }}
|
||||
</v-chip>
|
||||
<v-chip v-for="tag in extension.tags" :key="tag" :color="tag === 'danger' ? 'error' : 'primary'" label
|
||||
size="small" class="ml-2">
|
||||
{{ tag === 'danger' ? tm('tags.danger') : tag }}
|
||||
</v-chip>
|
||||
</div>
|
||||
|
||||
<div class="mt-2" :class="{ 'text-caption': $vuetify.display.xs }" style="overflow-y: auto; height: 60px;">
|
||||
{{ extension.desc }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="extension-image-container" v-if="extension.logo">
|
||||
<img :src="extension.logo" :style="{
|
||||
height: $vuetify.display.xs ? '75px' : '100px',
|
||||
width: $vuetify.display.xs ? '75px' : '100px',
|
||||
borderRadius: '8px',
|
||||
objectFit: 'cover',
|
||||
objectPosition: 'center'
|
||||
}" :alt="tm('card.alt.logo')" />
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions style="margin-left: 0px; gap: 2px;">
|
||||
<v-btn color="teal-accent-4" :text="tm('buttons.viewDocs')" variant="text" @click="viewReadme"></v-btn>
|
||||
<v-btn v-if="!marketMode" color="teal-accent-4" :text="tm('buttons.actions')" variant="text" @click="reveal = true"></v-btn>
|
||||
<v-btn v-if="marketMode && !extension?.installed" color="teal-accent-4" :text="tm('buttons.install')" variant="text"
|
||||
@click="installExtension"></v-btn>
|
||||
<v-btn v-if="marketMode && extension?.installed" color="teal-accent-4" :text="tm('status.installed')" variant="text" disabled></v-btn>
|
||||
</v-card-actions>
|
||||
|
||||
<v-expand-transition v-if="!marketMode">
|
||||
<v-card v-if="reveal" class="position-absolute w-100" height="100%"
|
||||
style="bottom: 0; display: flex; flex-direction: column;">
|
||||
<v-card-text style="overflow-y: auto;">
|
||||
<div class="d-flex align-center mb-4">
|
||||
<img v-if="extension.logo" :src="extension.logo"
|
||||
style="height: 50px; width: 50px; border-radius: 8px; margin-right: 16px;" :alt="tm('card.alt.extensionIcon')" />
|
||||
<h3>{{ extension.name }}</h3>
|
||||
</div>
|
||||
|
||||
<div class="mt-4" :style="{
|
||||
justifyContent: 'center',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
flexWrap: 'wrap',
|
||||
gap: '8px',
|
||||
flexDirection: $vuetify.display.xs ? 'column' : 'row'
|
||||
}">
|
||||
<v-btn prepend-icon="mdi-cog" color="primary" variant="tonal" @click="configure"
|
||||
:block="$vuetify.display.xs">
|
||||
{{ tm("card.actions.pluginConfig") }}
|
||||
</v-btn>
|
||||
|
||||
<v-btn prepend-icon="mdi-delete" color="error" variant="tonal" @click="uninstallExtension"
|
||||
:block="$vuetify.display.xs">
|
||||
{{ tm("card.actions.uninstallPlugin") }}
|
||||
</v-btn>
|
||||
|
||||
<v-btn prepend-icon="mdi-reload" color="primary" variant="tonal" @click="reloadExtension"
|
||||
:block="$vuetify.display.xs">
|
||||
{{ tm("card.actions.reloadPlugin") }}
|
||||
</v-btn>
|
||||
|
||||
<v-btn :prepend-icon="extension.activated ? 'mdi-cancel' : 'mdi-check-circle'"
|
||||
:color="extension.activated ? 'error' : 'success'" variant="tonal" @click="toggleActivation"
|
||||
:block="$vuetify.display.xs">
|
||||
{{ extension.activated ? tm('buttons.disable') : tm('buttons.enable') }}{{ tm("card.actions.togglePlugin") }}
|
||||
</v-btn>
|
||||
|
||||
<v-btn prepend-icon="mdi-cogs" color="info" variant="tonal" @click="viewHandlers"
|
||||
:block="$vuetify.display.xs">
|
||||
{{ tm("card.actions.viewHandlers") }} ({{ extension.handlers.length }})
|
||||
</v-btn>
|
||||
|
||||
<v-btn prepend-icon="mdi-update" color="primary" variant="tonal" :disabled="!extension?.has_update "
|
||||
@click="updateExtension" :block="$vuetify.display.xs">
|
||||
{{ tm("card.actions.updateTo") }} {{ extension.online_version || extension.version }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="pt-0 d-flex justify-center">
|
||||
<v-btn color="teal-accent-4" :text="tm('buttons.back')" variant="text" @click="reveal = false"></v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-expand-transition>
|
||||
</v-card>
|
||||
|
||||
</template>
|
||||
|
||||
@@ -127,7 +127,6 @@ export default {
|
||||
transition: all 0.3s ease;
|
||||
overflow: hidden;
|
||||
min-height: 220px;
|
||||
margin-bottom: 16px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
justify-content: space-between;
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
<template>
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<span v-if="!modelValue" style="color: rgb(var(--v-theme-primaryText));">
|
||||
<span v-if="!modelValue || (Array.isArray(modelValue) && modelValue.length === 0)"
|
||||
style="color: rgb(var(--v-theme-primaryText));">
|
||||
未选择
|
||||
</span>
|
||||
<span v-else>
|
||||
{{ modelValue }}
|
||||
</span>
|
||||
<div v-else class="d-flex flex-wrap gap-1">
|
||||
<v-chip
|
||||
v-for="name in modelValue"
|
||||
:key="name"
|
||||
size="small"
|
||||
color="primary"
|
||||
variant="tonal"
|
||||
closable
|
||||
@click:close="removeKnowledgeBase(name)">
|
||||
{{ name }}
|
||||
</v-chip>
|
||||
</div>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ buttonText }}
|
||||
</v-btn>
|
||||
@@ -21,86 +31,57 @@
|
||||
<v-card-text class="pa-0" style="max-height: 400px; overflow-y: auto;">
|
||||
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
|
||||
|
||||
<!-- 插件未安装提示 -->
|
||||
<div v-if="!loading && !pluginInstalled" class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-puzzle-outline</v-icon>
|
||||
<p class="text-grey mt-4 mb-4">知识库插件未安装</p>
|
||||
<v-btn color="primary" variant="tonal" @click="goToKnowledgeBasePage">
|
||||
前往知识库页面
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- 知识库列表 -->
|
||||
<v-list v-else-if="!loading && pluginInstalled" density="compact">
|
||||
<!-- 不使用选项 -->
|
||||
<v-list-item
|
||||
:value="''"
|
||||
@click="selectKnowledgeBase({ collection_name: '' })"
|
||||
:active="selectedKnowledgeBase === ''"
|
||||
rounded="md"
|
||||
class="ma-1">
|
||||
<template v-slot:prepend>
|
||||
<v-icon color="grey-lighten-1">mdi-close-circle-outline</v-icon>
|
||||
</template>
|
||||
<v-list-item-title>不使用</v-list-item-title>
|
||||
<v-list-item-subtitle>不使用任何知识库</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-icon v-if="selectedKnowledgeBase === ''" color="primary">mdi-check-circle</v-icon>
|
||||
</template>
|
||||
</v-list-item>
|
||||
|
||||
<v-divider v-if="knowledgeBaseList.length > 0" class="my-2"></v-divider>
|
||||
|
||||
<v-list v-if="!loading" density="compact">
|
||||
<!-- 知识库选项 -->
|
||||
<v-list-item
|
||||
v-for="kb in knowledgeBaseList"
|
||||
:key="kb.collection_name"
|
||||
:value="kb.collection_name"
|
||||
@click="selectKnowledgeBase(kb)"
|
||||
:active="selectedKnowledgeBase === kb.collection_name"
|
||||
:key="kb.kb_id"
|
||||
:value="kb.kb_name"
|
||||
@click="selectKnowledgeBase(kb.kb_name)"
|
||||
:active="isSelected(kb.kb_name)"
|
||||
rounded="md"
|
||||
class="ma-1">
|
||||
<template v-slot:prepend>
|
||||
<span class="emoji-icon">{{ kb.emoji || '🙂' }}</span>
|
||||
<span class="emoji-icon">{{ kb.emoji || '📚' }}</span>
|
||||
</template>
|
||||
<v-list-item-title>{{ kb.collection_name }}</v-list-item-title>
|
||||
<v-list-item-title>{{ kb.kb_name }}</v-list-item-title>
|
||||
<v-list-item-subtitle>
|
||||
{{ kb.description || '无描述' }}
|
||||
<span v-if="kb.count !== undefined"> - {{ kb.count }} 项知识</span>
|
||||
<span v-if="kb.doc_count !== undefined"> - {{ kb.doc_count }} 个文档</span>
|
||||
<span v-if="kb.chunk_count !== undefined"> - {{ kb.chunk_count }} 个块</span>
|
||||
</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-icon v-if="selectedKnowledgeBase === kb.collection_name" color="primary">mdi-check-circle</v-icon>
|
||||
<v-icon v-if="isSelected(kb.kb_name)" color="primary">
|
||||
mdi-checkbox-marked
|
||||
</v-icon>
|
||||
<v-icon v-else color="grey-lighten-1">
|
||||
mdi-checkbox-blank-outline
|
||||
</v-icon>
|
||||
</template>
|
||||
</v-list-item>
|
||||
|
||||
<!-- 当没有知识库时显示创建提示 -->
|
||||
<div v-if="knowledgeBaseList.length === 0" class="text-center py-4">
|
||||
<p class="text-grey mb-4">暂无知识库</p>
|
||||
<v-btn color="primary" variant="tonal" size="small" @click="goToKnowledgeBasePage">
|
||||
<div v-if="knowledgeBaseList.length === 0" class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-database-off</v-icon>
|
||||
<p class="text-grey mt-4 mb-4">暂无知识库</p>
|
||||
<v-btn color="primary" variant="tonal" @click="goToKnowledgeBasePage">
|
||||
创建知识库
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-list>
|
||||
|
||||
<!-- 空状态(插件未安装时保留原有逻辑) -->
|
||||
<div v-else-if="!loading && !pluginInstalled && knowledgeBaseList.length === 0" class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-database-off</v-icon>
|
||||
<p class="text-grey mt-4 mb-4">暂无知识库</p>
|
||||
<v-btn color="primary" variant="tonal" @click="goToKnowledgeBasePage">
|
||||
创建知识库
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<div v-if="selectedKnowledgeBases.length > 0" class="text-caption text-grey">
|
||||
已选择 {{ selectedKnowledgeBases.length }} 个知识库
|
||||
</div>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="cancelSelection">取消</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
@click="confirmSelection"
|
||||
:disabled="selectedKnowledgeBase === null || selectedKnowledgeBase === undefined">
|
||||
@click="confirmSelection">
|
||||
确认选择
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
@@ -115,8 +96,8 @@ import { useRouter } from 'vue-router'
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
type: String,
|
||||
default: ''
|
||||
type: Array,
|
||||
default: () => []
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
@@ -130,74 +111,87 @@ const router = useRouter()
|
||||
const dialog = ref(false)
|
||||
const knowledgeBaseList = ref([])
|
||||
const loading = ref(false)
|
||||
const selectedKnowledgeBase = ref('')
|
||||
const pluginInstalled = ref(false)
|
||||
const selectedKnowledgeBases = ref([])
|
||||
|
||||
// 监听 modelValue 变化,同步到 selectedKnowledgeBase
|
||||
// 监听 modelValue 变化,同步到 selectedKnowledgeBases
|
||||
watch(() => props.modelValue, (newValue) => {
|
||||
selectedKnowledgeBase.value = newValue || ''
|
||||
selectedKnowledgeBases.value = Array.isArray(newValue) ? [...newValue] : []
|
||||
}, { immediate: true })
|
||||
|
||||
async function openDialog() {
|
||||
selectedKnowledgeBase.value = props.modelValue || ''
|
||||
// 初始化选中状态
|
||||
selectedKnowledgeBases.value = Array.isArray(props.modelValue)
|
||||
? [...props.modelValue]
|
||||
: []
|
||||
|
||||
dialog.value = true
|
||||
await checkPluginAndLoadKnowledgeBases()
|
||||
await loadKnowledgeBases()
|
||||
}
|
||||
|
||||
async function checkPluginAndLoadKnowledgeBases() {
|
||||
async function loadKnowledgeBases() {
|
||||
loading.value = true
|
||||
try {
|
||||
// 首先检查插件是否安装
|
||||
const pluginResponse = await axios.get('/api/plugin/get?name=astrbot_plugin_knowledge_base')
|
||||
const response = await axios.get('/api/kb/list', {
|
||||
params: {
|
||||
page: 1,
|
||||
page_size: 100
|
||||
}
|
||||
})
|
||||
|
||||
if (pluginResponse.data.status === 'ok' && pluginResponse.data.data.length > 0) {
|
||||
pluginInstalled.value = true
|
||||
// 插件已安装,获取知识库列表
|
||||
await loadKnowledgeBases()
|
||||
if (response.data.status === 'ok') {
|
||||
knowledgeBaseList.value = response.data.data.items || []
|
||||
} else {
|
||||
pluginInstalled.value = false
|
||||
console.error('加载知识库列表失败:', response.data.message)
|
||||
knowledgeBaseList.value = []
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('检查知识库插件失败:', error)
|
||||
pluginInstalled.value = false
|
||||
console.error('加载知识库列表失败:', error)
|
||||
knowledgeBaseList.value = []
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
async function loadKnowledgeBases() {
|
||||
try {
|
||||
const response = await axios.get('/api/plug/alkaid/kb/collections')
|
||||
if (response.data.status === 'ok') {
|
||||
knowledgeBaseList.value = response.data.data || []
|
||||
} else {
|
||||
knowledgeBaseList.value = []
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('加载知识库列表失败:', error)
|
||||
knowledgeBaseList.value = []
|
||||
function isSelected(kbName) {
|
||||
return selectedKnowledgeBases.value.includes(kbName)
|
||||
}
|
||||
|
||||
function selectKnowledgeBase(kbName) {
|
||||
// 多选模式:切换选中状态
|
||||
const index = selectedKnowledgeBases.value.indexOf(kbName)
|
||||
if (index > -1) {
|
||||
selectedKnowledgeBases.value.splice(index, 1)
|
||||
} else {
|
||||
selectedKnowledgeBases.value.push(kbName)
|
||||
}
|
||||
}
|
||||
|
||||
function selectKnowledgeBase(kb) {
|
||||
selectedKnowledgeBase.value = kb.collection_name
|
||||
function removeKnowledgeBase(kbName) {
|
||||
const index = selectedKnowledgeBases.value.indexOf(kbName)
|
||||
if (index > -1) {
|
||||
selectedKnowledgeBases.value.splice(index, 1)
|
||||
}
|
||||
|
||||
// 立即更新父组件
|
||||
emit('update:modelValue', [...selectedKnowledgeBases.value])
|
||||
}
|
||||
|
||||
function confirmSelection() {
|
||||
emit('update:modelValue', selectedKnowledgeBase.value)
|
||||
emit('update:modelValue', [...selectedKnowledgeBases.value])
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function cancelSelection() {
|
||||
selectedKnowledgeBase.value = props.modelValue || ''
|
||||
// 恢复到原始值
|
||||
selectedKnowledgeBases.value = Array.isArray(props.modelValue)
|
||||
? [...props.modelValue]
|
||||
: []
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function goToKnowledgeBasePage() {
|
||||
dialog.value = false
|
||||
router.push('/alkaid/knowledge-base')
|
||||
router.push('/knowledge-base')
|
||||
}
|
||||
</script>
|
||||
|
||||
@@ -222,4 +216,8 @@ function goToKnowledgeBasePage() {
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
|
||||
.gap-1 {
|
||||
gap: 4px;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -0,0 +1,536 @@
|
||||
<template>
|
||||
<v-dialog v-model="showDialog" max-width="500px" persistent>
|
||||
<v-card>
|
||||
<v-card-title class="text-h2">
|
||||
{{ editingPersona ? tm('dialog.edit.title') : tm('dialog.create.title') }}
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text>
|
||||
<v-form ref="personaForm" v-model="formValid">
|
||||
<v-text-field v-model="personaForm.persona_id" :label="tm('form.personaId')"
|
||||
:rules="personaIdRules" :disabled="editingPersona" variant="outlined" density="comfortable"
|
||||
class="mb-4" />
|
||||
|
||||
<v-textarea v-model="personaForm.system_prompt" :label="tm('form.systemPrompt')"
|
||||
:rules="systemPromptRules" variant="outlined" rows="6" class="mb-4" />
|
||||
|
||||
<v-expansion-panels v-model="expandedPanels" multiple>
|
||||
<!-- 工具选择面板 -->
|
||||
<v-expansion-panel value="tools">
|
||||
<v-expansion-panel-title>
|
||||
<v-icon class="mr-2">mdi-tools</v-icon>
|
||||
{{ tm('form.tools') }}
|
||||
<v-chip v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
|
||||
size="small" color="primary" variant="tonal" class="ml-2">
|
||||
{{ personaForm.tools.length }}
|
||||
</v-chip>
|
||||
</v-expansion-panel-title>
|
||||
|
||||
<v-expansion-panel-text>
|
||||
<div class="mb-3">
|
||||
<p class="text-body-2 text-medium-emphasis">
|
||||
{{ tm('form.toolsHelp') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<v-radio-group class="mt-2" v-model="toolSelectValue" hide-details="true">
|
||||
<v-radio label="默认使用全部函数工具" value="0"></v-radio>
|
||||
<v-radio label="选择指定函数工具" value="1">
|
||||
</v-radio>
|
||||
</v-radio-group>
|
||||
|
||||
<div v-if="toolSelectValue === '1'" class="mt-3 ml-8">
|
||||
|
||||
<!-- 工具搜索 -->
|
||||
<v-text-field v-model="toolSearch" :label="tm('form.searchTools')"
|
||||
prepend-inner-icon="mdi-magnify" variant="outlined" density="compact"
|
||||
hide-details clearable class="mb-3" />
|
||||
|
||||
|
||||
<!-- MCP 服务器 -->
|
||||
<div v-if="mcpServers.length > 0" class="mb-4">
|
||||
<h4 class="text-subtitle-2 mb-2">{{ tm('form.mcpServersQuickSelect') }}</h4>
|
||||
<div class="d-flex flex-wrap ga-2">
|
||||
<v-chip v-for="server in mcpServers" :key="server.name"
|
||||
:color="isServerSelected(server) ? 'primary' : 'default'"
|
||||
:variant="isServerSelected(server) ? 'flat' : 'outlined'"
|
||||
size="small" clickable @click="toggleMcpServer(server)"
|
||||
:disabled="!server.tools || server.tools.length === 0">
|
||||
<v-icon start size="small">mdi-server</v-icon>
|
||||
{{ server.name }}
|
||||
<v-chip-text v-if="server.tools" class="ml-1">
|
||||
({{ server.tools.length }})
|
||||
</v-chip-text>
|
||||
</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 工具选择列表 -->
|
||||
<div v-if="filteredTools.length > 0" class="tools-selection">
|
||||
<v-virtual-scroll :items="filteredTools" height="300" item-height="48">
|
||||
<template v-slot:default="{ item }">
|
||||
<v-list-item :key="item.name" density="comfortable"
|
||||
@click="toggleTool(item.name)">
|
||||
<template v-slot:prepend>
|
||||
<v-checkbox-btn :model-value="isToolSelected(item.name)"
|
||||
@click.stop="toggleTool(item.name)" />
|
||||
</template>
|
||||
|
||||
<v-list-item-title>
|
||||
{{ item.name }}
|
||||
<v-chip v-if="item.mcp_server_name" size="x-small"
|
||||
color="secondary" variant="tonal" class="ml-2">
|
||||
{{ item.mcp_server_name }}
|
||||
</v-chip>
|
||||
</v-list-item-title>
|
||||
|
||||
<v-list-item-subtitle v-if="item.description">
|
||||
{{ truncateText(item.description, 100) }}
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</template>
|
||||
</v-virtual-scroll>
|
||||
</div>
|
||||
|
||||
<div v-else-if="!loadingTools && availableTools.length === 0"
|
||||
class="text-center pa-4">
|
||||
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-tools</v-icon>
|
||||
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsAvailable')
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div v-else-if="!loadingTools && filteredTools.length === 0"
|
||||
class="text-center pa-4">
|
||||
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-magnify</v-icon>
|
||||
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsFound') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 加载状态 -->
|
||||
<div v-if="loadingTools" class="text-center pa-4">
|
||||
<v-progress-circular indeterminate color="primary" />
|
||||
<p class="text-body-2 text-medium-emphasis mt-2">{{ tm('form.loadingTools')
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 已选择的工具 -->
|
||||
<div class="mt-4">
|
||||
<h4 class="text-subtitle-2 mb-2">
|
||||
{{ tm('form.selectedTools') }}
|
||||
<span v-if="personaForm.tools === null" class="text-success">
|
||||
({{ tm('form.allSelected') }})
|
||||
</span>
|
||||
<span v-else-if="Array.isArray(personaForm.tools)">
|
||||
({{ personaForm.tools.length }})
|
||||
</span>
|
||||
</h4>
|
||||
<div v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
|
||||
class="d-flex flex-wrap ga-1" style="max-height: 100px; overflow-y: auto;">
|
||||
<v-chip v-for="toolName in personaForm.tools" :key="toolName"
|
||||
size="small" color="primary" variant="tonal" closable
|
||||
@click:close="removeTool(toolName)">
|
||||
{{ toolName }}
|
||||
</v-chip>
|
||||
</div>
|
||||
<div v-else class="text-body-2 text-medium-emphasis">
|
||||
{{ tm('form.noToolsSelected') }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</v-expansion-panel-text>
|
||||
</v-expansion-panel>
|
||||
|
||||
<!-- 预设对话面板 -->
|
||||
<v-expansion-panel value="dialogs">
|
||||
<v-expansion-panel-title>
|
||||
<v-icon class="mr-2">mdi-chat</v-icon>
|
||||
{{ tm('form.presetDialogs') }}
|
||||
<v-chip v-if="personaForm.begin_dialogs.length > 0" size="small" color="primary"
|
||||
variant="tonal" class="ml-2">
|
||||
{{ personaForm.begin_dialogs.length / 2 }}
|
||||
</v-chip>
|
||||
</v-expansion-panel-title>
|
||||
|
||||
<v-expansion-panel-text>
|
||||
<div class="mb-3">
|
||||
<p class="text-body-2 text-medium-emphasis">
|
||||
{{ tm('form.presetDialogsHelp') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div v-for="(dialog, index) in personaForm.begin_dialogs" :key="index" class="mb-3">
|
||||
<v-textarea v-model="personaForm.begin_dialogs[index]"
|
||||
:label="index % 2 === 0 ? tm('form.userMessage') : tm('form.assistantMessage')"
|
||||
:rules="getDialogRules(index)" variant="outlined" rows="2"
|
||||
density="comfortable">
|
||||
<template v-slot:append>
|
||||
<v-btn icon="mdi-delete" variant="text" size="small" color="error"
|
||||
@click="removeDialog(index)" />
|
||||
</template>
|
||||
</v-textarea>
|
||||
</div>
|
||||
|
||||
<v-btn variant="outlined" prepend-icon="mdi-plus" @click="addDialogPair" block>
|
||||
{{ tm('buttons.addDialogPair') }}
|
||||
</v-btn>
|
||||
</v-expansion-panel-text>
|
||||
</v-expansion-panel>
|
||||
</v-expansion-panels>
|
||||
</v-form>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions>
|
||||
<v-spacer />
|
||||
<v-btn color="grey" variant="text" @click="closeDialog">
|
||||
{{ tm('buttons.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" variant="flat" @click="savePersona" :loading="saving" :disabled="!formValid">
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import axios from 'axios';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
export default {
|
||||
name: 'PersonaForm',
|
||||
props: {
|
||||
modelValue: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
},
|
||||
editingPersona: {
|
||||
type: Object,
|
||||
default: null
|
||||
}
|
||||
},
|
||||
emits: ['update:modelValue', 'saved', 'error'],
|
||||
setup() {
|
||||
const { tm } = useModuleI18n('features/persona');
|
||||
return { tm };
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
toolSelectValue: '0', // 默认选择全部工具
|
||||
saving: false,
|
||||
expandedPanels: [],
|
||||
formValid: false,
|
||||
mcpServers: [],
|
||||
availableTools: [],
|
||||
loadingTools: false,
|
||||
personaForm: {
|
||||
persona_id: '',
|
||||
system_prompt: '',
|
||||
begin_dialogs: [],
|
||||
tools: []
|
||||
},
|
||||
personaIdRules: [
|
||||
v => !!v || this.tm('validation.required'),
|
||||
v => (v && v.length >= 0) || this.tm('validation.minLength', { min: 2 }),
|
||||
],
|
||||
systemPromptRules: [
|
||||
v => !!v || this.tm('validation.required'),
|
||||
v => (v && v.length >= 10) || this.tm('validation.minLength', { min: 10 })
|
||||
],
|
||||
toolSearch: ''
|
||||
}
|
||||
},
|
||||
|
||||
computed: {
|
||||
showDialog: {
|
||||
get() {
|
||||
return this.modelValue;
|
||||
},
|
||||
set(value) {
|
||||
this.$emit('update:modelValue', value);
|
||||
}
|
||||
},
|
||||
filteredTools() {
|
||||
if (!this.toolSearch) {
|
||||
return this.availableTools;
|
||||
}
|
||||
const search = this.toolSearch.toLowerCase();
|
||||
return this.availableTools.filter(tool =>
|
||||
tool.name.toLowerCase().includes(search) ||
|
||||
(tool.description && tool.description.toLowerCase().includes(search)) ||
|
||||
(tool.mcp_server_name && tool.mcp_server_name.toLowerCase().includes(search))
|
||||
);
|
||||
}
|
||||
},
|
||||
|
||||
watch: {
|
||||
modelValue(newValue) {
|
||||
if (newValue) {
|
||||
// 只有在不是编辑状态时才初始化空表单
|
||||
if (this.editingPersona) {
|
||||
this.initFormWithPersona(this.editingPersona);
|
||||
} else {
|
||||
this.initForm();
|
||||
}
|
||||
this.loadMcpServers();
|
||||
this.loadTools();
|
||||
}
|
||||
},
|
||||
editingPersona: {
|
||||
immediate: true,
|
||||
handler(newPersona) {
|
||||
// 只有在对话框打开时才处理editingPersona的变化
|
||||
if (this.modelValue) {
|
||||
if (newPersona) {
|
||||
this.initFormWithPersona(newPersona);
|
||||
} else {
|
||||
this.initForm();
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
toolSelectValue(newValue) {
|
||||
if (newValue === '0') {
|
||||
// 选择全部工具
|
||||
this.personaForm.tools = null;
|
||||
} else if (newValue === '1') {
|
||||
// 选择指定工具,如果当前是null,则转换为空数组
|
||||
if (this.personaForm.tools === null) {
|
||||
this.personaForm.tools = [];
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
methods: {
|
||||
initForm() {
|
||||
this.personaForm = {
|
||||
persona_id: '',
|
||||
system_prompt: '',
|
||||
begin_dialogs: [],
|
||||
tools: []
|
||||
};
|
||||
this.toolSelectValue = '0';
|
||||
this.expandedPanels = [];
|
||||
},
|
||||
|
||||
initFormWithPersona(persona) {
|
||||
this.personaForm = {
|
||||
persona_id: persona.persona_id,
|
||||
system_prompt: persona.system_prompt,
|
||||
begin_dialogs: [...(persona.begin_dialogs || [])],
|
||||
tools: persona.tools === null ? null : [...(persona.tools || [])]
|
||||
};
|
||||
// 根据 tools 的值设置 toolSelectValue
|
||||
this.toolSelectValue = persona.tools === null ? '0' : '1';
|
||||
this.expandedPanels = [];
|
||||
},
|
||||
|
||||
closeDialog() {
|
||||
this.showDialog = false;
|
||||
},
|
||||
|
||||
async loadMcpServers() {
|
||||
try {
|
||||
const response = await axios.get('/api/tools/mcp/servers');
|
||||
if (response.data.status === 'ok') {
|
||||
this.mcpServers = response.data.data || [];
|
||||
} else {
|
||||
this.$emit('error', response.data.message || 'Failed to load MCP servers');
|
||||
}
|
||||
} catch (error) {
|
||||
this.$emit('error', error.response?.data?.message || 'Failed to load MCP servers');
|
||||
this.mcpServers = [];
|
||||
}
|
||||
},
|
||||
|
||||
async loadTools() {
|
||||
this.loadingTools = true;
|
||||
try {
|
||||
const response = await axios.get('/api/tools/list');
|
||||
if (response.data.status === 'ok') {
|
||||
this.availableTools = response.data.data || [];
|
||||
} else {
|
||||
this.$emit('error', response.data.message || 'Failed to load tools');
|
||||
}
|
||||
} catch (error) {
|
||||
this.$emit('error', error.response?.data?.message || 'Failed to load tools');
|
||||
this.availableTools = [];
|
||||
} finally {
|
||||
this.loadingTools = false;
|
||||
}
|
||||
},
|
||||
|
||||
async savePersona() {
|
||||
if (!this.formValid) return;
|
||||
|
||||
// 验证预设对话不能为空
|
||||
if (this.personaForm.begin_dialogs.length > 0) {
|
||||
for (let i = 0; i < this.personaForm.begin_dialogs.length; i++) {
|
||||
if (!this.personaForm.begin_dialogs[i] || this.personaForm.begin_dialogs[i].trim() === '') {
|
||||
const dialogType = i % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
|
||||
this.$emit('error', this.tm('validation.dialogRequired', { type: dialogType }));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.saving = true;
|
||||
try {
|
||||
const url = this.editingPersona ? '/api/persona/update' : '/api/persona/create';
|
||||
const response = await axios.post(url, this.personaForm);
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.$emit('saved', response.data.message || this.tm('messages.saveSuccess'));
|
||||
this.closeDialog();
|
||||
} else {
|
||||
this.$emit('error', response.data.message || this.tm('messages.saveError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.$emit('error', error.response?.data?.message || this.tm('messages.saveError'));
|
||||
}
|
||||
this.saving = false;
|
||||
},
|
||||
|
||||
addDialogPair() {
|
||||
this.personaForm.begin_dialogs.push('', '');
|
||||
// 自动展开预设对话面板
|
||||
if (!this.expandedPanels.includes('dialogs')) {
|
||||
this.expandedPanels.push('dialogs');
|
||||
}
|
||||
},
|
||||
|
||||
removeDialog(index) {
|
||||
// 如果是偶数索引(用户消息),删除用户消息和对应的助手消息
|
||||
if (index % 2 === 0 && index + 1 < this.personaForm.begin_dialogs.length) {
|
||||
this.personaForm.begin_dialogs.splice(index, 2);
|
||||
}
|
||||
// 如果是奇数索引(助手消息),删除助手消息和对应的用户消息
|
||||
else if (index % 2 === 1 && index - 1 >= 0) {
|
||||
this.personaForm.begin_dialogs.splice(index - 1, 2);
|
||||
}
|
||||
},
|
||||
|
||||
toggleMcpServer(server) {
|
||||
if (!server.tools || server.tools.length === 0) return;
|
||||
|
||||
// 如果当前是全选状态,需要先转换为具体的工具列表
|
||||
if (this.personaForm.tools === null) {
|
||||
// 从全选状态转换为去除该服务器工具的状态
|
||||
this.personaForm.tools = this.availableTools.map(tool => tool.name)
|
||||
.filter(toolName => !server.tools.includes(toolName));
|
||||
this.toolSelectValue = '1'; // 切换到指定工具模式
|
||||
return;
|
||||
}
|
||||
|
||||
// 确保tools是数组
|
||||
if (!Array.isArray(this.personaForm.tools)) {
|
||||
this.personaForm.tools = [];
|
||||
this.toolSelectValue = '1';
|
||||
}
|
||||
|
||||
// 检查是否所有服务器的工具都已选中
|
||||
const serverTools = server.tools;
|
||||
const allSelected = serverTools.every(toolName => this.personaForm.tools.includes(toolName));
|
||||
|
||||
if (allSelected) {
|
||||
// 移除所有服务器工具
|
||||
this.personaForm.tools = this.personaForm.tools.filter(
|
||||
toolName => !serverTools.includes(toolName)
|
||||
);
|
||||
} else {
|
||||
// 添加所有服务器工具
|
||||
serverTools.forEach(toolName => {
|
||||
if (!this.personaForm.tools.includes(toolName)) {
|
||||
this.personaForm.tools.push(toolName);
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
toggleTool(toolName) {
|
||||
// 如果当前是全选状态,需要先转换为具体的工具列表
|
||||
if (this.personaForm.tools === null) {
|
||||
// 如果是全选状态,点击某个工具表示要取消选择该工具
|
||||
// 所以创建一个包含所有其他工具的数组
|
||||
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
|
||||
this.toolSelectValue = '1'; // 切换到指定工具模式
|
||||
} else if (Array.isArray(this.personaForm.tools)) {
|
||||
const index = this.personaForm.tools.indexOf(toolName);
|
||||
if (index !== -1) {
|
||||
// 如果工具已选择,移除工具
|
||||
this.personaForm.tools.splice(index, 1);
|
||||
} else {
|
||||
// 如果工具未选择,添加工具
|
||||
this.personaForm.tools.push(toolName);
|
||||
}
|
||||
} else {
|
||||
// 如果tools不是数组也不是null,初始化为数组
|
||||
this.personaForm.tools = [toolName];
|
||||
this.toolSelectValue = '1';
|
||||
}
|
||||
},
|
||||
|
||||
removeTool(toolName) {
|
||||
// 如果当前是全选状态,需要先转换为具体的工具列表
|
||||
if (this.personaForm.tools === null) {
|
||||
// 创建一个包含所有工具的数组,然后移除指定工具
|
||||
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
|
||||
this.toolSelectValue = '1'; // 切换到指定工具模式
|
||||
} else if (Array.isArray(this.personaForm.tools)) {
|
||||
const index = this.personaForm.tools.indexOf(toolName);
|
||||
if (index !== -1) {
|
||||
this.personaForm.tools.splice(index, 1);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
truncateText(text, maxLength) {
|
||||
if (!text) return '';
|
||||
return text.length > maxLength ? text.substring(0, maxLength) + '...' : text;
|
||||
},
|
||||
|
||||
getDialogRules(index) {
|
||||
const dialogType = index % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
|
||||
return [
|
||||
v => !!v || this.tm('validation.dialogRequired', { type: dialogType }),
|
||||
v => (v && v.trim().length > 0) || this.tm('validation.dialogRequired', { type: dialogType })
|
||||
];
|
||||
},
|
||||
|
||||
isToolSelected(toolName) {
|
||||
// 如果是全选状态,所有工具都被选中
|
||||
if (this.personaForm.tools === null) {
|
||||
return true;
|
||||
}
|
||||
return Array.isArray(this.personaForm.tools) && this.personaForm.tools.includes(toolName);
|
||||
},
|
||||
|
||||
isServerSelected(server) {
|
||||
if (!server.tools || server.tools.length === 0) return false;
|
||||
|
||||
// 如果是全选状态,所有服务器都被选中
|
||||
if (this.personaForm.tools === null) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 检查服务器的所有工具是否都已选中
|
||||
return Array.isArray(this.personaForm.tools) &&
|
||||
server.tools.every(toolName => this.personaForm.tools.includes(toolName));
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.tools-selection {
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.v-virtual-scroll {
|
||||
padding-bottom: 16px;
|
||||
}
|
||||
</style>
|
||||
@@ -18,7 +18,7 @@
|
||||
选择人格
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-0" style="max-height: 400px; overflow-y: auto;">
|
||||
<v-card-text class="pa-2" style="max-height: 400px; overflow-y: auto;">
|
||||
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
|
||||
|
||||
<v-list v-if="!loading && personaList.length > 0" density="compact">
|
||||
@@ -48,6 +48,9 @@
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-btn variant="text" color="primary" prepend-icon="mdi-plus" @click="openCreatePersona">
|
||||
创建新人格
|
||||
</v-btn>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="cancelSelection">取消</v-btn>
|
||||
<v-btn
|
||||
@@ -59,11 +62,22 @@
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 创建人格对话框 -->
|
||||
<PersonaForm
|
||||
v-model="showCreateDialog"
|
||||
:editing-persona="null"
|
||||
:mcp-servers="mcpServers"
|
||||
:available-tools="availableTools"
|
||||
:loading-tools="loadingTools"
|
||||
@saved="handlePersonaCreated"
|
||||
@error="handleError" />
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { ref, watch } from 'vue'
|
||||
import axios from 'axios'
|
||||
import PersonaForm from './PersonaForm.vue'
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
@@ -82,6 +96,7 @@ const dialog = ref(false)
|
||||
const personaList = ref([])
|
||||
const loading = ref(false)
|
||||
const selectedPersona = ref('')
|
||||
const showCreateDialog = ref(false)
|
||||
|
||||
// 监听 modelValue 变化,同步到 selectedPersona
|
||||
watch(() => props.modelValue, (newValue) => {
|
||||
@@ -135,6 +150,21 @@ function cancelSelection() {
|
||||
selectedPersona.value = props.modelValue || ''
|
||||
dialog.value = false
|
||||
}
|
||||
|
||||
function openCreatePersona() {
|
||||
showCreateDialog.value = true
|
||||
}
|
||||
|
||||
async function handlePersonaCreated(message) {
|
||||
console.log('人格创建成功:', message)
|
||||
showCreateDialog.value = false
|
||||
// 刷新人格列表
|
||||
await loadPersonas()
|
||||
}
|
||||
|
||||
function handleError(error) {
|
||||
console.error('创建人格失败:', error)
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"dashboard": "Dashboard",
|
||||
"dashboard": "Dashboard",
|
||||
"platforms": "Platforms",
|
||||
"providers": "Providers",
|
||||
"persona": "Persona",
|
||||
@@ -12,10 +12,12 @@
|
||||
"console": "Console",
|
||||
"alkaid": "Alkaid Lab",
|
||||
"knowledgeBase": "Knowledge Base",
|
||||
"knowledgeBaseBeta": "Knowledge Base (Beta)",
|
||||
"about": "About",
|
||||
"settings": "Settings",
|
||||
"documentation": "Documentation",
|
||||
"github": "GitHub",
|
||||
"drag": "Drag"
|
||||
}
|
||||
"drag": "Drag",
|
||||
"groups": {
|
||||
"more": "More Features"
|
||||
}
|
||||
}
|
||||
@@ -24,7 +24,9 @@
|
||||
"title": "Conversation Title",
|
||||
"platform": "Platform",
|
||||
"type": "Type",
|
||||
"sessionId": "ID",
|
||||
"cid": "Conversation ID",
|
||||
"umo": "Unified Message Origin",
|
||||
"sessionId": "Session ID",
|
||||
"createdAt": "Created At",
|
||||
"updatedAt": "Updated At",
|
||||
"actions": "Actions"
|
||||
|
||||
@@ -46,36 +46,23 @@
|
||||
"selectFile": "Select File",
|
||||
"dropzone": "Drop files here or click to select",
|
||||
"supportedFormats": "Supported formats: TXT, PDF, Markdown",
|
||||
"maxSize": "Max file size: 50MB",
|
||||
"maxSize": "Max file size: 128MB",
|
||||
"chunkSettings": "Chunk Settings",
|
||||
"batchSettings": "Batch Settings",
|
||||
"chunkSize": "Chunk Size",
|
||||
"chunkSizeHint": "Number of characters per chunk (default: 512)",
|
||||
"chunkOverlap": "Chunk Overlap",
|
||||
"chunkOverlapHint": "Overlapping characters between chunks (default: 50)",
|
||||
"batchSize": "Batch Size",
|
||||
"batchSizeHint": "Number of chunks to process in each batch (default: 32)",
|
||||
"tasksLimit": "Concurrent Tasks Limit",
|
||||
"tasksLimitHint": "Maximum number of concurrent upload tasks (default: 3)",
|
||||
"maxRetries": "Max Retries",
|
||||
"maxRetriesHint": "Number of times to retry a failed upload task (default: 3)",
|
||||
"cancel": "Cancel",
|
||||
"submit": "Upload",
|
||||
"fileRequired": "Please select a file to upload"
|
||||
},
|
||||
"sessions": {
|
||||
"title": "Session Configuration",
|
||||
"subtitle": "Configure which sessions can use this knowledge base",
|
||||
"empty": "No session configurations",
|
||||
"add": "Add Configuration",
|
||||
"scope": "Scope",
|
||||
"scopeId": "Identifier",
|
||||
"topK": "Top K Results",
|
||||
"enableRerank": "Enable Rerank",
|
||||
"actions": "Actions",
|
||||
"edit": "Edit",
|
||||
"delete": "Delete",
|
||||
"scopeSession": "Session Level",
|
||||
"scopePlatform": "Platform Level",
|
||||
"deleteConfirm": "Are you sure you want to delete this configuration?",
|
||||
"addSuccess": "Configuration added successfully",
|
||||
"addFailed": "Failed to add configuration",
|
||||
"deleteSuccess": "Configuration deleted successfully",
|
||||
"deleteFailed": "Failed to delete configuration"
|
||||
},
|
||||
"settings": {
|
||||
"title": "Knowledge Base Settings",
|
||||
"basic": "Basic Settings",
|
||||
|
||||
@@ -21,7 +21,8 @@
|
||||
"delete": "Delete",
|
||||
"preview": "Preview",
|
||||
"search": "Search Chunks",
|
||||
"searchPlaceholder": "Enter keywords to search chunks..."
|
||||
"searchPlaceholder": "Enter keywords to search chunks...",
|
||||
"showing": "Showing"
|
||||
},
|
||||
"edit": {
|
||||
"title": "Edit Chunk",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"dashboard": "统计",
|
||||
"platforms": "消息平台",
|
||||
"providers": "服务提供商",
|
||||
"dashboard": "数据统计",
|
||||
"platforms": "机器人",
|
||||
"providers": "模型提供商",
|
||||
"persona": "人格设定",
|
||||
"toolUse": "MCP",
|
||||
"extension": "插件",
|
||||
@@ -12,10 +12,12 @@
|
||||
"console": "控制台",
|
||||
"alkaid": "Alkaid",
|
||||
"knowledgeBase": "知识库",
|
||||
"knowledgeBaseBeta": "知识库 (Beta)",
|
||||
"about": "关于",
|
||||
"settings": "设置",
|
||||
"documentation": "官方文档",
|
||||
"github": "GitHub",
|
||||
"drag": "拖拽"
|
||||
}
|
||||
"drag": "拖拽",
|
||||
"groups": {
|
||||
"more": "更多功能"
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
"subtitle": "管理和查看用户对话历史记录",
|
||||
"filters": {
|
||||
"title": "筛选条件",
|
||||
"platform": "消息平台 ID",
|
||||
"platform": "机器人 ID",
|
||||
"type": "类型",
|
||||
"search": "搜索关键词",
|
||||
"reset": "重置"
|
||||
@@ -22,9 +22,11 @@
|
||||
"table": {
|
||||
"headers": {
|
||||
"title": "对话标题",
|
||||
"platform": "消息平台 ID",
|
||||
"type": "类型",
|
||||
"sessionId": "ID (UMO)",
|
||||
"platform": "机器人 ID",
|
||||
"type": "消息类型",
|
||||
"cid": "对话 ID",
|
||||
"umo": "消息会话来源",
|
||||
"sessionId": "会话 ID",
|
||||
"createdAt": "创建时间",
|
||||
"updatedAt": "更新时间",
|
||||
"actions": "操作"
|
||||
|
||||
@@ -47,32 +47,23 @@
|
||||
"selectFile": "选择文件",
|
||||
"dropzone": "拖放文件到这里或点击选择",
|
||||
"supportedFormats": "支持的格式: TXT, PDF, Markdown",
|
||||
"maxSize": "最大文件大小: 50MB",
|
||||
"maxSize": "最大文件大小: 128MB",
|
||||
"chunkSettings": "分块设置",
|
||||
"batchSettings": "批处理设置",
|
||||
"chunkSize": "分块大小",
|
||||
"chunkSizeHint": "每个文本块的字符数 (默认: 512)",
|
||||
"chunkOverlap": "分块重叠",
|
||||
"chunkOverlapHint": "相邻文本块之间的重叠字符数 (默认: 50)",
|
||||
"batchSize": "批处理大小",
|
||||
"batchSizeHint": "每批处理的文本块数量 (默认: 32)",
|
||||
"tasksLimit": "并发任务限制",
|
||||
"tasksLimitHint": "最大并发上传任务数 (默认: 3)",
|
||||
"maxRetries": "最大重试次数",
|
||||
"maxRetriesHint": "上传失败任务的重试次数 (默认: 3)",
|
||||
"cancel": "取消",
|
||||
"submit": "上传",
|
||||
"fileRequired": "请选择要上传的文件"
|
||||
},
|
||||
"sessions": {
|
||||
"title": "使用该知识库的会话",
|
||||
"subtitle": "以下会话正在使用此知识库",
|
||||
"empty": "暂无会话使用此知识库",
|
||||
"refresh": "刷新",
|
||||
"scope": "范围",
|
||||
"scopeId": "会话标识",
|
||||
"topK": "返回结果数",
|
||||
"enableRerank": "启用重排序",
|
||||
"actions": "操作",
|
||||
"scopeSession": "会话级别",
|
||||
"scopePlatform": "平台级别",
|
||||
"viewInSessionManagement": "在会话管理中查看",
|
||||
"goToSessionManagement": "前往会话管理",
|
||||
"loadFailed": "加载会话列表失败"
|
||||
},
|
||||
"retrieval": {
|
||||
"title": "知识库检索",
|
||||
"subtitle": "使用稠密检索和稀疏检索测试知识库内容",
|
||||
|
||||
@@ -21,7 +21,11 @@
|
||||
"delete": "删除",
|
||||
"preview": "预览",
|
||||
"search": "搜索分块",
|
||||
"searchPlaceholder": "输入关键词搜索分块内容..."
|
||||
"searchPlaceholder": "输入关键词搜索分块内容...",
|
||||
"showing": "显示",
|
||||
"deleteConfirm": "确定要删除该文本块吗?",
|
||||
"deleteSuccess": "文本块删除成功",
|
||||
"deleteFailed": "文本块删除失败"
|
||||
},
|
||||
"edit": {
|
||||
"title": "编辑分块",
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
{
|
||||
"title": "平台适配器管理",
|
||||
"subtitle": "管理机器人的平台适配器,连接到不同的聊天平台",
|
||||
"title": "机器人",
|
||||
"subtitle": "管理平台适配器实例,连接到不同的消息平台",
|
||||
"adapters": "平台适配器",
|
||||
"addAdapter": "新增适配器",
|
||||
"emptyText": "暂无平台适配器,点击 新增适配器 添加",
|
||||
"addAdapter": "创建机器人",
|
||||
"emptyText": "暂无平台适配器,点击 创建机器人 添加",
|
||||
"details": {
|
||||
"adapterType": "适配器类型",
|
||||
"token": "Token",
|
||||
@@ -17,11 +17,11 @@
|
||||
"dialog": {
|
||||
"add": "新增",
|
||||
"edit": "编辑",
|
||||
"adapter": "平台适配器",
|
||||
"adapter": "机器人",
|
||||
"refresh": "刷新",
|
||||
"cancel": "取消",
|
||||
"save": "保存",
|
||||
"addPlatform": "添加平台适配器",
|
||||
"addPlatform": "创建机器人",
|
||||
"connectTitle": "接入 {name}",
|
||||
"viewTutorial": "查看接入教程",
|
||||
"noTemplates": "暂无平台模板",
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
{
|
||||
"title": "服务提供商管理",
|
||||
"subtitle": "管理模型服务提供商",
|
||||
"title": "模型提供商",
|
||||
"subtitle": "管理模型提供商",
|
||||
"providers": {
|
||||
"title": "服务提供商",
|
||||
"title": "模型提供商",
|
||||
"settings": "设置",
|
||||
"addProvider": "新增服务提供商",
|
||||
"addProvider": "新增模型提供商",
|
||||
"providerType": "提供商类型",
|
||||
"tabs": {
|
||||
"all": "全部",
|
||||
@@ -15,8 +15,8 @@
|
||||
"rerank": "重排序(Rerank)"
|
||||
},
|
||||
"empty": {
|
||||
"all": "暂无服务提供商,点击 新增服务提供商 添加",
|
||||
"typed": "暂无{type}类型的服务提供商,点击 新增服务提供商 添加"
|
||||
"all": "暂无模型提供商,点击 新增模型提供商 添加",
|
||||
"typed": "暂无{type}类型的模型提供商,点击 新增模型提供商 添加"
|
||||
},
|
||||
"description": {
|
||||
"openai": "也支持所有兼容 OpenAI API 的模型提供商。",
|
||||
@@ -25,10 +25,10 @@
|
||||
}
|
||||
},
|
||||
"availability": {
|
||||
"title": "服务提供商可用性",
|
||||
"title": "模型提供商可用性",
|
||||
"subtitle": "通过测试模型对话可用性判断,可能产生API费用",
|
||||
"refresh": "刷新状态",
|
||||
"noData": "点击\"刷新状态\"按钮获取服务提供商可用性",
|
||||
"noData": "点击\"刷新状态\"按钮获取模型提供商可用性",
|
||||
"available": "可用",
|
||||
"unavailable": "不可用",
|
||||
"pending": "检查中...",
|
||||
@@ -42,7 +42,7 @@
|
||||
},
|
||||
"dialogs": {
|
||||
"addProvider": {
|
||||
"title": "服务提供商",
|
||||
"title": "模型提供商",
|
||||
"tabs": {
|
||||
"basic": "基本",
|
||||
"speechToText": "语音转文字",
|
||||
@@ -55,15 +55,15 @@
|
||||
"config": {
|
||||
"addTitle": "新增",
|
||||
"editTitle": "编辑",
|
||||
"provider": "服务提供商",
|
||||
"provider": "模型提供商",
|
||||
"cancel": "取消",
|
||||
"save": "保存"
|
||||
},
|
||||
"settings": {
|
||||
"title": "服务提供商设置",
|
||||
"title": "模型提供商设置",
|
||||
"sessionSeparation": {
|
||||
"title": "启用提供商会话隔离",
|
||||
"description": "不同会话将可独立选择文本生成、TTS、STT 等服务提供商。"
|
||||
"description": "不同会话将可独立选择文本生成、TTS、STT 等模型提供商。"
|
||||
},
|
||||
"close": "关闭"
|
||||
}
|
||||
@@ -78,11 +78,11 @@
|
||||
},
|
||||
"error": {
|
||||
"sessionSeparation": "获取会话隔离配置失败",
|
||||
"fetchStatus": "获取服务提供商状态失败",
|
||||
"fetchStatus": "获取模型提供商状态失败",
|
||||
"testError": "测试 {id} 失败: {error}"
|
||||
},
|
||||
"confirm": {
|
||||
"delete": "确定要删除服务提供商 {id} 吗?"
|
||||
"delete": "确定要删除模型提供商 {id} 吗?"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -23,7 +23,7 @@
|
||||
"table": {
|
||||
"headers": {
|
||||
"sessionStatus": "会话状态",
|
||||
"sessionInfo": "ID (UMO)",
|
||||
"sessionInfo": "消息会话来源",
|
||||
"persona": "人格",
|
||||
"chatProvider": "聊天模型",
|
||||
"sttProvider": "语音识别模型",
|
||||
|
||||
@@ -311,7 +311,7 @@ commonStore.getStartTime();
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
|
||||
<div class="logo-container" :class="{ 'mobile-logo': $vuetify.display.xs }" @click="$router.push('/about')">
|
||||
<div class="logo-container" :class="{ 'mobile-logo': $vuetify.display.xs }" @click="router.push('/about')">
|
||||
<span class="logo-text">Astr<span class="logo-text-light">Bot</span></span>
|
||||
<span class="version-text hidden-xs">{{ botCurrVersion }}</span>
|
||||
</div>
|
||||
|
||||
@@ -1,20 +1,38 @@
|
||||
<script setup>
|
||||
import { useI18n } from '@/i18n/composables';
|
||||
import { useCustomizerStore } from '@/stores/customizer';
|
||||
import { computed } from 'vue';
|
||||
|
||||
const props = defineProps({ item: Object, level: Number });
|
||||
const { t } = useI18n();
|
||||
const customizer = useCustomizerStore();
|
||||
|
||||
const itemStyle = computed(() => {
|
||||
const lvl = props.level ?? 0;
|
||||
const indent = customizer.mini_sidebar ? '0px' : `${lvl * 24}px`;
|
||||
return { '--indent-padding': indent };
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<v-list-item
|
||||
:to="item.type === 'external' ? '' : item.to"
|
||||
:href="item.type === 'external' ? item.to : ''"
|
||||
rounded
|
||||
class="mb-1"
|
||||
color="secondary"
|
||||
:disabled="item.disabled"
|
||||
:target="item.type === 'external' ? '_blank' : ''"
|
||||
>
|
||||
<v-list-group v-if="item.children" :value="item.title" :class="{ 'group-bordered': customizer.mini_sidebar }">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-list-item v-bind="props" rounded class="mb-1" color="secondary" :prepend-icon="item.icon"
|
||||
:style="{ '--indent-padding': '0px' }">
|
||||
<v-list-item-title style="font-size: 14px; font-weight: 500; line-height: 1.2; word-break: break-word;">
|
||||
{{ t(item.title) }}
|
||||
</v-list-item-title>
|
||||
</v-list-item>
|
||||
</template>
|
||||
|
||||
<!-- children -->
|
||||
<template v-for="(child, index) in item.children" :key="index">
|
||||
<NavItem :item="child" :level="(level || 0) + 1" />
|
||||
</template>
|
||||
</v-list-group>
|
||||
|
||||
<v-list-item v-else :to="item.type === 'external' ? '' : item.to" :href="item.type === 'external' ? item.to : ''" rounded
|
||||
class="mb-1" color="secondary" :disabled="item.disabled" :target="item.type === 'external' ? '_blank' : ''" :style="itemStyle">
|
||||
<template v-slot:prepend>
|
||||
<v-icon v-if="item.icon" :size="item.iconSize" class="hide-menu" :icon="item.icon"></v-icon>
|
||||
</template>
|
||||
@@ -23,15 +41,19 @@ const { t } = useI18n();
|
||||
{{ item.subCaption }}
|
||||
</v-list-item-subtitle>
|
||||
<template v-slot:append v-if="item.chip">
|
||||
<v-chip
|
||||
:color="item.chipColor"
|
||||
class="sidebarchip hide-menu"
|
||||
:size="item.chipIcon ? 'small' : 'default'"
|
||||
:variant="item.chipVariant"
|
||||
:prepend-icon="item.chipIcon"
|
||||
>
|
||||
<v-chip :color="item.chipColor" class="sidebarchip hide-menu" :size="item.chipIcon ? 'small' : 'default'"
|
||||
:variant="item.chipVariant" :prepend-icon="item.chipIcon">
|
||||
{{ item.chip }}
|
||||
</v-chip>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</template>
|
||||
|
||||
<style>
|
||||
/* 在折叠(mini)状态下,分组展开时给整个分组(母项+子项)加边框以便区分 */
|
||||
.group-bordered.v-list-group--open {
|
||||
border: 2px solid rgba(var(--v-theme-borderLight), 0.35);
|
||||
border-radius: 8px;
|
||||
background: rgba(var(--v-theme-borderLight), 0.04);
|
||||
}
|
||||
</style>
|
||||
@@ -11,8 +11,13 @@ const customizer = useCustomizerStore();
|
||||
const sidebarMenu = shallowRef(sidebarItems);
|
||||
|
||||
const showIframe = ref(false);
|
||||
const starCount = ref(null);
|
||||
|
||||
const sidebarWidth = ref(235);
|
||||
const minSidebarWidth = 200;
|
||||
const maxSidebarWidth = 300;
|
||||
const isResizing = ref(false);
|
||||
|
||||
// 默认桌面端 iframe 样式
|
||||
const iframeStyle = ref({
|
||||
position: 'fixed',
|
||||
bottom: '16px',
|
||||
@@ -29,14 +34,13 @@ const iframeStyle = ref({
|
||||
boxShadow: '0px 4px 12px rgba(0, 0, 0, 0.1)',
|
||||
});
|
||||
|
||||
// 如果为移动端,则采用百分比尺寸,并设置初始位置
|
||||
if (window.innerWidth < 768) {
|
||||
iframeStyle.value = {
|
||||
position: 'fixed',
|
||||
top: '10%',
|
||||
left: '0%',
|
||||
width: '100%',
|
||||
height: '50%',
|
||||
height: '80%',
|
||||
minWidth: '300px',
|
||||
minHeight: '200px',
|
||||
background: 'white',
|
||||
@@ -46,7 +50,6 @@ if (window.innerWidth < 768) {
|
||||
borderRadius: '12px',
|
||||
boxShadow: '0px 4px 12px rgba(0, 0, 0, 0.1)',
|
||||
};
|
||||
// 移动端默认关闭侧边栏
|
||||
customizer.Sidebar_drawer = false;
|
||||
}
|
||||
|
||||
@@ -74,12 +77,10 @@ function openIframeLink(url) {
|
||||
}
|
||||
}
|
||||
|
||||
// 拖拽相关变量与函数
|
||||
let offsetX = 0;
|
||||
let offsetY = 0;
|
||||
let isDragging = false;
|
||||
|
||||
// 辅助函数:限制数值在一定范围内
|
||||
function clamp(value, min, max) {
|
||||
return Math.min(Math.max(value, min), max);
|
||||
}
|
||||
@@ -91,7 +92,6 @@ function startDrag(clientX, clientY) {
|
||||
offsetX = clientX - rect.left;
|
||||
offsetY = clientY - rect.top;
|
||||
document.body.style.userSelect = 'none';
|
||||
// 绑定全局鼠标和触摸事件
|
||||
document.addEventListener('mousemove', onMouseMove);
|
||||
document.addEventListener('mouseup', onMouseUp);
|
||||
document.addEventListener('touchmove', onTouchMove, { passive: false });
|
||||
@@ -149,6 +149,53 @@ function endDrag() {
|
||||
document.removeEventListener('touchend', onTouchEnd);
|
||||
}
|
||||
|
||||
function startSidebarResize(event) {
|
||||
isResizing.value = true;
|
||||
document.body.style.userSelect = 'none';
|
||||
document.body.style.cursor = 'ew-resize';
|
||||
|
||||
const startX = event.clientX;
|
||||
const startWidth = sidebarWidth.value;
|
||||
|
||||
function onMouseMoveResize(event) {
|
||||
if (!isResizing.value) return;
|
||||
|
||||
const deltaX = event.clientX - startX;
|
||||
const newWidth = Math.max(minSidebarWidth, Math.min(maxSidebarWidth, startWidth + deltaX));
|
||||
sidebarWidth.value = newWidth;
|
||||
}
|
||||
|
||||
function onMouseUpResize() {
|
||||
isResizing.value = false;
|
||||
document.body.style.userSelect = '';
|
||||
document.body.style.cursor = '';
|
||||
document.removeEventListener('mousemove', onMouseMoveResize);
|
||||
document.removeEventListener('mouseup', onMouseUpResize);
|
||||
}
|
||||
|
||||
document.addEventListener('mousemove', onMouseMoveResize);
|
||||
document.addEventListener('mouseup', onMouseUpResize);
|
||||
}
|
||||
|
||||
function formatNumber(num) {
|
||||
return num.toString().replace(/\B(?=(\d{3})+(?!\d))/g, ',');
|
||||
}
|
||||
|
||||
async function fetchStarCount() {
|
||||
try {
|
||||
const response = await fetch('https://cloud.astrbot.app/api/v1/github/repo-info');
|
||||
const data = await response.json();
|
||||
if (data.data && data.data.stargazers_count) {
|
||||
starCount.value = data.data.stargazers_count;
|
||||
console.debug('Fetched star count:', starCount.value);
|
||||
}
|
||||
} catch (error) {
|
||||
console.debug('Failed to fetch star count:', error);
|
||||
}
|
||||
}
|
||||
|
||||
fetchStarCount();
|
||||
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -159,7 +206,7 @@ function endDrag() {
|
||||
rail-width="80"
|
||||
app
|
||||
class="leftSidebar"
|
||||
width="220"
|
||||
:width="sidebarWidth"
|
||||
:rail="customizer.mini_sidebar"
|
||||
>
|
||||
<div class="sidebar-container">
|
||||
@@ -177,9 +224,24 @@ function endDrag() {
|
||||
</v-btn>
|
||||
<v-btn style="margin-bottom: 8px;" size="small" variant="plain" @click="openIframeLink('https://github.com/AstrBotDevs/AstrBot')">
|
||||
{{ t('core.navigation.github') }}
|
||||
<v-chip
|
||||
v-if="starCount"
|
||||
size="x-small"
|
||||
variant="outlined"
|
||||
class="ml-2"
|
||||
style="font-weight: normal;"
|
||||
>{{ formatNumber(starCount) }}</v-chip>
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div
|
||||
v-if="!customizer.mini_sidebar && customizer.Sidebar_drawer"
|
||||
class="sidebar-resize-handle"
|
||||
@mousedown="startSidebarResize"
|
||||
:class="{ 'resizing': isResizing }"
|
||||
>
|
||||
</div>
|
||||
</v-navigation-drawer>
|
||||
|
||||
<div
|
||||
@@ -187,14 +249,13 @@ function endDrag() {
|
||||
id="draggable-iframe"
|
||||
:style="iframeStyle"
|
||||
>
|
||||
<!-- 拖拽头部:支持鼠标和触摸 -->
|
||||
|
||||
<div :style="dragHeaderStyle" @mousedown="onMouseDown" @touchstart="onTouchStart">
|
||||
<div style="display: flex; align-items: center;">
|
||||
<v-icon icon="mdi-cursor-move" />
|
||||
<span style="margin-left: 8px;">{{ t('core.navigation.drag') }}</span>
|
||||
</div>
|
||||
<div style="display: flex; gap: 8px;">
|
||||
<!-- 跳转按钮 -->
|
||||
<v-btn
|
||||
icon
|
||||
@click.stop="openIframeLink('https://astrbot.app')"
|
||||
@@ -203,7 +264,6 @@ function endDrag() {
|
||||
>
|
||||
<v-icon icon="mdi-open-in-new" />
|
||||
</v-btn>
|
||||
<!-- 关闭按钮 -->
|
||||
<v-btn
|
||||
icon
|
||||
@click.stop="toggleIframe"
|
||||
@@ -214,10 +274,53 @@ function endDrag() {
|
||||
</v-btn>
|
||||
</div>
|
||||
</div>
|
||||
<!-- iframe 区域 -->
|
||||
<iframe
|
||||
src="https://astrbot.app"
|
||||
style="width: 100%; height: calc(100% - 56px); border: none; border-bottom-left-radius: 12px; border-bottom-right-radius: 12px;"
|
||||
></iframe>
|
||||
</div>
|
||||
</template>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
.sidebar-resize-handle {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
right: 0;
|
||||
width: 4px;
|
||||
height: 100%;
|
||||
background: transparent;
|
||||
cursor: ew-resize;
|
||||
user-select: none;
|
||||
z-index: 1000;
|
||||
transition: background-color 0.2s ease;
|
||||
}
|
||||
|
||||
.sidebar-resize-handle:hover,
|
||||
.sidebar-resize-handle.resizing {
|
||||
background: rgba(var(--v-theme-primary), 0.3);
|
||||
}
|
||||
|
||||
.sidebar-resize-handle::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
width: 2px;
|
||||
height: 30px;
|
||||
background: rgba(var(--v-theme-on-surface), 0.3);
|
||||
border-radius: 1px;
|
||||
opacity: 0;
|
||||
transition: opacity 0.2s ease;
|
||||
}
|
||||
|
||||
.sidebar-resize-handle:hover::before,
|
||||
.sidebar-resize-handle.resizing::before {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
/* 确保侧边栏容器支持相对定位 */
|
||||
.leftSidebar .v-navigation-drawer__content {
|
||||
position: relative;
|
||||
}
|
||||
</style>
|
||||
@@ -18,31 +18,26 @@ export interface menu {
|
||||
// 在组件中使用时需要通过t()函数进行翻译
|
||||
// 所有键名都使用 core.navigation.* 格式
|
||||
const sidebarItem: menu[] = [
|
||||
{
|
||||
title: 'core.navigation.dashboard',
|
||||
icon: 'mdi-view-dashboard',
|
||||
to: '/dashboard/default'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.platforms',
|
||||
icon: 'mdi-message-processing',
|
||||
to: '/platforms',
|
||||
icon: 'mdi-robot',
|
||||
to: '/',
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.providers',
|
||||
icon: 'mdi-creation',
|
||||
to: '/providers',
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.config',
|
||||
icon: 'mdi-cog',
|
||||
to: '/config',
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.toolUse',
|
||||
icon: 'mdi-function-variant',
|
||||
to: '/tool-use'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.persona',
|
||||
icon: 'mdi-heart',
|
||||
to: '/persona'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.extension',
|
||||
icon: 'mdi-puzzle',
|
||||
@@ -50,21 +45,8 @@ const sidebarItem: menu[] = [
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.knowledgeBase',
|
||||
icon: 'mdi-text-box-search',
|
||||
to: '/alkaid/knowledge-base',
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.knowledgeBaseBeta',
|
||||
icon: 'mdi-book-open-variant',
|
||||
to: '/knowledge-base',
|
||||
chip: 'Beta',
|
||||
chipColor: 'primary',
|
||||
chipVariant: 'tonal',
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.config',
|
||||
icon: 'mdi-cog',
|
||||
to: '/config',
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.chat',
|
||||
@@ -72,20 +54,36 @@ const sidebarItem: menu[] = [
|
||||
to: '/chat'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.conversation',
|
||||
icon: 'mdi-database',
|
||||
to: '/conversation'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.sessionManagement',
|
||||
icon: 'mdi-account-group',
|
||||
to: '/session-management'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.console',
|
||||
icon: 'mdi-console',
|
||||
to: '/console'
|
||||
},
|
||||
title: 'core.navigation.groups.more',
|
||||
icon: 'mdi-dots-horizontal',
|
||||
children: [
|
||||
{
|
||||
title: 'core.navigation.persona',
|
||||
icon: 'mdi-heart',
|
||||
to: '/persona'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.conversation',
|
||||
icon: 'mdi-database',
|
||||
to: '/conversation'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.sessionManagement',
|
||||
icon: 'mdi-account-group',
|
||||
to: '/session-management'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.dashboard',
|
||||
icon: 'mdi-view-dashboard',
|
||||
to: '/dashboard/default'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.console',
|
||||
icon: 'mdi-console',
|
||||
to: '/console'
|
||||
},
|
||||
]
|
||||
}
|
||||
// {
|
||||
// title: 'Project ATRI',
|
||||
// icon: 'mdi-grain',
|
||||
|
||||
@@ -49,6 +49,6 @@ axios.interceptors.request.use((config) => {
|
||||
|
||||
loader.config({
|
||||
paths: {
|
||||
vs: 'https://cdn.jsdelivr.net/npm/monaco-editor@0.43.0/min/vs',
|
||||
vs: 'https://cdn.jsdelivr.net/npm/monaco-editor@0.54.0/min/vs',
|
||||
},
|
||||
})
|
||||
@@ -3,13 +3,13 @@ const MainRoutes = {
|
||||
meta: {
|
||||
requiresAuth: true
|
||||
},
|
||||
redirect: '/main/dashboard/default',
|
||||
redirect: '/main/platforms',
|
||||
component: () => import('@/layouts/full/FullLayout.vue'),
|
||||
children: [
|
||||
{
|
||||
name: 'Dashboard',
|
||||
name: 'MainPage',
|
||||
path: '/',
|
||||
component: () => import('@/views/dashboards/default/DefaultDashboard.vue')
|
||||
component: () => import('@/views/PlatformPage.vue')
|
||||
},
|
||||
{
|
||||
name: 'Extensions',
|
||||
@@ -90,6 +90,13 @@ const MainRoutes = {
|
||||
}
|
||||
]
|
||||
},
|
||||
|
||||
// 旧版本的知识库路由
|
||||
{
|
||||
name: 'KnowledgeBase',
|
||||
path: '/alkaid/knowledge-base',
|
||||
component: () => import('@/views/alkaid/KnowledgeBase.vue'),
|
||||
},
|
||||
// {
|
||||
// name: 'Alkaid',
|
||||
// path: '/alkaid',
|
||||
@@ -112,28 +119,6 @@ const MainRoutes = {
|
||||
// }
|
||||
// ]
|
||||
// },
|
||||
{
|
||||
name: 'Alkaid',
|
||||
path: '/alkaid',
|
||||
component: () => import('@/views/AlkaidPage.vue'),
|
||||
children: [
|
||||
{
|
||||
path: 'knowledge-base',
|
||||
name: 'KnowledgeBase',
|
||||
component: () => import('@/views/alkaid/KnowledgeBase.vue')
|
||||
},
|
||||
{
|
||||
path: 'long-term-memory',
|
||||
name: 'LongTermMemory',
|
||||
component: () => import('@/views/alkaid/LongTermMemory.vue')
|
||||
},
|
||||
{
|
||||
path: 'other',
|
||||
name: 'OtherFeatures',
|
||||
component: () => import('@/views/alkaid/Other.vue')
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
name: 'Chat',
|
||||
path: '/chat',
|
||||
|
||||
@@ -11,12 +11,6 @@
|
||||
<v-select style="min-width: 130px;" v-model="selectedConfigID" :items="configSelectItems" item-title="name"
|
||||
v-if="!isSystemConfig" item-value="id" label="选择配置文件" hide-details density="compact" rounded="md"
|
||||
variant="outlined" @update:model-value="onConfigSelect">
|
||||
<template v-slot:item="{ props: itemProps, item }">
|
||||
<v-list-item v-bind="itemProps"
|
||||
:subtitle="item.raw.id === '_%manage%_' ? '管理所有配置文件' : formatUmop(item.raw.umop)"
|
||||
:class="item.raw.id === '_%manage%_' ? 'text-primary' : ''">
|
||||
</v-list-item>
|
||||
</template>
|
||||
</v-select>
|
||||
<a style="color: inherit;" href="https://blog.astrbot.app/posts/what-is-changed-in-4.0.0/#%E5%A4%9A%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6" target="_blank"><v-btn icon="mdi-help-circle" size="small" variant="plain"></v-btn></a>
|
||||
|
||||
@@ -37,38 +31,10 @@
|
||||
|
||||
<div v-if="(selectedConfigID || isSystemConfig) && fetched" style="width: 100%;">
|
||||
<!-- 可视化编辑 -->
|
||||
<div :class="$vuetify.display.mobile ? '' : 'd-flex'">
|
||||
<v-tabs v-model="tab" :direction="$vuetify.display.mobile ? 'horizontal' : 'vertical'"
|
||||
:align-tabs="$vuetify.display.mobile ? 'left' : 'start'" color="deep-purple-accent-4" class="config-tabs">
|
||||
<v-tab v-for="(val, key, index) in metadata" :key="index" :value="index"
|
||||
style="font-weight: 1000; font-size: 15px">
|
||||
{{ metadata[key]['name'] }}
|
||||
</v-tab>
|
||||
</v-tabs>
|
||||
<v-tabs-window v-model="tab" class="config-tabs-window">
|
||||
<v-tabs-window-item v-for="(val, key, index) in metadata" v-show="index == tab" :key="index">
|
||||
<v-container fluid>
|
||||
<div v-for="(val2, key2, index2) in metadata[key]['metadata']" :key="key2">
|
||||
<!-- Support both traditional and JSON selector metadata -->
|
||||
<AstrBotConfigV4 :metadata="{ [key2]: metadata[key]['metadata'][key2] }" :iterable="config_data"
|
||||
:metadataKey="key2">
|
||||
</AstrBotConfigV4>
|
||||
</div>
|
||||
</v-container>
|
||||
</v-tabs-window-item>
|
||||
|
||||
|
||||
<div style="margin-left: 16px; padding-bottom: 16px">
|
||||
<small>{{ tm('help.helpPrefix') }}
|
||||
<a href="https://astrbot.app/" target="_blank">{{ tm('help.documentation') }}</a>
|
||||
{{ tm('help.helpMiddle') }}
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=EYGsuUTfe00_iOu9JTXS7_TEpMkXOvwv&jump_from=webapi&authKey=uUEMKCROfsseS+8IzqPjzV3y1tzy4AkykwTib2jNkOFdzezF9s9XknqnIaf3CDft"
|
||||
target="_blank">{{ tm('help.support') }}</a>{{ tm('help.helpSuffix') }}
|
||||
</small>
|
||||
</div>
|
||||
|
||||
</v-tabs-window>
|
||||
</div>
|
||||
<AstrBotCoreConfigWrapper
|
||||
:metadata="metadata"
|
||||
:config_data="config_data"
|
||||
/>
|
||||
|
||||
<v-btn icon="mdi-content-save" size="x-large" style="position: fixed; right: 52px; bottom: 52px;"
|
||||
color="darkprimary" @click="updateConfig">
|
||||
@@ -118,7 +84,7 @@
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text>
|
||||
<small>AstrBot 支持针对不同消息平台实例分别设置配置文件。默认会使用 `default` 配置。</small>
|
||||
<small>AstrBot 支持针对不同机器人分别设置配置文件。默认会使用 `default` 配置。</small>
|
||||
<div class="mt-6 mb-4">
|
||||
<v-btn prepend-icon="mdi-plus" @click="startCreateConfig" variant="tonal" color="primary">
|
||||
新建配置文件
|
||||
@@ -128,8 +94,6 @@
|
||||
<!-- Config List -->
|
||||
<v-list lines="two">
|
||||
<v-list-item v-for="config in configInfoList" :key="config.id" :title="config.name">
|
||||
<v-list-item-subtitle>当前应用于: {{ formatUmop(config.umop) }} </v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append v-if="config.id !== 'default'">
|
||||
<div class="d-flex align-center" style="gap: 8px;">
|
||||
<v-btn icon="mdi-pencil" size="small" variant="text" color="warning"
|
||||
@@ -147,149 +111,15 @@
|
||||
<div v-if="showConfigForm">
|
||||
<h3 class="mb-4">{{ isEditingConfig ? '编辑配置文件' : '新建配置文件' }}</h3>
|
||||
|
||||
<div class="mb-4">
|
||||
<div v-if="conflictMessage" class="text-warning">
|
||||
<div v-html="conflictMessage" style="font-size: 0.875rem; line-height: 1.4;"></div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<h4>名称</h4>
|
||||
|
||||
<v-text-field v-model="configFormData.name" label="填写配置文件名称" variant="outlined" class="mt-4 mb-4"
|
||||
hide-details></v-text-field>
|
||||
|
||||
<h4>应用于</h4>
|
||||
|
||||
<v-radio-group class="mt-2" v-model="appliedToRadioValue" hide-details="true">
|
||||
<v-radio value="0">
|
||||
<template v-slot:label>
|
||||
<span>指定消息平台...</span>
|
||||
</template>
|
||||
</v-radio>
|
||||
<v-select v-if="appliedToRadioValue === '0'" v-model="configFormData.umop" :items="platformList" item-title="id" item-value="id"
|
||||
label="选择已配置的消息平台(可多选)" variant="outlined" hide-details multiple class="ma-2"
|
||||
@update:model-value="checkPlatformConflictOnForm">
|
||||
<template v-slot:item="{ props: itemProps, item }">
|
||||
<v-list-item v-bind="itemProps" :subtitle="item.raw.type"></v-list-item>
|
||||
</template>
|
||||
</v-select>
|
||||
<v-radio value="1" label="自定义规则(实验性)">
|
||||
</v-radio>
|
||||
|
||||
<!-- 自定义规则界面 -->
|
||||
<div v-if="appliedToRadioValue === '1'" class="ma-2">
|
||||
<small class="text-medium-emphasis mb-4 d-block">UMO 格式: [platform_id]:[message_type]:[session_id]。通配符 * 或留空表示全部。使用 /sid 查看某个聊天的 UMO。</small>
|
||||
|
||||
<!-- 输入方式切换 -->
|
||||
<v-btn-toggle v-model="customRuleInputMode" mandatory color="primary" variant="outlined" density="compact"
|
||||
rounded="md" class="mb-4">
|
||||
<v-btn value="builder" prepend-icon="mdi-tune" size="x-small">
|
||||
可视化
|
||||
</v-btn>
|
||||
<v-btn value="manual" prepend-icon="mdi-code-tags" size="x-small">
|
||||
手动编辑
|
||||
</v-btn>
|
||||
</v-btn-toggle>
|
||||
|
||||
<!-- 快速规则构建 -->
|
||||
<div v-if="customRuleInputMode === 'builder'" class="mb-4">
|
||||
<div v-for="(rule, index) in customRules" :key="index" class="d-flex align-center mb-2" style="gap: 8px;">
|
||||
<v-select
|
||||
v-model="rule.platform"
|
||||
:items="[{ id: '*', type: '所有平台' }, ...platformList]"
|
||||
item-title="id"
|
||||
item-value="id"
|
||||
label="平台"
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
style="min-width: 120px;"
|
||||
@update:model-value="updateCustomRule(index)">
|
||||
<template v-slot:item="{ props: itemProps, item }">
|
||||
<v-list-item v-bind="itemProps" :subtitle="item.raw.type"></v-list-item>
|
||||
</template>
|
||||
</v-select>
|
||||
|
||||
<v-select
|
||||
v-model="rule.messageType"
|
||||
:items="messageTypeOptions"
|
||||
item-title="label"
|
||||
item-value="value"
|
||||
label="消息类型"
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
style="min-width: 130px;"
|
||||
@update:model-value="updateCustomRule(index)">
|
||||
</v-select>
|
||||
|
||||
<v-text-field
|
||||
v-model="rule.sessionId"
|
||||
label="会话ID"
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
placeholder="* 或留空表示全部"
|
||||
style="min-width: 120px;"
|
||||
@update:model-value="updateCustomRule(index)">
|
||||
</v-text-field>
|
||||
|
||||
<v-btn
|
||||
icon="mdi-delete"
|
||||
size="small"
|
||||
variant="text"
|
||||
color="error"
|
||||
@click="removeCustomRule(index)"
|
||||
:disabled="customRules.length === 1">
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<v-btn
|
||||
prepend-icon="mdi-plus"
|
||||
size="small"
|
||||
variant="tonal"
|
||||
color="primary"
|
||||
@click="addCustomRule">
|
||||
添加规则
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- 手动输入 -->
|
||||
<div v-if="customRuleInputMode === 'manual'" class="mb-4">
|
||||
<v-textarea
|
||||
v-model="manualRulesText"
|
||||
label="手动输入规则(每行一个)"
|
||||
variant="outlined"
|
||||
rows="4"
|
||||
placeholder="每行一个规则,例如: platform1:GroupMessage:* *:FriendMessage:session123 *:*:*"
|
||||
@update:model-value="updateManualRules">
|
||||
</v-textarea>
|
||||
</div>
|
||||
|
||||
<!-- 规则预览 -->
|
||||
<div class="mb-2">
|
||||
<small class="text-medium-emphasis">
|
||||
<strong>预览:</strong>
|
||||
<span v-if="!configFormData.umop.length" class="text-error">未配置任何规则</span>
|
||||
<div v-else class="mt-1">
|
||||
<v-chip
|
||||
v-for="(rule, index) in configFormData.umop"
|
||||
:key="index"
|
||||
size="x-small"
|
||||
rounded="sm"
|
||||
class="mr-1">
|
||||
{{ rule }}
|
||||
</v-chip>
|
||||
</div>
|
||||
<small>这些规则对应的会话将使用此配置文件。</small>
|
||||
</small>
|
||||
</div>
|
||||
</div>
|
||||
</v-radio-group>
|
||||
|
||||
|
||||
|
||||
<div class="d-flex justify-end mt-4" style="gap: 8px;">
|
||||
<v-btn variant="text" @click="cancelConfigForm">取消</v-btn>
|
||||
<v-btn color="primary" @click="saveConfigForm"
|
||||
:disabled="!configFormData.name || !configFormData.umop.length">
|
||||
:disabled="!configFormData.name">
|
||||
{{ isEditingConfig ? '更新' : '创建' }}
|
||||
</v-btn>
|
||||
</div>
|
||||
@@ -308,7 +138,7 @@
|
||||
|
||||
<script>
|
||||
import axios from 'axios';
|
||||
import AstrBotConfigV4 from '@/components/shared/AstrBotConfigV4.vue';
|
||||
import AstrBotCoreConfigWrapper from '@/components/config/AstrBotCoreConfigWrapper.vue';
|
||||
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||
@@ -316,7 +146,7 @@ import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||
export default {
|
||||
name: 'ConfigPage',
|
||||
components: {
|
||||
AstrBotConfigV4,
|
||||
AstrBotCoreConfigWrapper,
|
||||
VueMonacoEditor,
|
||||
WaitingForRestart
|
||||
},
|
||||
@@ -359,15 +189,6 @@ export default {
|
||||
watch: {
|
||||
config_data_str: function (val) {
|
||||
this.config_data_has_changed = true;
|
||||
},
|
||||
customRuleInputMode: function (newVal) {
|
||||
if (newVal === 'builder') {
|
||||
// 切换到快速构建,从手动输入同步数据
|
||||
this.syncCustomRulesFromManual();
|
||||
} else if (newVal === 'manual') {
|
||||
// 切换到手动输入,从快速构建同步数据
|
||||
this.syncManualRulesText();
|
||||
}
|
||||
}
|
||||
},
|
||||
data() {
|
||||
@@ -387,8 +208,6 @@ export default {
|
||||
save_message: "",
|
||||
save_message_success: "",
|
||||
|
||||
tab: 0, // 用于切换配置标签页
|
||||
|
||||
// 配置类型切换
|
||||
configType: 'normal', // 'normal' 或 'system'
|
||||
|
||||
@@ -396,32 +215,12 @@ export default {
|
||||
isSystemConfig: false,
|
||||
|
||||
// 多配置文件管理
|
||||
appliedToRadioValue: '0',
|
||||
selectedConfigID: null, // 用于存储当前选中的配置项信息
|
||||
configInfoList: [],
|
||||
platformList: [],
|
||||
configFormData: {
|
||||
name: '',
|
||||
umop: [],
|
||||
},
|
||||
editingConfigId: null,
|
||||
conflictMessage: '', // 冲突提示信息
|
||||
|
||||
// 自定义规则相关
|
||||
customRuleInputMode: 'builder', // 'builder' 或 'manual'
|
||||
customRules: [
|
||||
{
|
||||
platform: '*',
|
||||
messageType: '*',
|
||||
sessionId: '*'
|
||||
}
|
||||
],
|
||||
manualRulesText: '',
|
||||
messageTypeOptions: [
|
||||
{ label: '所有消息类型', value: '*' },
|
||||
{ label: '群组消息', value: 'GroupMessage' },
|
||||
{ label: '私聊消息', value: 'FriendMessage' }
|
||||
],
|
||||
}
|
||||
},
|
||||
mounted() {
|
||||
@@ -450,13 +249,6 @@ export default {
|
||||
this.save_message_success = "error";
|
||||
});
|
||||
},
|
||||
getPlatformList() {
|
||||
axios.get('/api/config/platform/list').then((res) => {
|
||||
this.platformList = res.data.data.platforms;
|
||||
}).catch((err) => {
|
||||
console.error(this.t('status.dataError'), err);
|
||||
});
|
||||
},
|
||||
getConfig(abconf_id) {
|
||||
this.fetched = false
|
||||
const params = {};
|
||||
@@ -532,18 +324,7 @@ export default {
|
||||
}
|
||||
},
|
||||
createNewConfig() {
|
||||
let umo_parts = [];
|
||||
|
||||
if (this.appliedToRadioValue === '0') {
|
||||
// 修正为 umo part 形式 - 指定平台
|
||||
umo_parts = this.configFormData.umop.map(platform => platform + "::");
|
||||
} else if (this.appliedToRadioValue === '1') {
|
||||
// 自定义规则
|
||||
umo_parts = [...this.configFormData.umop]; // 直接使用 umop,它已经包含完整的规则
|
||||
}
|
||||
|
||||
axios.post('/api/config/abconf/new', {
|
||||
umo_parts: umo_parts,
|
||||
name: this.configFormData.name
|
||||
}).then((res) => {
|
||||
if (res.data.status === "ok") {
|
||||
@@ -564,210 +345,9 @@ export default {
|
||||
this.save_message_success = "error";
|
||||
});
|
||||
},
|
||||
checkPlatformConflict(newRules) {
|
||||
const conflictConfigs = [];
|
||||
|
||||
// 遍历现有的配置文件,排除名为 "default" 的配置
|
||||
for (const config of this.configInfoList) {
|
||||
if (config.name === 'default') {
|
||||
continue; // 跳过 default 配置
|
||||
}
|
||||
|
||||
if (config.umop && config.umop.length > 0) {
|
||||
// 检查是否有冲突
|
||||
const hasConflict = this.hasUmoConflict(newRules, config.umop);
|
||||
|
||||
if (hasConflict) {
|
||||
conflictConfigs.push(config);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return conflictConfigs;
|
||||
},
|
||||
|
||||
hasUmoConflict(newRules, existingRules) {
|
||||
// 检查新规则与现有规则是否有冲突
|
||||
for (const newRule of newRules) {
|
||||
for (const existingRule of existingRules) {
|
||||
if (this.isUmoMatch(newRule, existingRule) || this.isUmoMatch(existingRule, newRule)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
},
|
||||
|
||||
isUmoMatch(p1, p2) {
|
||||
// 判断 p2 umo 是否逻辑包含于 p1 umo
|
||||
// 基于后端的 _is_umo_match 逻辑
|
||||
|
||||
// 先标准化规则格式
|
||||
const p1_normalized = this.normalizeUmoRule(p1);
|
||||
const p2_normalized = this.normalizeUmoRule(p2);
|
||||
|
||||
const p1_parts = p1_normalized.split(":");
|
||||
const p2_parts = p2_normalized.split(":");
|
||||
|
||||
if (p1_parts.length !== 3 || p2_parts.length !== 3) {
|
||||
return false; // 非法格式
|
||||
}
|
||||
|
||||
// 检查每个部分是否匹配
|
||||
return p1_parts.every((p, index) => {
|
||||
const t = p2_parts[index];
|
||||
return p === "" || p === "*" || p === t;
|
||||
});
|
||||
},
|
||||
|
||||
normalizeUmoRule(rule) {
|
||||
// 标准化规则格式
|
||||
if (typeof rule !== 'string') {
|
||||
return "*:*:*";
|
||||
}
|
||||
|
||||
const parts = rule.split(":");
|
||||
|
||||
if (parts.length === 2 && parts[1] === "") {
|
||||
// 传统格式 "platform::" -> "platform:*:*"
|
||||
return `${parts[0] || "*"}:*:*`;
|
||||
} else if (parts.length === 3) {
|
||||
// 已经是完整格式,只需要处理空字符串
|
||||
return parts.map(part => part === "" ? "*" : part).join(":");
|
||||
} else if (parts.length === 1) {
|
||||
// 只有平台 "platform" -> "platform:*:*"
|
||||
return `${parts[0] || "*"}:*:*`;
|
||||
}
|
||||
|
||||
// 默认返回通配符
|
||||
return "*:*:*";
|
||||
},
|
||||
|
||||
getDetailedConflictInfo(newRules) {
|
||||
const conflictDetails = [];
|
||||
|
||||
// 获取所有配置文件及其优先级(按创建时间排序,早创建的优先级高)
|
||||
const sortedConfigs = [...this.configInfoList]
|
||||
.filter(config => config.name !== 'default')
|
||||
.sort((a, b) => {
|
||||
// 假设按字母顺序排序作为优先级(实际应该按创建时间)
|
||||
return a.id.localeCompare(b.id);
|
||||
});
|
||||
|
||||
for (const config of sortedConfigs) {
|
||||
if (!config.umop || config.umop.length === 0) continue;
|
||||
|
||||
const conflictingRules = [];
|
||||
|
||||
for (const newRule of newRules) {
|
||||
for (const existingRule of config.umop) {
|
||||
if (this.isUmoMatch(newRule, existingRule) || this.isUmoMatch(existingRule, newRule)) {
|
||||
conflictingRules.push({
|
||||
newRule: newRule,
|
||||
existingRule: existingRule,
|
||||
matchType: this.getMatchType(newRule, existingRule)
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (conflictingRules.length > 0) {
|
||||
conflictDetails.push({
|
||||
config: config,
|
||||
conflicts: conflictingRules
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return conflictDetails;
|
||||
},
|
||||
|
||||
getMatchType(rule1, rule2) {
|
||||
const r1_normalized = this.normalizeUmoRule(rule1);
|
||||
const r2_normalized = this.normalizeUmoRule(rule2);
|
||||
|
||||
const isR1MatchR2 = this.isUmoMatch(rule1, rule2);
|
||||
const isR2MatchR1 = this.isUmoMatch(rule2, rule1);
|
||||
|
||||
if (isR1MatchR2 && isR2MatchR1) {
|
||||
return 'exact'; // 完全匹配
|
||||
} else if (isR1MatchR2) {
|
||||
return 'new_covers_existing'; // 新规则覆盖现有规则
|
||||
} else if (isR2MatchR1) {
|
||||
return 'existing_covers_new'; // 现有规则覆盖新规则
|
||||
}
|
||||
|
||||
return 'overlap'; // 部分重叠
|
||||
},
|
||||
|
||||
formatConflictMessage(conflictDetails) {
|
||||
if (conflictDetails.length === 0) return '';
|
||||
|
||||
let message = '⚠️ <strong>规则冲突警告:</strong><br><br>';
|
||||
|
||||
// 按优先级排序(最先创建的配置文件优先级最高)
|
||||
const sortedDetails = [...conflictDetails].sort((a, b) =>
|
||||
a.config.id.localeCompare(b.config.id)
|
||||
);
|
||||
|
||||
sortedDetails.forEach((detail, index) => {
|
||||
const configName = detail.config.name || detail.config.id;
|
||||
message += `<strong>${index + 1}. 与配置文件 "${configName}" 冲突:</strong><br>`;
|
||||
|
||||
detail.conflicts.forEach(conflict => {
|
||||
const newRuleFormatted = this.formatRuleForDisplay(conflict.newRule);
|
||||
const existingRuleFormatted = this.formatRuleForDisplay(conflict.existingRule);
|
||||
|
||||
switch (conflict.matchType) {
|
||||
case 'exact':
|
||||
message += `规则完全相同: <code>${newRuleFormatted}</code><br>`;
|
||||
message += `<span style="color: orange;">"${configName}" 将覆盖当前配置</span><br>`;
|
||||
break;
|
||||
case 'new_covers_existing':
|
||||
message += `当前规则 <code>${newRuleFormatted}</code> 包含现有规则 <code>${existingRuleFormatted}</code><br>`;
|
||||
message += `<span style="color: red;">"${configName}" 的规则将优先匹配</span><br>`;
|
||||
break;
|
||||
case 'existing_covers_new':
|
||||
message += `现有规则 <code>${existingRuleFormatted}</code> 包含当前规则 <code>${newRuleFormatted}</code><br>`;
|
||||
message += `<span style="color: red;">"${configName}" 的规则将优先匹配</span><br>`;
|
||||
break;
|
||||
case 'overlap':
|
||||
message += `规则重叠: <code>${newRuleFormatted}</code> ↔ <code>${existingRuleFormatted}</code><br>`;
|
||||
message += `<span style="color: orange;">"${configName}" 在匹配范围内优先</span><br>`;
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
if (index < sortedDetails.length - 1) {
|
||||
message += '<br>';
|
||||
}
|
||||
});
|
||||
|
||||
message += '<br><small><strong>💡 说明:</strong> 您仍可创建此配置文件。AstrBot 按配置文件创建顺序匹配规则,先创建的配置文件优先级更高。当多个配置文件的规则匹配同一个消息会话来源时,优先级最高的配置文件会生效(default 配置文件除外)。</small>';
|
||||
|
||||
return message;
|
||||
},
|
||||
|
||||
formatRuleForDisplay(rule) {
|
||||
const parts = this.normalizeUmoRule(rule).split(':');
|
||||
const platform = parts[0] === '*' || parts[0] === '' ? '任意平台' : parts[0];
|
||||
const messageType = parts[1] === '*' || parts[1] === '' ? '任意消息' : this.getMessageTypeLabel(parts[1]);
|
||||
const sessionId = parts[2] === '*' || parts[2] === '' ? '任意会话' : parts[2];
|
||||
|
||||
return `${platform}:${messageType}:${sessionId}`;
|
||||
},
|
||||
|
||||
getMessageTypeLabel(messageType) {
|
||||
const typeMap = {
|
||||
'GroupMessage': '群组消息',
|
||||
'FriendMessage': '私聊消息',
|
||||
};
|
||||
return typeMap[messageType] || messageType;
|
||||
},
|
||||
onConfigSelect(value) {
|
||||
if (value === '_%manage%_') {
|
||||
this.configManageDialog = true;
|
||||
this.getPlatformList();
|
||||
// 重置选择到之前的值
|
||||
this.$nextTick(() => {
|
||||
this.selectedConfigID = this.selectedConfigInfo.id || 'default';
|
||||
@@ -781,26 +361,17 @@ export default {
|
||||
this.isEditingConfig = false;
|
||||
this.configFormData = {
|
||||
name: '',
|
||||
umop: [],
|
||||
};
|
||||
this.editingConfigId = null;
|
||||
this.conflictMessage = '';
|
||||
this.resetCustomRules();
|
||||
},
|
||||
startEditConfig(config) {
|
||||
this.appliedToRadioValue = "1";
|
||||
this.showConfigForm = true;
|
||||
this.isEditingConfig = true;
|
||||
this.editingConfigId = config.id;
|
||||
|
||||
this.parseExistingCustomRules(config.umop || []);
|
||||
|
||||
this.configFormData = {
|
||||
name: config.name || '',
|
||||
umop: [...(config.umop || [])],
|
||||
};
|
||||
|
||||
this.conflictMessage = '';
|
||||
},
|
||||
cancelConfigForm() {
|
||||
this.showConfigForm = false;
|
||||
@@ -808,14 +379,11 @@ export default {
|
||||
this.editingConfigId = null;
|
||||
this.configFormData = {
|
||||
name: '',
|
||||
umop: [],
|
||||
};
|
||||
this.conflictMessage = '';
|
||||
this.resetCustomRules();
|
||||
},
|
||||
saveConfigForm() {
|
||||
if (!this.configFormData.name || !this.configFormData.umop.length) {
|
||||
this.save_message = "请填写配置名称和选择应用平台";
|
||||
if (!this.configFormData.name) {
|
||||
this.save_message = "请填写配置名称";
|
||||
this.save_message_snack = true;
|
||||
this.save_message_success = "error";
|
||||
return;
|
||||
@@ -827,106 +395,6 @@ export default {
|
||||
this.createNewConfig();
|
||||
}
|
||||
},
|
||||
|
||||
// 自定义规则相关方法
|
||||
addCustomRule() {
|
||||
this.customRules.push({
|
||||
platform: '*',
|
||||
messageType: '*',
|
||||
sessionId: '*'
|
||||
});
|
||||
this.updateCustomRulesFromBuilder();
|
||||
},
|
||||
|
||||
removeCustomRule(index) {
|
||||
if (this.customRules.length > 1) {
|
||||
this.customRules.splice(index, 1);
|
||||
this.updateCustomRulesFromBuilder();
|
||||
}
|
||||
},
|
||||
|
||||
updateCustomRule(index) {
|
||||
this.updateCustomRulesFromBuilder();
|
||||
},
|
||||
|
||||
updateCustomRulesFromBuilder() {
|
||||
// 从规则构建器更新 umop
|
||||
const rules = this.customRules.map(rule => {
|
||||
const platform = rule.platform === '*' ? '' : rule.platform;
|
||||
const messageType = rule.messageType === '*' ? '' : rule.messageType;
|
||||
const sessionId = rule.sessionId === '*' ? '' : (rule.sessionId || '');
|
||||
return `${platform}:${messageType}:${sessionId}`;
|
||||
});
|
||||
|
||||
this.configFormData.umop = rules;
|
||||
this.syncManualRulesText();
|
||||
// 触发冲突检测
|
||||
this.checkPlatformConflictOnForm();
|
||||
},
|
||||
|
||||
updateManualRules() {
|
||||
// 从手动输入更新 umop
|
||||
const rules = this.manualRulesText
|
||||
.split('\n')
|
||||
.map(rule => rule.trim())
|
||||
.filter(rule => rule);
|
||||
|
||||
this.configFormData.umop = rules;
|
||||
this.syncCustomRulesFromManual();
|
||||
// 触发冲突检测
|
||||
this.checkPlatformConflictOnForm();
|
||||
},
|
||||
|
||||
syncManualRulesText() {
|
||||
// 同步到手动输入文本区域
|
||||
this.manualRulesText = this.configFormData.umop.join('\n');
|
||||
},
|
||||
|
||||
syncCustomRulesFromManual() {
|
||||
// 从手动输入同步到规则构建器
|
||||
this.customRules = this.configFormData.umop.map(rule => {
|
||||
const parts = rule.split(':');
|
||||
return {
|
||||
platform: parts[0] || '*',
|
||||
messageType: parts[1] || '*',
|
||||
sessionId: parts[2] || '*'
|
||||
};
|
||||
});
|
||||
},
|
||||
|
||||
resetCustomRules() {
|
||||
this.customRuleInputMode = 'builder'; // 重置为快速构建模式
|
||||
this.customRules = [
|
||||
{
|
||||
platform: '*',
|
||||
messageType: '*',
|
||||
sessionId: '*'
|
||||
}
|
||||
];
|
||||
this.manualRulesText = '';
|
||||
if (this.appliedToRadioValue === '1') {
|
||||
this.updateCustomRulesFromBuilder();
|
||||
}
|
||||
},
|
||||
|
||||
parseExistingCustomRules(umop) {
|
||||
// 解析现有的自定义规则
|
||||
if (!umop || umop.length === 0) {
|
||||
this.resetCustomRules();
|
||||
return;
|
||||
}
|
||||
|
||||
this.customRules = umop.map(rule => {
|
||||
const parts = rule.split(':');
|
||||
return {
|
||||
platform: parts[0] || '*',
|
||||
messageType: parts[1] || '*',
|
||||
sessionId: parts[2] || '*'
|
||||
};
|
||||
});
|
||||
|
||||
this.syncManualRulesText();
|
||||
},
|
||||
confirmDeleteConfig(config) {
|
||||
if (confirm(`确定要删除配置文件 "${config.name}" 吗?此操作不可恢复。`)) {
|
||||
this.deleteConfig(config.id);
|
||||
@@ -940,6 +408,7 @@ export default {
|
||||
this.save_message = res.data.message;
|
||||
this.save_message_snack = true;
|
||||
this.save_message_success = "success";
|
||||
this.cancelConfigForm();
|
||||
// 删除成功后,更新配置列表
|
||||
this.getConfigInfoList("default");
|
||||
} else {
|
||||
@@ -954,52 +423,10 @@ export default {
|
||||
this.save_message_success = "error";
|
||||
});
|
||||
},
|
||||
checkPlatformConflictOnForm() {
|
||||
if (!this.configFormData.umop || this.configFormData.umop.length === 0) {
|
||||
this.conflictMessage = '';
|
||||
return;
|
||||
}
|
||||
|
||||
// 准备用于冲突检测的规则列表
|
||||
let rulesToCheck = [];
|
||||
|
||||
if (this.appliedToRadioValue === '0') {
|
||||
// 平台模式:转换为标准UMO格式
|
||||
rulesToCheck = this.configFormData.umop.map(platform => `${platform}:*:*`);
|
||||
} else {
|
||||
// 自定义模式:直接使用规则
|
||||
rulesToCheck = [...this.configFormData.umop];
|
||||
}
|
||||
|
||||
// 检查与其他配置文件的冲突
|
||||
let conflictDetails = this.getDetailedConflictInfo(rulesToCheck);
|
||||
|
||||
// 如果是编辑模式,排除当前编辑的配置文件
|
||||
if (this.isEditingConfig && this.editingConfigId) {
|
||||
conflictDetails = conflictDetails.filter(detail => detail.config.id !== this.editingConfigId);
|
||||
}
|
||||
|
||||
if (conflictDetails.length > 0) {
|
||||
this.conflictMessage = this.formatConflictMessage(conflictDetails);
|
||||
} else {
|
||||
this.conflictMessage = '';
|
||||
}
|
||||
},
|
||||
updateConfigInfo() {
|
||||
let umo_parts = [];
|
||||
|
||||
if (this.appliedToRadioValue === '0') {
|
||||
// 修正为 umo part 形式 - 指定平台
|
||||
umo_parts = this.configFormData.umop.map(platform => platform + "::");
|
||||
} else if (this.appliedToRadioValue === '1') {
|
||||
// 自定义规则
|
||||
umo_parts = [...this.configFormData.umop]; // 直接使用 umop,它已经包含完整的规则
|
||||
}
|
||||
|
||||
axios.post('/api/config/abconf/update', {
|
||||
id: this.editingConfigId,
|
||||
name: this.configFormData.name,
|
||||
umo_parts: umo_parts
|
||||
name: this.configFormData.name
|
||||
}).then((res) => {
|
||||
if (res.data.status === "ok") {
|
||||
this.save_message = res.data.message;
|
||||
@@ -1019,38 +446,8 @@ export default {
|
||||
this.save_message_success = "error";
|
||||
});
|
||||
},
|
||||
formatUmop(umop) {
|
||||
if (!umop) {
|
||||
return
|
||||
}
|
||||
let ret = ""
|
||||
for (let i = 0; i < umop.length; i++) {
|
||||
const parts = umop[i].split(":");
|
||||
if (parts.length === 3) {
|
||||
// 自定义规则格式 platform:messageType:sessionId
|
||||
const platform = parts[0] || "*";
|
||||
const messageType = parts[1] || "*";
|
||||
const sessionId = parts[2] || "*";
|
||||
if (platform === "*" && messageType === "*" && sessionId === "*") {
|
||||
return "所有平台";
|
||||
}
|
||||
ret += `${platform}:${messageType}:${sessionId},`;
|
||||
} else {
|
||||
// 传统平台格式
|
||||
let platformPart = umop[i].split(":")[0];
|
||||
if (platformPart === "") {
|
||||
return "所有平台";
|
||||
} else {
|
||||
ret += platformPart + ",";
|
||||
}
|
||||
}
|
||||
}
|
||||
ret = ret.slice(0, -1);
|
||||
return ret;
|
||||
},
|
||||
onConfigTypeToggle() {
|
||||
this.isSystemConfig = this.configType === 'system';
|
||||
this.tab = 0; // 重置标签页
|
||||
this.fetched = false; // 重置加载状态
|
||||
|
||||
if (this.isSystemConfig) {
|
||||
@@ -1069,7 +466,6 @@ export default {
|
||||
// 保持向后兼容性,更新 configType
|
||||
this.configType = this.isSystemConfig ? 'system' : 'normal';
|
||||
|
||||
this.tab = 0; // 重置标签页
|
||||
this.fetched = false; // 重置加载状态
|
||||
|
||||
if (this.isSystemConfig) {
|
||||
@@ -1128,31 +524,12 @@ export default {
|
||||
}
|
||||
|
||||
@media (min-width: 768px) {
|
||||
.config-tabs {
|
||||
display: flex;
|
||||
margin: 16px 16px 0 0;
|
||||
}
|
||||
|
||||
.config-panel {
|
||||
width: 750px;
|
||||
}
|
||||
|
||||
.config-tabs-window {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
.config-tabs .v-tab {
|
||||
justify-content: flex-start !important;
|
||||
text-align: left;
|
||||
min-height: 48px;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 767px) {
|
||||
.config-tabs {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.v-container {
|
||||
padding: 4px;
|
||||
}
|
||||
@@ -1160,9 +537,5 @@ export default {
|
||||
.config-panel {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.config-tabs-window {
|
||||
margin-top: 16px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -433,14 +433,14 @@ export default {
|
||||
tableHeaders() {
|
||||
return [
|
||||
{ title: this.tm('table.headers.title'), key: 'title', sortable: true },
|
||||
{ title: '会话 ID', key: 'cid', sortable: true, width: '100px' },
|
||||
{ title: this.tm('table.headers.cid'), key: 'cid', sortable: true, width: '100px' },
|
||||
{
|
||||
title: this.tm('table.headers.sessionId'),
|
||||
title: this.tm('table.headers.umo'),
|
||||
align: 'center',
|
||||
children: [
|
||||
{ title: this.tm('table.headers.platform'), key: 'platform', sortable: true, width: '120px' },
|
||||
{ title: this.tm('table.headers.type'), key: 'messageType', sortable: true, width: '100px' },
|
||||
{ title: '用户 ID', key: 'sessionId', sortable: true, width: '100px' },
|
||||
{ title: this.tm('table.headers.sessionId'), key: 'sessionId', sortable: true, width: '100px' },
|
||||
],
|
||||
},
|
||||
{ title: this.tm('table.headers.createdAt'), key: 'created_at', sortable: true, width: '180px' },
|
||||
|
||||
@@ -749,9 +749,9 @@ onMounted(async () => {
|
||||
</v-row>
|
||||
|
||||
<v-row>
|
||||
<v-col cols="12" md="6" lg="4" v-for="extension in filteredPlugins" :key="extension.name"
|
||||
<v-col cols="12" md="6" lg="6" v-for="extension in filteredPlugins" :key="extension.name"
|
||||
class="pb-4">
|
||||
<ExtensionCard :extension="extension" class="h-120 rounded-lg"
|
||||
<ExtensionCard :extension="extension" class="rounded-lg"
|
||||
@configure="openExtensionConfig(extension.name)" @uninstall="uninstallExtension(extension.name)"
|
||||
@update="updateExtension(extension.name)" @reload="reloadPlugin(extension.name)"
|
||||
@toggle-activation="extension.activated ? pluginOff(extension) : pluginOn(extension)"
|
||||
|
||||
@@ -90,200 +90,11 @@
|
||||
</v-container>
|
||||
|
||||
<!-- 创建/编辑人格对话框 -->
|
||||
<v-dialog v-model="showPersonaDialog" max-width="800px" persistent>
|
||||
<v-card>
|
||||
<v-card-title class="text-h2">
|
||||
{{ editingPersona ? tm('dialog.edit.title') : tm('dialog.create.title') }}
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text>
|
||||
<v-form ref="personaForm" v-model="formValid">
|
||||
<v-text-field v-model="personaForm.persona_id" :label="tm('form.personaId')"
|
||||
:rules="personaIdRules" :disabled="editingPersona" variant="outlined" density="comfortable"
|
||||
class="mb-4" />
|
||||
|
||||
<v-textarea v-model="personaForm.system_prompt" :label="tm('form.systemPrompt')"
|
||||
:rules="systemPromptRules" variant="outlined" rows="6" class="mb-4" />
|
||||
|
||||
<v-expansion-panels v-model="expandedPanels" multiple>
|
||||
<!-- 工具选择面板 -->
|
||||
<v-expansion-panel value="tools">
|
||||
<v-expansion-panel-title>
|
||||
<v-icon class="mr-2">mdi-tools</v-icon>
|
||||
{{ tm('form.tools') }}
|
||||
<v-chip v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
|
||||
size="small" color="primary" variant="tonal" class="ml-2">
|
||||
{{ personaForm.tools.length }}
|
||||
</v-chip>
|
||||
</v-expansion-panel-title>
|
||||
|
||||
<v-expansion-panel-text>
|
||||
<div class="mb-3">
|
||||
<p class="text-body-2 text-medium-emphasis">
|
||||
{{ tm('form.toolsHelp') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<v-radio-group class="mt-2" v-model="toolSelectValue" hide-details="true">
|
||||
<v-radio label="默认使用全部函数工具" value="0"></v-radio>
|
||||
<v-radio label="选择指定函数工具" value="1">
|
||||
</v-radio>
|
||||
</v-radio-group>
|
||||
|
||||
<div v-if="toolSelectValue === '1'" class="mt-3 ml-8">
|
||||
|
||||
<!-- 工具搜索 -->
|
||||
<v-text-field v-model="toolSearch" :label="tm('form.searchTools')"
|
||||
prepend-inner-icon="mdi-magnify" variant="outlined" density="compact"
|
||||
hide-details clearable class="mb-3" />
|
||||
|
||||
|
||||
<!-- MCP 服务器 -->
|
||||
<div v-if="mcpServers.length > 0" class="mb-4">
|
||||
<h4 class="text-subtitle-2 mb-2">{{ tm('form.mcpServersQuickSelect') }}</h4>
|
||||
<div class="d-flex flex-wrap ga-2">
|
||||
<v-chip v-for="server in mcpServers" :key="server.name"
|
||||
:color="isServerSelected(server) ? 'primary' : 'default'"
|
||||
:variant="isServerSelected(server) ? 'flat' : 'outlined'"
|
||||
size="small" clickable @click="toggleMcpServer(server)"
|
||||
:disabled="!server.tools || server.tools.length === 0">
|
||||
<v-icon start size="small">mdi-server</v-icon>
|
||||
{{ server.name }}
|
||||
<v-chip-text v-if="server.tools" class="ml-1">
|
||||
({{ server.tools.length }})
|
||||
</v-chip-text>
|
||||
</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 工具选择列表 -->
|
||||
<div v-if="filteredTools.length > 0" class="tools-selection">
|
||||
<v-virtual-scroll :items="filteredTools" height="300" item-height="48">
|
||||
<template v-slot:default="{ item }">
|
||||
<v-list-item :key="item.name" density="comfortable"
|
||||
@click="toggleTool(item.name)">
|
||||
<template v-slot:prepend>
|
||||
<v-checkbox-btn :model-value="isToolSelected(item.name)"
|
||||
@click.stop="toggleTool(item.name)" />
|
||||
</template>
|
||||
|
||||
<v-list-item-title>
|
||||
{{ item.name }}
|
||||
<v-chip v-if="item.mcp_server_name" size="x-small"
|
||||
color="secondary" variant="tonal" class="ml-2">
|
||||
{{ item.mcp_server_name }}
|
||||
</v-chip>
|
||||
</v-list-item-title>
|
||||
|
||||
<v-list-item-subtitle v-if="item.description">
|
||||
{{ truncateText(item.description, 100) }}
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</template>
|
||||
</v-virtual-scroll>
|
||||
</div>
|
||||
|
||||
<div v-else-if="!loadingTools && availableTools.length === 0"
|
||||
class="text-center pa-4">
|
||||
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-tools</v-icon>
|
||||
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsAvailable')
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div v-else-if="!loadingTools && filteredTools.length === 0"
|
||||
class="text-center pa-4">
|
||||
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-magnify</v-icon>
|
||||
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsFound') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 加载状态 -->
|
||||
<div v-if="loadingTools" class="text-center pa-4">
|
||||
<v-progress-circular indeterminate color="primary" />
|
||||
<p class="text-body-2 text-medium-emphasis mt-2">{{ tm('form.loadingTools')
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 已选择的工具 -->
|
||||
<div class="mt-4">
|
||||
<h4 class="text-subtitle-2 mb-2">
|
||||
{{ tm('form.selectedTools') }}
|
||||
<span v-if="personaForm.tools === null" class="text-success">
|
||||
({{ tm('form.allSelected') }})
|
||||
</span>
|
||||
<span v-else-if="Array.isArray(personaForm.tools)">
|
||||
({{ personaForm.tools.length }})
|
||||
</span>
|
||||
</h4>
|
||||
<div v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
|
||||
class="d-flex flex-wrap ga-1" style="max-height: 100px; overflow-y: auto;">
|
||||
<v-chip v-for="toolName in personaForm.tools" :key="toolName"
|
||||
size="small" color="primary" variant="tonal" closable
|
||||
@click:close="removeTool(toolName)">
|
||||
{{ toolName }}
|
||||
</v-chip>
|
||||
</div>
|
||||
<div v-else class="text-body-2 text-medium-emphasis">
|
||||
{{ tm('form.noToolsSelected') }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</v-expansion-panel-text>
|
||||
</v-expansion-panel>
|
||||
|
||||
<!-- 预设对话面板 -->
|
||||
<v-expansion-panel value="dialogs">
|
||||
<v-expansion-panel-title>
|
||||
<v-icon class="mr-2">mdi-chat</v-icon>
|
||||
{{ tm('form.presetDialogs') }}
|
||||
<v-chip v-if="personaForm.begin_dialogs.length > 0" size="small" color="primary"
|
||||
variant="tonal" class="ml-2">
|
||||
{{ personaForm.begin_dialogs.length / 2 }}
|
||||
</v-chip>
|
||||
</v-expansion-panel-title>
|
||||
|
||||
<v-expansion-panel-text>
|
||||
<div class="mb-3">
|
||||
<p class="text-body-2 text-medium-emphasis">
|
||||
{{ tm('form.presetDialogsHelp') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div v-for="(dialog, index) in personaForm.begin_dialogs" :key="index" class="mb-3">
|
||||
<v-textarea v-model="personaForm.begin_dialogs[index]"
|
||||
:label="index % 2 === 0 ? tm('form.userMessage') : tm('form.assistantMessage')"
|
||||
:rules="getDialogRules(index)" variant="outlined" rows="2"
|
||||
density="comfortable">
|
||||
<template v-slot:append>
|
||||
<v-btn icon="mdi-delete" variant="text" size="small" color="error"
|
||||
@click="removeDialog(index)" />
|
||||
</template>
|
||||
</v-textarea>
|
||||
</div>
|
||||
|
||||
<v-btn variant="outlined" prepend-icon="mdi-plus" @click="addDialogPair" block>
|
||||
{{ tm('buttons.addDialogPair') }}
|
||||
</v-btn>
|
||||
</v-expansion-panel-text>
|
||||
</v-expansion-panel>
|
||||
</v-expansion-panels>
|
||||
</v-form>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions>
|
||||
<v-spacer />
|
||||
<v-btn color="grey" variant="text" @click="closePersonaDialog">
|
||||
{{ tm('buttons.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" variant="flat" @click="savePersona" :loading="saving" :disabled="!formValid">
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
<PersonaForm
|
||||
v-model="showPersonaDialog"
|
||||
:editing-persona="editingPersona"
|
||||
@saved="handlePersonaSaved"
|
||||
@error="showError" />
|
||||
|
||||
<!-- 查看人格详情对话框 -->
|
||||
<v-dialog v-model="showViewDialog" max-width="700px">
|
||||
@@ -352,9 +163,13 @@
|
||||
<script>
|
||||
import axios from 'axios';
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||
import PersonaForm from '@/components/shared/PersonaForm.vue';
|
||||
|
||||
export default {
|
||||
name: 'PersonaPage',
|
||||
components: {
|
||||
PersonaForm
|
||||
},
|
||||
setup() {
|
||||
const { t } = useI18n();
|
||||
const { tm } = useModuleI18n('features/persona');
|
||||
@@ -362,76 +177,20 @@ export default {
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
toolSelectValue: '0', // 默认选择全部工具
|
||||
personas: [],
|
||||
loading: false,
|
||||
saving: false,
|
||||
showPersonaDialog: false,
|
||||
showViewDialog: false,
|
||||
editingPersona: null,
|
||||
viewingPersona: null,
|
||||
expandedPanels: [],
|
||||
formValid: false,
|
||||
personaForm: {
|
||||
persona_id: '',
|
||||
system_prompt: '',
|
||||
begin_dialogs: [],
|
||||
tools: []
|
||||
},
|
||||
showMessage: false,
|
||||
message: '',
|
||||
messageType: 'success',
|
||||
personaIdRules: [
|
||||
v => !!v || this.tm('validation.required'),
|
||||
v => (v && v.length >= 0) || this.tm('validation.minLength', { min: 2 }),
|
||||
],
|
||||
systemPromptRules: [
|
||||
v => !!v || this.tm('validation.required'),
|
||||
v => (v && v.length >= 10) || this.tm('validation.minLength', { min: 10 })
|
||||
],
|
||||
mcpServers: [],
|
||||
availableTools: [],
|
||||
loadingTools: false,
|
||||
toolSearch: ''
|
||||
}
|
||||
},
|
||||
|
||||
computed: {
|
||||
filteredTools() {
|
||||
if (!this.toolSearch) {
|
||||
return this.availableTools;
|
||||
}
|
||||
const search = this.toolSearch.toLowerCase();
|
||||
return this.availableTools.filter(tool =>
|
||||
tool.name.toLowerCase().includes(search) ||
|
||||
(tool.description && tool.description.toLowerCase().includes(search)) ||
|
||||
(tool.mcp_server_name && tool.mcp_server_name.toLowerCase().includes(search))
|
||||
);
|
||||
}
|
||||
},
|
||||
|
||||
watch: {
|
||||
toolSearch() {
|
||||
// 响应式搜索,无需额外处理
|
||||
},
|
||||
|
||||
toolSelectValue(newValue) {
|
||||
if (newValue === '0') {
|
||||
// 选择全部工具
|
||||
this.personaForm.tools = null;
|
||||
} else if (newValue === '1') {
|
||||
// 选择指定工具,如果当前是null,则转换为空数组
|
||||
if (this.personaForm.tools === null) {
|
||||
this.personaForm.tools = [];
|
||||
}
|
||||
}
|
||||
messageType: 'success'
|
||||
}
|
||||
},
|
||||
|
||||
mounted() {
|
||||
this.loadPersonas();
|
||||
this.loadMcpServers();
|
||||
this.loadTools();
|
||||
},
|
||||
|
||||
methods: {
|
||||
@@ -450,58 +209,13 @@ export default {
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async loadMcpServers() {
|
||||
try {
|
||||
const response = await axios.get('/api/tools/mcp/servers');
|
||||
if (response.data.status === 'ok') {
|
||||
this.mcpServers = response.data.data;
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.loadError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.loadError'));
|
||||
}
|
||||
},
|
||||
|
||||
async loadTools() {
|
||||
this.loadingTools = true;
|
||||
try {
|
||||
const response = await axios.get('/api/tools/list');
|
||||
if (response.data.status === 'ok') {
|
||||
this.availableTools = response.data.data;
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.loadError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.loadError'));
|
||||
}
|
||||
this.loadingTools = false;
|
||||
},
|
||||
|
||||
openCreateDialog() {
|
||||
this.editingPersona = null;
|
||||
this.personaForm = {
|
||||
persona_id: '',
|
||||
system_prompt: '',
|
||||
begin_dialogs: [],
|
||||
tools: []
|
||||
};
|
||||
this.toolSelectValue = '0';
|
||||
this.expandedPanels = [];
|
||||
this.showPersonaDialog = true;
|
||||
},
|
||||
|
||||
editPersona(persona) {
|
||||
this.editingPersona = persona;
|
||||
this.personaForm = {
|
||||
persona_id: persona.persona_id,
|
||||
system_prompt: persona.system_prompt,
|
||||
begin_dialogs: [...(persona.begin_dialogs || [])],
|
||||
tools: persona.tools === null ? null : [...(persona.tools || [])]
|
||||
};
|
||||
// 根据 tools 的值设置 toolSelectValue
|
||||
this.toolSelectValue = persona.tools === null ? '0' : '1';
|
||||
this.expandedPanels = [];
|
||||
this.showPersonaDialog = true;
|
||||
},
|
||||
|
||||
@@ -510,48 +224,9 @@ export default {
|
||||
this.showViewDialog = true;
|
||||
},
|
||||
|
||||
closePersonaDialog() {
|
||||
this.showPersonaDialog = false;
|
||||
this.editingPersona = null;
|
||||
this.personaForm = {
|
||||
persona_id: '',
|
||||
system_prompt: '',
|
||||
begin_dialogs: [],
|
||||
tools: []
|
||||
};
|
||||
this.toolSelectValue = '1'; // 重置为默认值
|
||||
},
|
||||
|
||||
async savePersona() {
|
||||
if (!this.formValid) return;
|
||||
|
||||
// 验证预设对话不能为空
|
||||
if (this.personaForm.begin_dialogs.length > 0) {
|
||||
for (let i = 0; i < this.personaForm.begin_dialogs.length; i++) {
|
||||
if (!this.personaForm.begin_dialogs[i] || this.personaForm.begin_dialogs[i].trim() === '') {
|
||||
const dialogType = i % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
|
||||
this.showError(this.tm('validation.dialogRequired', { type: dialogType }));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.saving = true;
|
||||
try {
|
||||
const url = this.editingPersona ? '/api/persona/update' : '/api/persona/create';
|
||||
const response = await axios.post(url, this.personaForm);
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(response.data.message || this.tm('messages.saveSuccess'));
|
||||
this.closePersonaDialog();
|
||||
await this.loadPersonas();
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.saveError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.saveError'));
|
||||
}
|
||||
this.saving = false;
|
||||
handlePersonaSaved(message) {
|
||||
this.showSuccess(message);
|
||||
this.loadPersonas();
|
||||
},
|
||||
|
||||
async deletePersona(persona) {
|
||||
@@ -575,124 +250,6 @@ export default {
|
||||
}
|
||||
},
|
||||
|
||||
addDialogPair() {
|
||||
this.personaForm.begin_dialogs.push('', '');
|
||||
// 自动展开预设对话面板
|
||||
if (!this.expandedPanels.includes('dialogs')) {
|
||||
this.expandedPanels.push('dialogs');
|
||||
}
|
||||
},
|
||||
|
||||
removeDialog(index) {
|
||||
// 如果是偶数索引(用户消息),删除用户消息和对应的助手消息
|
||||
if (index % 2 === 0 && index + 1 < this.personaForm.begin_dialogs.length) {
|
||||
this.personaForm.begin_dialogs.splice(index, 2);
|
||||
}
|
||||
// 如果是奇数索引(助手消息),删除助手消息和对应的用户消息
|
||||
else if (index % 2 === 1 && index - 1 >= 0) {
|
||||
this.personaForm.begin_dialogs.splice(index - 1, 2);
|
||||
}
|
||||
},
|
||||
|
||||
toggleMcpServer(server) {
|
||||
if (!server.tools || server.tools.length === 0) return;
|
||||
|
||||
// 如果当前是全选状态,需要先转换为具体的工具列表
|
||||
if (this.personaForm.tools === null) {
|
||||
// 从全选状态转换为去除该服务器工具的状态
|
||||
this.personaForm.tools = this.availableTools.map(tool => tool.name)
|
||||
.filter(toolName => !server.tools.includes(toolName));
|
||||
this.toolSelectValue = '1'; // 切换到指定工具模式
|
||||
return;
|
||||
}
|
||||
|
||||
// 确保tools是数组
|
||||
if (!Array.isArray(this.personaForm.tools)) {
|
||||
this.personaForm.tools = [];
|
||||
this.toolSelectValue = '1';
|
||||
}
|
||||
|
||||
// 检查是否所有服务器的工具都已选中
|
||||
const serverTools = server.tools;
|
||||
const allSelected = serverTools.every(toolName => this.personaForm.tools.includes(toolName));
|
||||
|
||||
if (allSelected) {
|
||||
// 移除所有服务器工具
|
||||
this.personaForm.tools = this.personaForm.tools.filter(
|
||||
toolName => !serverTools.includes(toolName)
|
||||
);
|
||||
} else {
|
||||
// 添加所有服务器工具
|
||||
serverTools.forEach(toolName => {
|
||||
if (!this.personaForm.tools.includes(toolName)) {
|
||||
this.personaForm.tools.push(toolName);
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
toggleTool(toolName) {
|
||||
// 如果当前是全选状态,需要先转换为具体的工具列表
|
||||
if (this.personaForm.tools === null) {
|
||||
// 如果是全选状态,点击某个工具表示要取消选择该工具
|
||||
// 所以创建一个包含所有其他工具的数组
|
||||
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
|
||||
this.toolSelectValue = '1'; // 切换到指定工具模式
|
||||
} else if (Array.isArray(this.personaForm.tools)) {
|
||||
const index = this.personaForm.tools.indexOf(toolName);
|
||||
if (index !== -1) {
|
||||
// 如果工具已选择,移除工具
|
||||
this.personaForm.tools.splice(index, 1);
|
||||
} else {
|
||||
// 如果工具未选择,添加工具
|
||||
this.personaForm.tools.push(toolName);
|
||||
}
|
||||
} else {
|
||||
// 如果tools不是数组也不是null,初始化为数组
|
||||
this.personaForm.tools = [toolName];
|
||||
this.toolSelectValue = '1';
|
||||
}
|
||||
},
|
||||
|
||||
toggleAllTools() {
|
||||
// 如果当前是全选状态,则清空选择
|
||||
if (this.isAllToolsSelected()) {
|
||||
this.personaForm.tools = [];
|
||||
} else {
|
||||
// 否则设置为全选(null表示所有工具)
|
||||
this.personaForm.tools = null;
|
||||
}
|
||||
},
|
||||
|
||||
clearAllTools() {
|
||||
// 清空所有工具选择
|
||||
this.personaForm.tools = [];
|
||||
},
|
||||
|
||||
isAllToolsSelected() {
|
||||
// 检查是否为全选状态(tools为null)
|
||||
return this.personaForm.tools === null;
|
||||
},
|
||||
|
||||
isNoToolsSelected() {
|
||||
// 检查是否没有选择任何工具
|
||||
return Array.isArray(this.personaForm.tools) && this.personaForm.tools.length === 0;
|
||||
},
|
||||
|
||||
removeTool(toolName) {
|
||||
// 如果当前是全选状态,需要先转换为具体的工具列表
|
||||
if (this.personaForm.tools === null) {
|
||||
// 创建一个包含所有工具的数组,然后移除指定工具
|
||||
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
|
||||
this.toolSelectValue = '1'; // 切换到指定工具模式
|
||||
} else if (Array.isArray(this.personaForm.tools)) {
|
||||
const index = this.personaForm.tools.indexOf(toolName);
|
||||
if (index !== -1) {
|
||||
this.personaForm.tools.splice(index, 1);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
truncateText(text, maxLength) {
|
||||
if (!text) return '';
|
||||
return text.length > maxLength ? text.substring(0, maxLength) + '...' : text;
|
||||
@@ -713,35 +270,6 @@ export default {
|
||||
this.message = message;
|
||||
this.messageType = 'error';
|
||||
this.showMessage = true;
|
||||
},
|
||||
|
||||
getDialogRules(index) {
|
||||
const dialogType = index % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
|
||||
return [
|
||||
v => !!v || this.tm('validation.dialogRequired', { type: dialogType }),
|
||||
v => (v && v.trim().length > 0) || this.tm('validation.dialogRequired', { type: dialogType })
|
||||
];
|
||||
},
|
||||
|
||||
isToolSelected(toolName) {
|
||||
// 如果是全选状态,所有工具都被选中
|
||||
if (this.personaForm.tools === null) {
|
||||
return true;
|
||||
}
|
||||
return Array.isArray(this.personaForm.tools) && this.personaForm.tools.includes(toolName);
|
||||
},
|
||||
|
||||
isServerSelected(server) {
|
||||
if (!server.tools || server.tools.length === 0) return false;
|
||||
|
||||
// 如果是全选状态,所有服务器都被选中
|
||||
if (this.personaForm.tools === null) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 检查服务器的所有工具是否都已选中
|
||||
return Array.isArray(this.personaForm.tools) &&
|
||||
server.tools.every(toolName => this.personaForm.tools.includes(toolName));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -791,13 +319,4 @@ export default {
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.tools-selection {
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.v-virtual-scroll {
|
||||
padding-bottom: 16px;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -3,14 +3,14 @@
|
||||
<v-container fluid class="pa-0">
|
||||
<v-row class="d-flex justify-space-between align-center px-4 py-3 pb-8">
|
||||
<div>
|
||||
<h1 class="text-h1 font-weight-bold mb-2">
|
||||
<v-icon color="black" class="me-2">mdi-connection</v-icon>{{ tm('title') }}
|
||||
<h1 class="text-h1 font-weight-bold mb-2 d-flex align-center">
|
||||
<v-icon color="black" class="me-2">mdi-robot</v-icon>{{ tm('title') }}
|
||||
</h1>
|
||||
<p class="text-subtitle-1 text-medium-emphasis mb-4">
|
||||
{{ tm('subtitle') }}
|
||||
</p>
|
||||
</div>
|
||||
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="showAddPlatformDialog = true"
|
||||
<v-btn color="primary" prepend-icon="mdi-plus" variant="tonal" @click="updatingMode = false; showAddPlatformDialog = true"
|
||||
rounded="xl" size="x-large">
|
||||
{{ tm('addAdapter') }}
|
||||
</v-btn>
|
||||
@@ -46,7 +46,6 @@
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-expand-transition>
|
||||
<v-card-text class="pa-0" v-if="showConsole">
|
||||
@@ -57,91 +56,15 @@
|
||||
</v-container>
|
||||
|
||||
<!-- 添加平台适配器对话框 -->
|
||||
<AddNewPlatform v-model:show="showAddPlatformDialog" :metadata="metadata"
|
||||
@select-template="selectPlatformTemplate" />
|
||||
|
||||
<!-- 配置对话框 -->
|
||||
<v-dialog v-model="showPlatformCfg" persistent width="900px" max-width="90%">
|
||||
<v-card
|
||||
:title="updatingMode ? tm('dialog.edit') : tm('dialog.add') + ` ${newSelectedPlatformName} ` + tm('dialog.adapter')">
|
||||
<v-card-text class="py-4">
|
||||
<v-row>
|
||||
<v-col cols="12">
|
||||
<AstrBotConfig :iterable="newSelectedPlatformConfig" :metadata="metadata['platform_group']?.metadata"
|
||||
metadataKey="platform" />
|
||||
</v-col>
|
||||
</v-row>
|
||||
<v-row class="mt-2">
|
||||
<v-col cols="12" class="text-center">
|
||||
<v-btn color="info" variant="outlined" @click="openTutorial">
|
||||
<v-icon start>mdi-book-open-variant</v-icon>
|
||||
{{ tm('dialog.viewTutorial') }}
|
||||
</v-btn>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-card-text>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="showPlatformCfg = false" :disabled="loading">
|
||||
{{ tm('dialog.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" @click="newPlatform" :loading="loading">
|
||||
{{ tm('dialog.save') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
<AddNewPlatform v-model:show="showAddPlatformDialog" :metadata="metadata" :config_data="config_data" ref="addPlatformDialog"
|
||||
:updating-mode="updatingMode" :updating-platform-config="updatingPlatformConfig" @update="getConfig"
|
||||
@show-toast="showToast" @refresh-config="getConfig"/>
|
||||
|
||||
<!-- 消息提示 -->
|
||||
<v-snackbar :timeout="3000" elevation="24" :color="save_message_success" v-model="save_message_snack"
|
||||
location="top">
|
||||
{{ save_message }}
|
||||
</v-snackbar>
|
||||
|
||||
<!-- ID冲突确认对话框 -->
|
||||
<v-dialog v-model="showIdConflictDialog" max-width="450" persistent>
|
||||
<v-card>
|
||||
<v-card-title class="text-h6 bg-warning d-flex align-center">
|
||||
<v-icon start class="me-2">mdi-alert-circle-outline</v-icon>
|
||||
{{ tm('dialog.idConflict.title') }}
|
||||
</v-card-title>
|
||||
<v-card-text class="py-4 text-body-1 text-medium-emphasis">
|
||||
{{ tm('dialog.idConflict.message', { id: conflictId }) }}
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" variant="text" @click="handleIdConflictConfirm(false)">{{ tm('dialog.idConflict.confirm')
|
||||
}}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 安全警告对话框 -->
|
||||
<v-dialog v-model="showOneBotEmptyTokenWarnDialog" max-width="600" persistent>
|
||||
<v-card>
|
||||
<v-card-title>
|
||||
{{ tm('dialog.securityWarning.title') }}
|
||||
</v-card-title>
|
||||
<v-card-text class="py-4">
|
||||
<p>{{ tm('dialog.securityWarning.aiocqhttpTokenMissing') }}</p>
|
||||
<span><a
|
||||
href="https://docs.astrbot.app/deploy/platform/aiocqhttp/napcat.html#%E9%99%84%E5%BD%95-%E5%A2%9E%E5%BC%BA%E8%BF%9E%E6%8E%A5%E5%AE%89%E5%85%A8%E6%80%A7"
|
||||
target="_blank">{{ tm('dialog.securityWarning.learnMore') }}</a></span>
|
||||
</v-card-text>
|
||||
<v-card-actions class="px-4 pb-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="error" @click="handleOneBotEmptyTokenWarningDismiss(true)">
|
||||
无视警告并继续创建
|
||||
</v-btn>
|
||||
<v-btn color="primary" @click="handleOneBotEmptyTokenWarningDismiss(false)">
|
||||
重新修改
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -191,30 +114,17 @@ export default {
|
||||
config_data: {},
|
||||
fetched: false,
|
||||
metadata: {},
|
||||
showPlatformCfg: false,
|
||||
showAddPlatformDialog: false,
|
||||
|
||||
newSelectedPlatformName: '',
|
||||
newSelectedPlatformConfig: {},
|
||||
updatingPlatformConfig: {},
|
||||
updatingMode: false,
|
||||
|
||||
loading: false,
|
||||
|
||||
save_message_snack: false,
|
||||
save_message: "",
|
||||
save_message_success: "success",
|
||||
|
||||
showConsole: false,
|
||||
|
||||
// ID冲突确认对话框
|
||||
showIdConflictDialog: false,
|
||||
conflictId: '',
|
||||
idConflictResolve: null,
|
||||
|
||||
// OneBot Empty Token Warning #2639
|
||||
showOneBotEmptyTokenWarnDialog: false,
|
||||
oneBotEmptyTokenWarningResolve: null,
|
||||
|
||||
store: useCommonStore()
|
||||
}
|
||||
},
|
||||
@@ -251,11 +161,6 @@ export default {
|
||||
return getPlatformIcon(platform_id);
|
||||
},
|
||||
|
||||
openTutorial() {
|
||||
const tutorialUrl = getTutorialLink(this.newSelectedPlatformConfig.type);
|
||||
window.open(tutorialUrl, '_blank');
|
||||
},
|
||||
|
||||
getConfig() {
|
||||
axios.get('/api/config/get').then((res) => {
|
||||
this.config_data = res.data.data.config;
|
||||
@@ -266,134 +171,13 @@ export default {
|
||||
});
|
||||
},
|
||||
|
||||
// 选择平台模板
|
||||
selectPlatformTemplate(name) {
|
||||
this.newSelectedPlatformName = name;
|
||||
this.showPlatformCfg = true;
|
||||
this.updatingMode = false;
|
||||
this.newSelectedPlatformConfig = JSON.parse(JSON.stringify(
|
||||
this.metadata['platform_group']?.metadata?.platform?.config_template[name] || {}
|
||||
));
|
||||
},
|
||||
|
||||
addFromDefaultConfigTmpl(index) {
|
||||
this.newSelectedPlatformName = index[0];
|
||||
this.showPlatformCfg = true;
|
||||
this.updatingMode = false;
|
||||
this.newSelectedPlatformConfig = JSON.parse(JSON.stringify(
|
||||
this.metadata['platform_group']?.metadata?.platform?.config_template[index[0]] || {}
|
||||
));
|
||||
},
|
||||
|
||||
editPlatform(platform) {
|
||||
this.newSelectedPlatformName = platform.id;
|
||||
this.newSelectedPlatformConfig = JSON.parse(JSON.stringify(platform));
|
||||
this.updatingPlatformConfig = JSON.parse(JSON.stringify(platform));
|
||||
this.updatingMode = true;
|
||||
this.showPlatformCfg = true;
|
||||
},
|
||||
|
||||
newPlatform() {
|
||||
this.loading = true;
|
||||
if (this.updatingMode) {
|
||||
if (this.newSelectedPlatformConfig.type === 'aiocqhttp') {
|
||||
const token = this.newSelectedPlatformConfig.ws_reverse_token;
|
||||
if (!token || token.trim() === '') {
|
||||
this.showOneBotEmptyTokenWarning().then((continueWithWarning) => {
|
||||
if (continueWithWarning) {
|
||||
this.updatePlatform();
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
this.updatePlatform();
|
||||
} else {
|
||||
this.savePlatform();
|
||||
}
|
||||
},
|
||||
|
||||
updatePlatform() {
|
||||
axios.post('/api/config/platform/update', {
|
||||
id: this.newSelectedPlatformName,
|
||||
config: this.newSelectedPlatformConfig
|
||||
}).then((res) => {
|
||||
this.loading = false;
|
||||
this.showPlatformCfg = false;
|
||||
this.getConfig();
|
||||
this.showSuccess(res.data.message || this.messages.updateSuccess);
|
||||
}).catch((err) => {
|
||||
this.loading = false;
|
||||
this.showError(err.response?.data?.message || err.message);
|
||||
this.showAddPlatformDialog = true;
|
||||
this.$nextTick(() => {
|
||||
this.$refs.addPlatformDialog.toggleShowConfigSection();
|
||||
});
|
||||
this.updatingMode = false;
|
||||
},
|
||||
|
||||
async savePlatform() {
|
||||
// 检查 ID 是否已存在
|
||||
const existingPlatform = this.config_data.platform?.find(p => p.id === this.newSelectedPlatformConfig.id);
|
||||
if (existingPlatform) {
|
||||
const confirmed = await this.confirmIdConflict(this.newSelectedPlatformConfig.id);
|
||||
if (!confirmed) {
|
||||
this.loading = false;
|
||||
return; // 如果用户取消,则中止保存
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 aiocqhttp 适配器的安全设置
|
||||
if (this.newSelectedPlatformConfig.type === 'aiocqhttp') {
|
||||
const token = this.newSelectedPlatformConfig.ws_reverse_token;
|
||||
if (!token || token.trim() === '') {
|
||||
const continueWithWarning = await this.showOneBotEmptyTokenWarning();
|
||||
if (!continueWithWarning) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const res = await axios.post('/api/config/platform/new', this.newSelectedPlatformConfig);
|
||||
this.loading = false;
|
||||
this.showPlatformCfg = false;
|
||||
this.getConfig();
|
||||
this.showSuccess(res.data.message || this.messages.addSuccess);
|
||||
} catch (err) {
|
||||
this.loading = false;
|
||||
this.showError(err.response?.data?.message || err.message);
|
||||
}
|
||||
},
|
||||
|
||||
confirmIdConflict(id) {
|
||||
this.conflictId = id;
|
||||
this.showIdConflictDialog = true;
|
||||
return new Promise((resolve) => {
|
||||
this.idConflictResolve = resolve;
|
||||
});
|
||||
},
|
||||
|
||||
handleIdConflictConfirm(confirmed) {
|
||||
if (this.idConflictResolve) {
|
||||
this.idConflictResolve(confirmed);
|
||||
}
|
||||
this.showIdConflictDialog = false;
|
||||
},
|
||||
|
||||
showOneBotEmptyTokenWarning() {
|
||||
this.showOneBotEmptyTokenWarnDialog = true;
|
||||
return new Promise((resolve) => {
|
||||
this.oneBotEmptyTokenWarningResolve = resolve;
|
||||
});
|
||||
},
|
||||
|
||||
handleOneBotEmptyTokenWarningDismiss(continueWithWarning) {
|
||||
this.showOneBotEmptyTokenWarnDialog = false;
|
||||
if (this.oneBotEmptyTokenWarningResolve) {
|
||||
this.oneBotEmptyTokenWarningResolve(continueWithWarning);
|
||||
this.oneBotEmptyTokenWarningResolve = null;
|
||||
}
|
||||
|
||||
if (!continueWithWarning) {
|
||||
this.loading = false;
|
||||
}
|
||||
},
|
||||
|
||||
deletePlatform(platform) {
|
||||
@@ -422,6 +206,14 @@ export default {
|
||||
});
|
||||
},
|
||||
|
||||
showToast({ message, type }) {
|
||||
if (type === 'success') {
|
||||
this.showSuccess(message);
|
||||
} else if (type === 'error') {
|
||||
this.showError(message);
|
||||
}
|
||||
},
|
||||
|
||||
showSuccess(message) {
|
||||
this.save_message = message;
|
||||
this.save_message_success = "success";
|
||||
|
||||
@@ -103,8 +103,6 @@
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-expand-transition>
|
||||
<v-card-text class="pa-0" v-if="showStatus">
|
||||
<v-card-text class="px-4 py-3">
|
||||
@@ -158,8 +156,6 @@
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
|
||||
<v-divider></v-divider>
|
||||
|
||||
<v-expand-transition>
|
||||
<v-card-text class="pa-0" v-if="showConsole">
|
||||
<ConsoleDisplayer style="background-color: #1e1e1e; height: 300px; border-radius: 0"></ConsoleDisplayer>
|
||||
@@ -234,7 +230,7 @@
|
||||
确认保存
|
||||
</v-card-title>
|
||||
<v-card-text class="py-4 text-body-1 text-medium-emphasis">
|
||||
您没有填写 API Key,确定要保存吗?这可能会导致该服务提供商无法正常工作。
|
||||
您没有填写 API Key,确定要保存吗?这可能会导致该模型无法正常工作。
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
|
||||
@@ -60,10 +60,10 @@
|
||||
<p>使用 /sid 指令可查看会话 ID。</p>
|
||||
<p>会话信息:</p>
|
||||
<ul>
|
||||
<li>平台: {{ item.platform }}</li>
|
||||
<li v-if="item.user_name">用户: {{ item.user_name }}</li>
|
||||
<li>机器人 ID: {{ item.platform }}</li>
|
||||
<li v-if="item.message_type">消息类型: {{ item.message_type }}</li>
|
||||
<li v-if="item.session_raw_name">会话 ID: {{ item.session_raw_name }}</li>
|
||||
<li v-if="item.user_name">用户: {{ item.user_name }}</li>
|
||||
</ul>
|
||||
</div>
|
||||
</v-tooltip>
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
<template>
|
||||
<div class="flex-grow-1" style="display: flex; flex-direction: column; height: 100%;">
|
||||
<div style="flex-grow: 1; width: 100%; border: 1px solid #eee; border-radius: 8px; padding: 16px">
|
||||
<v-banner lines="one">
|
||||
<template v-slot:text>
|
||||
建议您更换使用新版知识库功能。
|
||||
</template>
|
||||
</v-banner>
|
||||
<!-- knowledge card -->
|
||||
<div v-if="!installed" class="d-flex align-center justify-center flex-column"
|
||||
style="flex-grow: 1; width: 100%; height: 100%;">
|
||||
@@ -105,9 +110,9 @@
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="error" variant="text" @click="showCreateDialog = false">{{ tm('createDialog.cancel')
|
||||
}}</v-btn>
|
||||
}}</v-btn>
|
||||
<v-btn color="primary" variant="text" @click="submitCreateForm">{{ tm('createDialog.create')
|
||||
}}</v-btn>
|
||||
}}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
@@ -132,7 +137,7 @@
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="primary" variant="text" @click="showEmojiPicker = false">{{ tm('emojiPicker.close')
|
||||
}}</v-btn>
|
||||
}}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
@@ -159,8 +164,8 @@
|
||||
<v-chip v-if="currentKB.rerank_provider_id" color="tertiary" variant="tonal" size="small"
|
||||
rounded="sm">
|
||||
<v-icon start size="small">mdi-sort-variant</v-icon>
|
||||
重排序模型: {{ rerankProviderConfigs.
|
||||
find(provider => provider.id === currentKB.rerank_provider_id)?.rerank_model || '未设置' }}
|
||||
重排序模型: {{rerankProviderConfigs.
|
||||
find(provider => provider.id === currentKB.rerank_provider_id)?.rerank_model || '未设置'}}
|
||||
</v-chip>
|
||||
<small style="margin-left: 8px;">💡 使用方式: 在聊天页中输入 "/kb use {{ currentKB.collection_name }}"</small>
|
||||
</div>
|
||||
@@ -411,7 +416,7 @@
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey-darken-1" variant="text" @click="showDeleteDialog = false">{{
|
||||
tm('deleteDialog.cancel')
|
||||
}}</v-btn>
|
||||
}}</v-btn>
|
||||
<v-btn color="error" variant="text" @click="deleteKnowledgeBase" :loading="deleting">{{
|
||||
tm('deleteDialog.delete') }}</v-btn>
|
||||
</v-card-actions>
|
||||
@@ -603,6 +608,7 @@ export default {
|
||||
.then(response => {
|
||||
if (response.data.status !== 'ok' || response.data.data.length === 0) {
|
||||
this.showSnackbar(this.tm('messages.pluginNotAvailable'), 'error');
|
||||
this.installed = false;
|
||||
return
|
||||
}
|
||||
if (!response.data.data[0].activated) {
|
||||
|
||||
@@ -82,10 +82,10 @@
|
||||
<v-card-title class="d-flex align-center pa-4">
|
||||
<span>{{ t('chunks.title') }}</span>
|
||||
<v-chip class="ml-2" size="small" variant="tonal">
|
||||
{{ chunks.length }} {{ t('chunks.title') }}
|
||||
{{ totalChunks }} {{ t('chunks.title') }}
|
||||
</v-chip>
|
||||
<v-spacer />
|
||||
<v-text-field
|
||||
<!-- <v-text-field
|
||||
v-model="searchQuery"
|
||||
prepend-inner-icon="mdi-magnify"
|
||||
:placeholder="t('chunks.searchPlaceholder')"
|
||||
@@ -94,7 +94,7 @@
|
||||
hide-details
|
||||
clearable
|
||||
style="max-width: 300px"
|
||||
/>
|
||||
/> -->
|
||||
</v-card-title>
|
||||
|
||||
<v-divider />
|
||||
@@ -104,7 +104,8 @@
|
||||
:headers="headers"
|
||||
:items="filteredChunks"
|
||||
:loading="loadingChunks"
|
||||
:items-per-page="10"
|
||||
:items-per-page="pageSize"
|
||||
hide-default-footer
|
||||
>
|
||||
<template #item.chunk_index="{ item }">
|
||||
<v-chip size="small" variant="tonal" color="primary">
|
||||
@@ -132,19 +133,13 @@
|
||||
color="info"
|
||||
@click="viewChunk(item)"
|
||||
/>
|
||||
<v-btn
|
||||
icon="mdi-pencil"
|
||||
variant="text"
|
||||
size="small"
|
||||
color="primary"
|
||||
@click="editChunk(item)"
|
||||
/>
|
||||
<!-- 删除 -->
|
||||
<v-btn
|
||||
icon="mdi-delete"
|
||||
variant="text"
|
||||
size="small"
|
||||
color="error"
|
||||
@click="confirmDeleteChunk(item)"
|
||||
@click="deleteChunk(item)"
|
||||
/>
|
||||
</template>
|
||||
|
||||
@@ -155,6 +150,31 @@
|
||||
</div>
|
||||
</template>
|
||||
</v-data-table>
|
||||
|
||||
|
||||
<!-- 自定义分页器 -->
|
||||
<div v-if="!searchQuery && totalChunks > 0" class="pa-4 d-flex align-center justify-space-between">
|
||||
<div class="text-caption text-medium-emphasis">
|
||||
{{ t('chunks.showing') }} {{ (page - 1) * pageSize + 1 }} - {{ Math.min(page * pageSize, totalChunks) }} / {{ totalChunks }}
|
||||
</div>
|
||||
<div class="d-flex align-center gap-2">
|
||||
<v-select
|
||||
v-model="pageSize"
|
||||
:items="[10, 25, 50, 100]"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details
|
||||
style="width: 100px"
|
||||
@update:model-value="handlePageSizeChange"
|
||||
/>
|
||||
<v-pagination
|
||||
v-model="page"
|
||||
:length="Math.ceil(totalChunks / pageSize)"
|
||||
:total-visible="5"
|
||||
@update:model-value="handlePageChange"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</div>
|
||||
@@ -191,7 +211,7 @@
|
||||
<v-icon>mdi-key</v-icon>
|
||||
</template>
|
||||
<v-list-item-title>{{ t('view.vecDocId') }}</v-list-item-title>
|
||||
<v-list-item-subtitle>{{ selectedChunk?.vec_doc_id || '-' }}</v-list-item-subtitle>
|
||||
<v-list-item-subtitle>{{ selectedChunk?.chunk_id || '-' }}</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
|
||||
@@ -212,71 +232,6 @@
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 编辑分块对话框 -->
|
||||
<v-dialog v-model="showEditDialog" max-width="800px" persistent scrollable>
|
||||
<v-card>
|
||||
<v-card-title class="pa-4">
|
||||
<span>{{ t('edit.title') }}</span>
|
||||
<v-spacer />
|
||||
<v-btn icon="mdi-close" variant="text" @click="closeEditDialog" />
|
||||
</v-card-title>
|
||||
<v-divider />
|
||||
<v-card-text class="pa-6">
|
||||
<v-textarea
|
||||
v-model="editForm.content"
|
||||
:label="t('edit.content')"
|
||||
variant="outlined"
|
||||
rows="15"
|
||||
auto-grow
|
||||
/>
|
||||
</v-card-text>
|
||||
<v-divider />
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer />
|
||||
<v-btn variant="text" @click="closeEditDialog">
|
||||
{{ t('edit.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
variant="elevated"
|
||||
@click="saveChunk"
|
||||
:loading="saving"
|
||||
>
|
||||
{{ t('edit.save') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 删除确认对话框 -->
|
||||
<v-dialog v-model="showDeleteDialog" max-width="450px">
|
||||
<v-card>
|
||||
<v-card-title class="pa-4 text-h6">{{ t('delete.title') }}</v-card-title>
|
||||
<v-divider />
|
||||
<v-card-text class="pa-6">
|
||||
<p>{{ t('delete.confirmText') }}</p>
|
||||
<v-alert type="warning" variant="tonal" density="compact" class="mt-4">
|
||||
{{ t('delete.warning') }}
|
||||
</v-alert>
|
||||
</v-card-text>
|
||||
<v-divider />
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer />
|
||||
<v-btn variant="text" @click="showDeleteDialog = false">
|
||||
{{ t('delete.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="error"
|
||||
variant="elevated"
|
||||
@click="deleteChunk"
|
||||
:loading="deleting"
|
||||
>
|
||||
{{ t('delete.confirm') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 消息提示 -->
|
||||
<v-snackbar v-model="snackbar.show" :color="snackbar.color">
|
||||
{{ snackbar.text }}
|
||||
@@ -299,16 +254,16 @@ const docId = ref(route.params.docId as string)
|
||||
// 状态
|
||||
const loading = ref(true)
|
||||
const loadingChunks = ref(false)
|
||||
const saving = ref(false)
|
||||
const deleting = ref(false)
|
||||
const document = ref<any>({})
|
||||
const chunks = ref<any[]>([])
|
||||
const searchQuery = ref('')
|
||||
const showViewDialog = ref(false)
|
||||
const showEditDialog = ref(false)
|
||||
const showDeleteDialog = ref(false)
|
||||
const selectedChunk = ref<any>(null)
|
||||
const deleteTarget = ref<any>(null)
|
||||
|
||||
// 分页状态
|
||||
const page = ref(1)
|
||||
const pageSize = ref(10)
|
||||
const totalChunks = ref(0)
|
||||
|
||||
const snackbar = ref({
|
||||
show: false,
|
||||
@@ -322,17 +277,12 @@ const showSnackbar = (text: string, color: string = 'success') => {
|
||||
snackbar.value.show = true
|
||||
}
|
||||
|
||||
// 编辑表单
|
||||
const editForm = ref({
|
||||
content: ''
|
||||
})
|
||||
|
||||
// 表格列
|
||||
const headers = [
|
||||
{ title: t('chunks.index'), key: 'chunk_index', width: 100 },
|
||||
{ title: t('chunks.content'), key: 'content', sortable: false },
|
||||
{ title: t('chunks.charCount'), key: 'char_count', width: 150 },
|
||||
{ title: t('chunks.actions'), key: 'actions', sortable: false, align: 'end', width: 150 }
|
||||
{ title: t('chunks.actions'), key: 'actions', sortable: false, width: 150 }
|
||||
]
|
||||
|
||||
// 过滤分块
|
||||
@@ -349,7 +299,7 @@ const loadDocument = async () => {
|
||||
loading.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/kb/document/get', {
|
||||
params: { doc_id: docId.value }
|
||||
params: { doc_id: docId.value, kb_id: kbId.value }
|
||||
})
|
||||
if (response.data.status === 'ok') {
|
||||
document.value = response.data.data
|
||||
@@ -367,10 +317,16 @@ const loadChunks = async () => {
|
||||
loadingChunks.value = true
|
||||
try {
|
||||
const response = await axios.get('/api/kb/chunk/list', {
|
||||
params: { doc_id: docId.value }
|
||||
params: {
|
||||
doc_id: docId.value,
|
||||
kb_id: kbId.value,
|
||||
page: page.value,
|
||||
page_size: pageSize.value
|
||||
}
|
||||
})
|
||||
if (response.data.status === 'ok') {
|
||||
chunks.value = response.data.data.items || []
|
||||
totalChunks.value = response.data.data.total || 0
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to load chunks:', error)
|
||||
@@ -380,81 +336,42 @@ const loadChunks = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
// 处理分页变化
|
||||
const handlePageChange = (newPage: number) => {
|
||||
page.value = newPage
|
||||
loadChunks()
|
||||
}
|
||||
|
||||
const handlePageSizeChange = (newPageSize: number) => {
|
||||
pageSize.value = newPageSize
|
||||
page.value = 1
|
||||
loadChunks()
|
||||
}
|
||||
|
||||
// 查看分块
|
||||
const viewChunk = (chunk: any) => {
|
||||
selectedChunk.value = chunk
|
||||
showViewDialog.value = true
|
||||
}
|
||||
|
||||
// 编辑分块
|
||||
const editChunk = (chunk: any) => {
|
||||
selectedChunk.value = chunk
|
||||
editForm.value.content = chunk.content
|
||||
showEditDialog.value = true
|
||||
}
|
||||
|
||||
// 关闭编辑对话框
|
||||
const closeEditDialog = () => {
|
||||
showEditDialog.value = false
|
||||
selectedChunk.value = null
|
||||
editForm.value.content = ''
|
||||
}
|
||||
|
||||
// 保存分块
|
||||
const saveChunk = async () => {
|
||||
if (!selectedChunk.value) return
|
||||
|
||||
saving.value = true
|
||||
try {
|
||||
const response = await axios.post('/api/kb/chunk/update', {
|
||||
chunk_id: selectedChunk.value.chunk_id,
|
||||
content: editForm.value.content
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
showSnackbar(t('edit.saveSuccess'))
|
||||
closeEditDialog()
|
||||
await loadChunks()
|
||||
} else {
|
||||
showSnackbar(response.data.message || t('edit.saveFailed'), 'error')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to save chunk:', error)
|
||||
showSnackbar(t('edit.saveFailed'), 'error')
|
||||
} finally {
|
||||
saving.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 确认删除分块
|
||||
const confirmDeleteChunk = (chunk: any) => {
|
||||
deleteTarget.value = chunk
|
||||
showDeleteDialog.value = true
|
||||
}
|
||||
|
||||
// 删除分块
|
||||
const deleteChunk = async () => {
|
||||
if (!deleteTarget.value) return
|
||||
|
||||
deleting.value = true
|
||||
const deleteChunk = async (chunk: any) => {
|
||||
if (!confirm(t('chunks.deleteConfirm'))) return
|
||||
try {
|
||||
const response = await axios.post('/api/kb/chunk/delete', {
|
||||
chunk_id: deleteTarget.value.chunk_id
|
||||
chunk_id: chunk.chunk_id,
|
||||
doc_id: docId.value,
|
||||
kb_id: kbId.value
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
showSnackbar(t('delete.deleteSuccess'))
|
||||
showDeleteDialog.value = false
|
||||
await loadChunks()
|
||||
await loadDocument()
|
||||
showSnackbar(t('chunks.deleteSuccess'))
|
||||
loadChunks()
|
||||
} else {
|
||||
showSnackbar(response.data.message || t('delete.deleteFailed'), 'error')
|
||||
showSnackbar(t('chunks.deleteFailed'), 'error')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to delete chunk:', error)
|
||||
showSnackbar(t('delete.deleteFailed'), 'error')
|
||||
} finally {
|
||||
deleting.value = false
|
||||
showSnackbar(t('chunks.deleteFailed'), 'error')
|
||||
}
|
||||
}
|
||||
|
||||
@@ -582,6 +499,10 @@ onMounted(() => {
|
||||
font-family: 'Consolas', 'Monaco', monospace;
|
||||
}
|
||||
|
||||
.gap-2 {
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
/* 响应式设计 */
|
||||
@media (max-width: 768px) {
|
||||
.document-detail-page {
|
||||
|
||||
@@ -40,10 +40,6 @@
|
||||
<v-icon start>mdi-magnify</v-icon>
|
||||
{{ t('tabs.retrieval') }}
|
||||
</v-tab>
|
||||
<v-tab value="sessions">
|
||||
<v-icon start>mdi-account-multiple</v-icon>
|
||||
{{ t('tabs.sessions') }}
|
||||
</v-tab>
|
||||
<v-tab value="settings">
|
||||
<v-icon start>mdi-cog</v-icon>
|
||||
{{ t('tabs.settings') }}
|
||||
@@ -51,7 +47,7 @@
|
||||
</v-tabs>
|
||||
|
||||
<!-- 标签页内容 -->
|
||||
<v-window v-model="activeTab">
|
||||
<v-window v-model="activeTab" style="padding: 8px;">
|
||||
<!-- 概览 -->
|
||||
<v-window-item value="overview">
|
||||
<v-row>
|
||||
@@ -163,12 +159,7 @@
|
||||
|
||||
<!-- 知识库检索 -->
|
||||
<v-window-item value="retrieval">
|
||||
<RetrievalTab :kb-id="kbId" />
|
||||
</v-window-item>
|
||||
|
||||
<!-- 使用会话 -->
|
||||
<v-window-item value="sessions">
|
||||
<SessionsTab :kb-id="kbId" />
|
||||
<RetrievalTab :kb-id="kbId" :kb-name="kb.kb_name"/>
|
||||
</v-window-item>
|
||||
|
||||
<!-- 设置 -->
|
||||
@@ -192,7 +183,6 @@ import axios from 'axios'
|
||||
import { useModuleI18n } from '@/i18n/composables'
|
||||
import DocumentsTab from './components/DocumentsTab.vue'
|
||||
import RetrievalTab from './components/RetrievalTab.vue'
|
||||
import SessionsTab from './components/SessionsTab.vue'
|
||||
import SettingsTab from './components/SettingsTab.vue'
|
||||
|
||||
const { tm: t } = useModuleI18n('features/knowledge-base/detail')
|
||||
@@ -255,7 +245,6 @@ onMounted(() => {
|
||||
|
||||
<style scoped>
|
||||
.kb-detail-page {
|
||||
padding: 24px;
|
||||
max-width: 1400px;
|
||||
margin: 0 auto;
|
||||
animation: fadeIn 0.3s ease;
|
||||
@@ -329,12 +318,11 @@ onMounted(() => {
|
||||
padding: 24px;
|
||||
text-align: center;
|
||||
border-radius: 12px;
|
||||
background: rgba(var(--v-theme-surface-variant), 0.3);
|
||||
background: rgba(var(--v-theme-surface-variant), 0.1);
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.stat-box:hover {
|
||||
transform: translateY(-4px);
|
||||
background: rgba(var(--v-theme-surface-variant), 0.5);
|
||||
}
|
||||
|
||||
@@ -342,21 +330,15 @@ onMounted(() => {
|
||||
font-size: 2rem;
|
||||
font-weight: 600;
|
||||
margin-top: 8px;
|
||||
color: rgb(var(--v-theme-on-surface));
|
||||
}
|
||||
|
||||
.stat-label {
|
||||
font-size: 0.875rem;
|
||||
color: rgb(var(--v-theme-on-surface-variant));
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
/* 响应式设计 */
|
||||
@media (max-width: 768px) {
|
||||
.kb-detail-page {
|
||||
padding: 16px;
|
||||
}
|
||||
|
||||
.kb-title {
|
||||
flex-direction: column;
|
||||
align-items: flex-start;
|
||||
|
||||
@@ -6,32 +6,16 @@
|
||||
<h1 class="text-h4 mb-2">{{ t('list.title') }}</h1>
|
||||
<p class="text-subtitle-1 text-medium-emphasis">{{ t('list.subtitle') }}</p>
|
||||
</div>
|
||||
<v-btn
|
||||
icon="mdi-information-outline"
|
||||
variant="text"
|
||||
size="small"
|
||||
color="grey"
|
||||
href="https://astrbot.app/use/knowledge-base.html"
|
||||
target="_blank"
|
||||
/>
|
||||
<v-btn icon="mdi-information-outline" variant="text" size="small" color="grey"
|
||||
href="https://astrbot.app/use/knowledge-base.html" target="_blank" />
|
||||
</div>
|
||||
|
||||
<!-- 操作按钮栏 -->
|
||||
<div class="action-bar mb-6">
|
||||
<v-btn
|
||||
prepend-icon="mdi-plus"
|
||||
color="primary"
|
||||
variant="elevated"
|
||||
@click="showCreateDialog = true"
|
||||
>
|
||||
<v-btn prepend-icon="mdi-plus" color="primary" variant="elevated" @click="showCreateDialog = true">
|
||||
{{ t('list.create') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
prepend-icon="mdi-refresh"
|
||||
variant="tonal"
|
||||
@click="loadKnowledgeBases"
|
||||
:loading="loading"
|
||||
>
|
||||
<v-btn prepend-icon="mdi-refresh" variant="tonal" @click="loadKnowledgeBases" :loading="loading">
|
||||
{{ t('list.refresh') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
@@ -43,14 +27,8 @@
|
||||
</div>
|
||||
|
||||
<div v-else-if="kbList.length > 0" class="kb-grid">
|
||||
<v-card
|
||||
v-for="kb in kbList"
|
||||
:key="kb.kb_id"
|
||||
class="kb-card"
|
||||
elevation="2"
|
||||
hover
|
||||
@click="navigateToDetail(kb.kb_id)"
|
||||
>
|
||||
<v-card v-for="kb in kbList" :key="kb.kb_id" class="kb-card" elevation="2" hover
|
||||
@click="navigateToDetail(kb.kb_id)">
|
||||
<div class="kb-card-content">
|
||||
<div class="kb-emoji">{{ kb.emoji || '📚' }}</div>
|
||||
<h3 class="kb-name">{{ kb.kb_name }}</h3>
|
||||
@@ -68,20 +46,8 @@
|
||||
</div>
|
||||
|
||||
<div class="kb-actions">
|
||||
<v-btn
|
||||
icon="mdi-pencil"
|
||||
size="small"
|
||||
variant="text"
|
||||
color="info"
|
||||
@click.stop="editKB(kb)"
|
||||
/>
|
||||
<v-btn
|
||||
icon="mdi-delete"
|
||||
size="small"
|
||||
variant="text"
|
||||
color="error"
|
||||
@click.stop="confirmDelete(kb)"
|
||||
/>
|
||||
<v-btn icon="mdi-pencil" size="small" variant="text" color="info" @click.stop="editKB(kb)" />
|
||||
<v-btn icon="mdi-delete" size="small" variant="text" color="error" @click.stop="confirmDelete(kb)" />
|
||||
</div>
|
||||
</div>
|
||||
</v-card>
|
||||
@@ -91,14 +57,8 @@
|
||||
<div v-else class="empty-state">
|
||||
<v-icon size="100" color="grey-lighten-2">mdi-book-open-variant</v-icon>
|
||||
<h2 class="mt-4">{{ t('list.empty') }}</h2>
|
||||
<v-btn
|
||||
class="mt-6"
|
||||
prepend-icon="mdi-plus"
|
||||
color="primary"
|
||||
variant="elevated"
|
||||
size="large"
|
||||
@click="showCreateDialog = true"
|
||||
>
|
||||
<v-btn class="mt-6" prepend-icon="mdi-plus" color="primary" variant="elevated" size="large"
|
||||
@click="showCreateDialog = true">
|
||||
{{ t('list.create') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
@@ -125,35 +85,16 @@
|
||||
|
||||
<!-- 表单 -->
|
||||
<v-form ref="formRef" @submit.prevent="submitForm">
|
||||
<v-text-field
|
||||
v-model="formData.kb_name"
|
||||
:label="t('create.nameLabel')"
|
||||
:placeholder="t('create.namePlaceholder')"
|
||||
variant="outlined"
|
||||
:rules="[v => !!v || t('create.nameRequired')]"
|
||||
required
|
||||
class="mb-4"
|
||||
/>
|
||||
<v-text-field v-model="formData.kb_name" :label="t('create.nameLabel')"
|
||||
:placeholder="t('create.namePlaceholder')" variant="outlined"
|
||||
:rules="[v => !!v || t('create.nameRequired')]" required class="mb-4" />
|
||||
|
||||
<v-textarea
|
||||
v-model="formData.description"
|
||||
:label="t('create.descriptionLabel')"
|
||||
:placeholder="t('create.descriptionPlaceholder')"
|
||||
variant="outlined"
|
||||
rows="3"
|
||||
class="mb-4"
|
||||
/>
|
||||
<v-textarea v-model="formData.description" :label="t('create.descriptionLabel')"
|
||||
:placeholder="t('create.descriptionPlaceholder')" variant="outlined" rows="3" class="mb-4" />
|
||||
|
||||
<v-select
|
||||
v-model="formData.embedding_provider_id"
|
||||
:items="embeddingProviders"
|
||||
:item-title="item => item.embedding_model || item.id"
|
||||
:item-value="'id'"
|
||||
:label="t('create.embeddingModelLabel')"
|
||||
variant="outlined"
|
||||
class="mb-4"
|
||||
@update:model-value="handleEmbeddingProviderChange"
|
||||
>
|
||||
<v-select v-model="formData.embedding_provider_id" :items="embeddingProviders"
|
||||
:item-title="item => item.embedding_model || item.id" :item-value="'id'"
|
||||
:label="t('create.embeddingModelLabel')" variant="outlined" class="mb-4" :disabled="editingKB !== null">
|
||||
<template #item="{ props, item }">
|
||||
<v-list-item v-bind="props">
|
||||
<template #subtitle>
|
||||
@@ -166,20 +107,9 @@
|
||||
</template>
|
||||
</v-select>
|
||||
|
||||
<v-alert type="warning" variant="tonal" density="compact" class="mb-4" v-if="editingKB && showEmbeddingWarning">
|
||||
<strong>注意:</strong> 修改嵌入模型会导致现有的向量数据失效,建议重新上传文档。不同的嵌入模型生成的向量不兼容,可能导致检索结果不准确。
|
||||
</v-alert>
|
||||
|
||||
<v-select
|
||||
v-model="formData.rerank_provider_id"
|
||||
:items="rerankProviders"
|
||||
:item-title="item => item.rerank_model || item.id"
|
||||
:item-value="'id'"
|
||||
:label="t('create.rerankModelLabel')"
|
||||
variant="outlined"
|
||||
clearable
|
||||
class="mb-2"
|
||||
>
|
||||
<v-select v-model="formData.rerank_provider_id" :items="rerankProviders"
|
||||
:item-title="item => item.rerank_model || item.id" :item-value="'id'"
|
||||
:label="t('create.rerankModelLabel')" variant="outlined" clearable class="mb-2">
|
||||
<template #item="{ props, item }">
|
||||
<v-list-item v-bind="props">
|
||||
<template #subtitle>
|
||||
@@ -202,12 +132,7 @@
|
||||
<v-btn variant="text" @click="closeCreateDialog">
|
||||
{{ t('create.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
variant="elevated"
|
||||
@click="submitForm"
|
||||
:loading="saving"
|
||||
>
|
||||
<v-btn color="primary" variant="elevated" @click="submitForm" :loading="saving">
|
||||
{{ editingKB ? t('edit.submit') : t('create.submit') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
@@ -223,12 +148,7 @@
|
||||
<div v-for="category in emojiCategories" :key="category.key" class="mb-4">
|
||||
<p class="text-subtitle-2 mb-2">{{ t(`emoji.categories.${category.key}`) }}</p>
|
||||
<div class="emoji-grid">
|
||||
<div
|
||||
v-for="emoji in category.emojis"
|
||||
:key="emoji"
|
||||
class="emoji-item"
|
||||
@click="selectEmoji(emoji)"
|
||||
>
|
||||
<div v-for="emoji in category.emojis" :key="emoji" class="emoji-item" @click="selectEmoji(emoji)">
|
||||
{{ emoji }}
|
||||
</div>
|
||||
</div>
|
||||
@@ -261,12 +181,7 @@
|
||||
<v-btn variant="text" @click="cancelDelete">
|
||||
{{ t('delete.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="error"
|
||||
variant="elevated"
|
||||
@click="deleteKB"
|
||||
:loading="deleting"
|
||||
>
|
||||
<v-btn color="error" variant="elevated" @click="deleteKB" :loading="deleting">
|
||||
{{ t('delete.confirm') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
@@ -278,38 +193,10 @@
|
||||
{{ snackbar.text }}
|
||||
</v-snackbar>
|
||||
|
||||
<!-- Embedding Provider 修改确认对话框 -->
|
||||
<v-dialog v-model="embeddingChangeDialog" max-width="500px" persistent>
|
||||
<v-card>
|
||||
<v-card-title class="bg-warning text-white">
|
||||
<v-icon class="mr-2">mdi-alert</v-icon>
|
||||
确认修改嵌入模型
|
||||
</v-card-title>
|
||||
<v-card-text class="pa-6">
|
||||
<v-alert type="warning" variant="tonal" class="mb-4">
|
||||
<strong>警告:</strong> 修改嵌入模型将导致以下影响:
|
||||
</v-alert>
|
||||
<ul class="text-body-2">
|
||||
<li>现有的向量数据将失效</li>
|
||||
<li>检索功能可能无法正常工作</li>
|
||||
<li>建议删除现有文档后重新上传</li>
|
||||
<li>不同嵌入模型生成的向量不兼容</li>
|
||||
</ul>
|
||||
<div class="mt-4 text-body-2">
|
||||
您确定要将嵌入模型从 <strong>{{ originalEmbeddingProvider }}</strong> 修改为 <strong>{{ pendingEmbeddingProvider }}</strong> 吗?
|
||||
</div>
|
||||
</v-card-text>
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer />
|
||||
<v-btn variant="text" @click="cancelEmbeddingChange">
|
||||
取消
|
||||
</v-btn>
|
||||
<v-btn color="warning" variant="elevated" @click="confirmEmbeddingChange">
|
||||
确认修改
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
<div class="position-absolute" style="bottom: 0px; right: 16px;">
|
||||
<small @click="router.push('/alkaid/knowledge-base')"><a style="text-decoration: underline; cursor: pointer;">切换到旧版知识库</a></small>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -439,39 +326,6 @@ const editKB = (kb: any) => {
|
||||
showCreateDialog.value = true
|
||||
}
|
||||
|
||||
// 处理 embedding provider 变更
|
||||
const handleEmbeddingProviderChange = (newValue: string | null) => {
|
||||
// 检测是否修改了embedding provider
|
||||
if (newValue && originalEmbeddingProvider.value && newValue !== originalEmbeddingProvider.value) {
|
||||
// 显示二次确认对话框
|
||||
showEmbeddingWarning.value = true
|
||||
pendingEmbeddingProvider.value = newValue
|
||||
embeddingChangeDialog.value = true
|
||||
} else {
|
||||
showEmbeddingWarning.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 确认修改 embedding provider
|
||||
const confirmEmbeddingChange = () => {
|
||||
if (pendingEmbeddingProvider.value) {
|
||||
formData.value.embedding_provider_id = pendingEmbeddingProvider.value
|
||||
// 更新原始值,这样下次比较时不会重复弹窗
|
||||
originalEmbeddingProvider.value = pendingEmbeddingProvider.value
|
||||
}
|
||||
embeddingChangeDialog.value = false
|
||||
showEmbeddingWarning.value = true
|
||||
}
|
||||
|
||||
// 取消修改 embedding provider
|
||||
const cancelEmbeddingChange = () => {
|
||||
// 恢复到原始值
|
||||
formData.value.embedding_provider_id = originalEmbeddingProvider.value
|
||||
embeddingChangeDialog.value = false
|
||||
showEmbeddingWarning.value = false
|
||||
pendingEmbeddingProvider.value = null
|
||||
}
|
||||
|
||||
// 确认删除
|
||||
const confirmDelete = (kb: any) => {
|
||||
deleteTarget.value = kb
|
||||
@@ -640,13 +494,7 @@ onMounted(() => {
|
||||
|
||||
.kb-emoji {
|
||||
font-size: 56px;
|
||||
margin-bottom: 16px;
|
||||
animation: float 3s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes float {
|
||||
0%, 100% { transform: translateY(0); }
|
||||
50% { transform: translateY(-10px); }
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.kb-name {
|
||||
|
||||
@@ -2,41 +2,35 @@
|
||||
<div class="documents-tab">
|
||||
<!-- 操作栏 -->
|
||||
<div class="action-bar mb-4">
|
||||
<v-btn
|
||||
prepend-icon="mdi-upload"
|
||||
color="primary"
|
||||
variant="elevated"
|
||||
@click="showUploadDialog = true"
|
||||
>
|
||||
<v-btn prepend-icon="mdi-upload" color="primary" variant="elevated" @click="showUploadDialog = true">
|
||||
{{ t('documents.upload') }}
|
||||
</v-btn>
|
||||
<v-text-field
|
||||
v-model="searchQuery"
|
||||
prepend-inner-icon="mdi-magnify"
|
||||
:placeholder="'搜索文档...'"
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
hide-details
|
||||
clearable
|
||||
style="max-width: 300px"
|
||||
/>
|
||||
<v-text-field v-model="searchQuery" prepend-inner-icon="mdi-magnify" :placeholder="'搜索文档...'" variant="outlined"
|
||||
density="compact" hide-details clearable style="max-width: 300px" />
|
||||
</div>
|
||||
|
||||
<!-- 文档列表 -->
|
||||
<v-card elevation="2">
|
||||
<v-data-table
|
||||
:headers="headers"
|
||||
:items="documents"
|
||||
:loading="loading"
|
||||
:search="searchQuery"
|
||||
:items-per-page="10"
|
||||
>
|
||||
<v-data-table :headers="headers" :items="documents" :loading="loading" :search="searchQuery" :items-per-page="10">
|
||||
<template #item.doc_name="{ item }">
|
||||
<div class="d-flex align-center gap-2">
|
||||
<v-icon :color="getFileColor(item.file_type)">
|
||||
<v-icon :color="getFileColor(item.file_type)" class="mr-2">
|
||||
{{ getFileIcon(item.file_type) }}
|
||||
</v-icon>
|
||||
<span class="font-weight-medium">{{ item.doc_name }}</span>
|
||||
<div class="flex-grow-1" style="padding: 4px 0px;">
|
||||
<span class="font-weight-medium">{{ item.doc_name }}</span>
|
||||
<!-- 上传进度 -->
|
||||
<div v-if="item.uploading" class="mt-1">
|
||||
<div class="text-caption text-medium-emphasis mb-1">
|
||||
{{ getStageText(item.uploadProgress?.stage || 'waiting') }}
|
||||
<span v-if="item.uploadProgress?.current">
|
||||
({{ item.uploadProgress.current }} / {{ item.uploadProgress.total }})
|
||||
</span>
|
||||
</div>
|
||||
<v-progress-linear :model-value="getUploadPercentage(item)" color="primary" height="4" rounded
|
||||
striped />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -49,20 +43,8 @@
|
||||
</template>
|
||||
|
||||
<template #item.actions="{ item }">
|
||||
<v-btn
|
||||
icon="mdi-eye"
|
||||
variant="text"
|
||||
size="small"
|
||||
color="info"
|
||||
@click="viewDocument(item)"
|
||||
/>
|
||||
<v-btn
|
||||
icon="mdi-delete"
|
||||
variant="text"
|
||||
size="small"
|
||||
color="error"
|
||||
@click="confirmDelete(item)"
|
||||
/>
|
||||
<v-btn icon="mdi-eye" variant="text" size="small" color="info" @click="viewDocument(item)" />
|
||||
<v-btn icon="mdi-delete" variant="text" size="small" color="error" @click="confirmDelete(item)" />
|
||||
</template>
|
||||
|
||||
<template #no-data>
|
||||
@@ -75,9 +57,9 @@
|
||||
</v-card>
|
||||
|
||||
<!-- 上传对话框 -->
|
||||
<v-dialog v-model="showUploadDialog" max-width="600px" persistent>
|
||||
<v-dialog v-model="showUploadDialog" max-width="600px" persistent @after-enter="initUploadSettings">
|
||||
<v-card>
|
||||
<v-card-title class="pa-4">
|
||||
<v-card-title class="pa-4 d-flex align-center">
|
||||
<span class="text-h5">{{ t('upload.title') }}</span>
|
||||
<v-spacer />
|
||||
<v-btn icon="mdi-close" variant="text" @click="closeUploadDialog" />
|
||||
@@ -87,84 +69,88 @@
|
||||
|
||||
<v-card-text class="pa-6">
|
||||
<!-- 文件选择 -->
|
||||
<div
|
||||
class="upload-dropzone"
|
||||
:class="{ 'dragover': isDragging }"
|
||||
@drop.prevent="handleDrop"
|
||||
@dragover.prevent="isDragging = true"
|
||||
@dragleave="isDragging = false"
|
||||
@click="$refs.fileInput.click()"
|
||||
>
|
||||
<div class="upload-dropzone" :class="{ 'dragover': isDragging }" @drop.prevent="handleDrop"
|
||||
@dragover.prevent="isDragging = true" @dragleave="isDragging = false" @click="fileInput?.click()">
|
||||
<v-icon size="64" color="primary">mdi-cloud-upload</v-icon>
|
||||
<p class="mt-4 text-h6">{{ t('upload.dropzone') }}</p>
|
||||
<p class="text-caption text-medium-emphasis mt-2">{{ t('upload.supportedFormats') }}</p>
|
||||
<p class="text-caption text-medium-emphasis">{{ t('upload.maxSize') }}</p>
|
||||
<input
|
||||
ref="fileInput"
|
||||
type="file"
|
||||
hidden
|
||||
accept=".txt,.md,.pdf"
|
||||
@change="handleFileSelect"
|
||||
/>
|
||||
<p class="text-caption text-medium-emphasis">最多可上传 10 个文件</p>
|
||||
<input ref="fileInput" type="file" multiple hidden accept=".txt,.md,.pdf" @change="handleFileSelect" />
|
||||
</div>
|
||||
|
||||
<div v-if="selectedFile" class="mt-4 pa-4 rounded bg-surface-variant">
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<div class="d-flex align-center gap-2">
|
||||
<v-icon>{{ getFileIcon(selectedFile.name) }}</v-icon>
|
||||
<div>
|
||||
<div class="font-weight-medium">{{ selectedFile.name }}</div>
|
||||
<div class="text-caption">{{ formatFileSize(selectedFile.size) }}</div>
|
||||
<div v-if="selectedFiles.length > 0" class="mt-4">
|
||||
<div class="d-flex align-center justify-space-between mb-2">
|
||||
<span class="text-subtitle-2">已选择 {{ selectedFiles.length }} 个文件</span>
|
||||
<v-btn variant="text" size="small" @click="selectedFiles = []">清空</v-btn>
|
||||
</div>
|
||||
<div class="files-list">
|
||||
<div v-for="(file, index) in selectedFiles" :key="index"
|
||||
class="file-item pa-3 mb-2 rounded bg-surface-variant">
|
||||
<div class="d-flex align-center justify-space-between">
|
||||
<div class="d-flex align-center gap-2">
|
||||
<v-icon>{{ getFileIcon(file.name) }}</v-icon>
|
||||
<div>
|
||||
<div class="font-weight-medium">{{ file.name }}</div>
|
||||
<div class="text-caption">{{ formatFileSize(file.size) }}</div>
|
||||
</div>
|
||||
</div>
|
||||
<v-btn icon="mdi-close" variant="text" size="small" @click="removeFile(index)" />
|
||||
</div>
|
||||
</div>
|
||||
<v-btn icon="mdi-close" variant="text" size="small" @click="selectedFile = null" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 分块设置 -->
|
||||
<v-expansion-panels class="mt-4">
|
||||
<v-expansion-panel>
|
||||
<v-expansion-panel-title>
|
||||
<v-icon start>mdi-cog</v-icon>
|
||||
{{ t('upload.chunkSettings') }}
|
||||
</v-expansion-panel-title>
|
||||
<v-expansion-panel-text>
|
||||
<v-text-field
|
||||
v-model.number="uploadSettings.chunk_size"
|
||||
:label="t('upload.chunkSize')"
|
||||
:hint="t('upload.chunkSizeHint')"
|
||||
type="number"
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
class="mb-2"
|
||||
/>
|
||||
<v-text-field
|
||||
v-model.number="uploadSettings.chunk_overlap"
|
||||
:label="t('upload.chunkOverlap')"
|
||||
:hint="t('upload.chunkOverlapHint')"
|
||||
type="number"
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
/>
|
||||
</v-expansion-panel-text>
|
||||
</v-expansion-panel>
|
||||
</v-expansion-panels>
|
||||
<div class="mt-6">
|
||||
<div class="d-flex align-center mb-4">
|
||||
<h3 class="text-h6">{{ t('upload.chunkSettings') }}</h3>
|
||||
</div>
|
||||
<v-row>
|
||||
<v-col cols="12" sm="6">
|
||||
<v-text-field v-model.number="uploadSettings.chunk_size" :label="t('upload.chunkSize')"
|
||||
:hint="t('upload.chunkSizeHint')" persistent-hint type="number" variant="outlined" density="compact"
|
||||
:placeholder="props.kb?.chunk_size?.toString() || '512'" />
|
||||
</v-col>
|
||||
<v-col cols="12" sm="6">
|
||||
<v-text-field v-model.number="uploadSettings.chunk_overlap" :label="t('upload.chunkOverlap')"
|
||||
:hint="t('upload.chunkOverlapHint')" persistent-hint type="number" variant="outlined"
|
||||
density="compact" :placeholder="props.kb?.chunk_overlap?.toString() || '50'" />
|
||||
</v-col>
|
||||
</v-row>
|
||||
</div>
|
||||
|
||||
<div class="mt-2">
|
||||
<h3 class="text-h6 mb-4">{{ t('upload.batchSettings') }}</h3>
|
||||
<v-row>
|
||||
<v-col cols="12" sm="4">
|
||||
<v-text-field v-model.number="uploadSettings.batch_size" :label="t('upload.batchSize')" hint="每批处理的文本数量"
|
||||
persistent-hint type="number" variant="outlined" density="compact" />
|
||||
</v-col>
|
||||
<v-col cols="12" sm="4">
|
||||
<v-text-field v-model.number="uploadSettings.tasks_limit" :label="t('upload.tasksLimit')"
|
||||
hint="并发任务数量限制" persistent-hint type="number" variant="outlined" density="compact" />
|
||||
</v-col>
|
||||
<v-col cols="12" sm="4">
|
||||
<v-text-field v-model.number="uploadSettings.max_retries" :label="t('upload.maxRetries')"
|
||||
hint="失败时的最大重试次数" persistent-hint type="number" variant="outlined" density="compact" />
|
||||
</v-col>
|
||||
</v-row>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
</v-card-text>
|
||||
|
||||
<v-divider />
|
||||
|
||||
<v-card-actions class="pa-4">
|
||||
<v-spacer />
|
||||
<v-btn variant="text" @click="closeUploadDialog">
|
||||
<v-btn variant="text" @click="closeUploadDialog" :disabled="uploading">
|
||||
{{ t('upload.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
variant="elevated"
|
||||
@click="uploadDocument"
|
||||
:loading="uploading"
|
||||
:disabled="!selectedFile"
|
||||
>
|
||||
<v-btn color="primary" variant="elevated" @click="uploadDocument" :loading="uploading"
|
||||
:disabled="selectedFiles.length === 0">
|
||||
{{ t('upload.submit') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
@@ -201,7 +187,7 @@
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
import { useRouter } from 'vue-router'
|
||||
import axios from 'axios'
|
||||
import { useModuleI18n } from '@/i18n/composables'
|
||||
@@ -224,9 +210,14 @@ const documents = ref<any[]>([])
|
||||
const searchQuery = ref('')
|
||||
const showUploadDialog = ref(false)
|
||||
const showDeleteDialog = ref(false)
|
||||
const selectedFile = ref<File | null>(null)
|
||||
const selectedFiles = ref<File[]>([])
|
||||
const deleteTarget = ref<any>(null)
|
||||
const isDragging = ref(false)
|
||||
const fileInput = ref<HTMLInputElement | null>(null)
|
||||
|
||||
// 上传进度 - 用于轮询多个任务
|
||||
const uploadingTasks = ref<Map<string, any>>(new Map())
|
||||
const progressPollingInterval = ref<number | null>(null)
|
||||
|
||||
const snackbar = ref({
|
||||
show: false,
|
||||
@@ -242,10 +233,24 @@ const showSnackbar = (text: string, color: string = 'success') => {
|
||||
|
||||
// 上传设置
|
||||
const uploadSettings = ref({
|
||||
chunk_size: null,
|
||||
chunk_overlap: null
|
||||
chunk_size: null as number | null,
|
||||
chunk_overlap: null as number | null,
|
||||
batch_size: 32,
|
||||
tasks_limit: 3,
|
||||
max_retries: 3
|
||||
})
|
||||
|
||||
// 初始化上传设置
|
||||
const initUploadSettings = () => {
|
||||
uploadSettings.value = {
|
||||
chunk_size: props.kb?.chunk_size || null,
|
||||
chunk_overlap: props.kb?.chunk_overlap || null,
|
||||
batch_size: 32,
|
||||
tasks_limit: 3,
|
||||
max_retries: 3
|
||||
}
|
||||
}
|
||||
|
||||
// 表格列
|
||||
const headers = [
|
||||
{ title: t('documents.name'), key: 'doc_name', sortable: true },
|
||||
@@ -253,7 +258,7 @@ const headers = [
|
||||
{ title: t('documents.size'), key: 'file_size', sortable: true },
|
||||
{ title: t('documents.chunks'), key: 'chunk_count', sortable: true },
|
||||
{ title: t('documents.createdAt'), key: 'created_at', sortable: true },
|
||||
{ title: t('documents.actions'), key: 'actions', sortable: false, align: 'end' }
|
||||
{ title: t('documents.actions'), key: 'actions', sortable: false, align: 'end' as const }
|
||||
]
|
||||
|
||||
// 加载文档列表
|
||||
@@ -277,30 +282,53 @@ const loadDocuments = async () => {
|
||||
// 文件选择
|
||||
const handleFileSelect = (event: Event) => {
|
||||
const target = event.target as HTMLInputElement
|
||||
if (target.files && target.files[0]) {
|
||||
selectedFile.value = target.files[0]
|
||||
if (target.files && target.files.length > 0) {
|
||||
const newFiles = Array.from(target.files)
|
||||
addFiles(newFiles)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加文件(检查数量限制)
|
||||
const addFiles = (files: File[]) => {
|
||||
const totalFiles = selectedFiles.value.length + files.length
|
||||
if (totalFiles > 10) {
|
||||
showSnackbar('最多只能选择 10 个文件', 'warning')
|
||||
return
|
||||
}
|
||||
selectedFiles.value.push(...files)
|
||||
}
|
||||
|
||||
// 移除文件
|
||||
const removeFile = (index: number) => {
|
||||
selectedFiles.value.splice(index, 1)
|
||||
}
|
||||
|
||||
// 拖放上传
|
||||
const handleDrop = (event: DragEvent) => {
|
||||
isDragging.value = false
|
||||
if (event.dataTransfer?.files && event.dataTransfer.files[0]) {
|
||||
selectedFile.value = event.dataTransfer.files[0]
|
||||
if (event.dataTransfer?.files && event.dataTransfer.files.length > 0) {
|
||||
const newFiles = Array.from(event.dataTransfer.files)
|
||||
addFiles(newFiles)
|
||||
}
|
||||
}
|
||||
|
||||
// 上传文档
|
||||
const uploadDocument = async () => {
|
||||
if (!selectedFile.value) {
|
||||
if (selectedFiles.value.length === 0) {
|
||||
showSnackbar(t('upload.fileRequired'), 'warning')
|
||||
return
|
||||
}
|
||||
|
||||
uploading.value = true
|
||||
|
||||
try {
|
||||
const formData = new FormData()
|
||||
formData.append('file', selectedFile.value)
|
||||
|
||||
// 添加所有文件
|
||||
selectedFiles.value.forEach((file, index) => {
|
||||
formData.append(`file${index}`, file)
|
||||
})
|
||||
|
||||
formData.append('kb_id', props.kbId)
|
||||
if (uploadSettings.value.chunk_size) {
|
||||
formData.append('chunk_size', uploadSettings.value.chunk_size.toString())
|
||||
@@ -308,16 +336,47 @@ const uploadDocument = async () => {
|
||||
if (uploadSettings.value.chunk_overlap) {
|
||||
formData.append('chunk_overlap', uploadSettings.value.chunk_overlap.toString())
|
||||
}
|
||||
formData.append('batch_size', uploadSettings.value.batch_size.toString())
|
||||
formData.append('tasks_limit', uploadSettings.value.tasks_limit.toString())
|
||||
formData.append('max_retries', uploadSettings.value.max_retries.toString())
|
||||
|
||||
const response = await axios.post('/api/kb/document/upload', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' }
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
showSnackbar(t('documents.uploadSuccess'))
|
||||
const result = response.data.data
|
||||
const taskId = result.task_id
|
||||
|
||||
showSnackbar(`正在后台上传 ${result.file_count} 个文件...`, 'info')
|
||||
|
||||
// 为每个文件添加占位条目到文档列表
|
||||
const uploadingDocs = selectedFiles.value.map((file, index) => ({
|
||||
doc_id: `uploading_${taskId}_${index}`,
|
||||
doc_name: file.name,
|
||||
file_type: file.name.split('.').pop() || '',
|
||||
file_size: file.size,
|
||||
chunk_count: 0,
|
||||
created_at: new Date().toISOString(),
|
||||
uploading: true,
|
||||
taskId: taskId,
|
||||
uploadProgress: {
|
||||
stage: 'waiting',
|
||||
current: 0,
|
||||
total: 100
|
||||
}
|
||||
}))
|
||||
|
||||
// 添加到文档列表顶部
|
||||
documents.value = [...uploadingDocs, ...documents.value]
|
||||
|
||||
// 关闭对话框
|
||||
closeUploadDialog()
|
||||
await loadDocuments()
|
||||
emit('refresh')
|
||||
|
||||
// 开始轮询进度
|
||||
if (taskId) {
|
||||
startProgressPolling(taskId)
|
||||
}
|
||||
} else {
|
||||
showSnackbar(response.data.message || t('documents.uploadFailed'), 'error')
|
||||
}
|
||||
@@ -329,11 +388,118 @@ const uploadDocument = async () => {
|
||||
}
|
||||
}
|
||||
|
||||
// 开始轮询进度
|
||||
const startProgressPolling = (taskId: string) => {
|
||||
// 如果已经在轮询,先停止
|
||||
if (progressPollingInterval.value) {
|
||||
stopProgressPolling()
|
||||
}
|
||||
|
||||
progressPollingInterval.value = window.setInterval(async () => {
|
||||
try {
|
||||
const response = await axios.get('/api/kb/document/upload/progress', {
|
||||
params: { task_id: taskId }
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
const data = response.data.data
|
||||
const status = data.status
|
||||
|
||||
if (status === 'processing' && data.progress) {
|
||||
// 更新进度
|
||||
const progress = data.progress
|
||||
const fileIndex = progress.file_index || 0
|
||||
|
||||
// 更新对应文件的进度
|
||||
documents.value = documents.value.map(doc => {
|
||||
if (doc.taskId === taskId) {
|
||||
const docIndex = parseInt(doc.doc_id.split('_').pop() || '0')
|
||||
if (docIndex === fileIndex) {
|
||||
return {
|
||||
...doc,
|
||||
uploadProgress: {
|
||||
stage: progress.stage || 'waiting',
|
||||
current: progress.current || 0,
|
||||
total: progress.total || 100
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return doc
|
||||
})
|
||||
} else if (status === 'completed') {
|
||||
// 任务完成
|
||||
stopProgressPolling()
|
||||
|
||||
const result = data.result
|
||||
const successCount = result?.success_count || 0
|
||||
const failedCount = result?.failed_count || 0
|
||||
|
||||
// 移除上传中的占位文档
|
||||
documents.value = documents.value.filter(doc => doc.taskId !== taskId)
|
||||
|
||||
// 重新加载文档列表
|
||||
await loadDocuments()
|
||||
emit('refresh')
|
||||
|
||||
if (failedCount === 0) {
|
||||
showSnackbar(`成功上传 ${successCount} 个文档`)
|
||||
} else {
|
||||
showSnackbar(`上传完成: ${successCount} 个成功, ${failedCount} 个失败`, 'warning')
|
||||
}
|
||||
} else if (status === 'failed') {
|
||||
// 任务失败
|
||||
stopProgressPolling()
|
||||
|
||||
// 移除上传中的占位文档
|
||||
documents.value = documents.value.filter(doc => doc.taskId !== taskId)
|
||||
|
||||
showSnackbar(`上传失败: ${data.error || '未知错误'}`, 'error')
|
||||
}
|
||||
} else {
|
||||
// 任务不存在,停止轮询
|
||||
stopProgressPolling()
|
||||
documents.value = documents.value.filter(doc => doc.taskId !== taskId)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch progress:', error)
|
||||
// 不立即停止,允许重试
|
||||
}
|
||||
}, 500) // 每500ms轮询一次
|
||||
}
|
||||
|
||||
// 停止轮询进度
|
||||
const stopProgressPolling = () => {
|
||||
if (progressPollingInterval.value) {
|
||||
clearInterval(progressPollingInterval.value)
|
||||
progressPollingInterval.value = null
|
||||
}
|
||||
}
|
||||
|
||||
// 获取上传百分比
|
||||
const getUploadPercentage = (item: any) => {
|
||||
if (!item.uploadProgress) return 0
|
||||
const { current, total } = item.uploadProgress
|
||||
if (!total || total === 0) return 0
|
||||
return (current / total) * 100
|
||||
}
|
||||
|
||||
// 获取阶段文本
|
||||
const getStageText = (stage: string) => {
|
||||
const stageMap: Record<string, string> = {
|
||||
'waiting': '等待中...',
|
||||
'parsing': '解析文档...',
|
||||
'chunking': '文本分块...',
|
||||
'embedding': '生成向量...'
|
||||
}
|
||||
return stageMap[stage] || stage
|
||||
}
|
||||
|
||||
// 关闭上传对话框
|
||||
const closeUploadDialog = () => {
|
||||
showUploadDialog.value = false
|
||||
selectedFile.value = null
|
||||
uploadSettings.value = { chunk_size: null, chunk_overlap: null }
|
||||
selectedFiles.value = []
|
||||
initUploadSettings()
|
||||
}
|
||||
|
||||
// 查看文档
|
||||
@@ -357,7 +523,8 @@ const deleteDocument = async () => {
|
||||
deleting.value = true
|
||||
try {
|
||||
const response = await axios.post('/api/kb/document/delete', {
|
||||
doc_id: deleteTarget.value.doc_id
|
||||
doc_id: deleteTarget.value.doc_id,
|
||||
kb_id: props.kbId
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
@@ -419,6 +586,10 @@ const formatDate = (dateStr: string) => {
|
||||
onMounted(() => {
|
||||
loadDocuments()
|
||||
})
|
||||
|
||||
onUnmounted(() => {
|
||||
stopProgressPolling()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
@@ -427,8 +598,13 @@ onMounted(() => {
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; }
|
||||
to { opacity: 1; }
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.action-bar {
|
||||
@@ -456,13 +632,26 @@ onMounted(() => {
|
||||
transform: scale(1.02);
|
||||
}
|
||||
|
||||
.files-list {
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.file-item {
|
||||
transition: all 0.2s ease;
|
||||
}
|
||||
|
||||
.file-item:hover {
|
||||
background: rgba(var(--v-theme-surface-variant), 0.8) !important;
|
||||
}
|
||||
|
||||
@media (max-width: 768px) {
|
||||
.action-bar {
|
||||
flex-direction: column;
|
||||
align-items: stretch;
|
||||
}
|
||||
|
||||
.action-bar > * {
|
||||
.action-bar>* {
|
||||
width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,66 +1,57 @@
|
||||
<template>
|
||||
<div class="retrieval-tab">
|
||||
<v-card elevation="2">
|
||||
<v-card-title class="pa-4">{{ t('retrieval.title') }}</v-card-title>
|
||||
<v-card-subtitle class="px-4 pb-4">
|
||||
<v-card-title class="pa-4 pb-0">{{ t('retrieval.title') }}</v-card-title>
|
||||
<v-card-subtitle class="pb-4 pt-2">
|
||||
{{ t('retrieval.subtitle') }}
|
||||
</v-card-subtitle>
|
||||
|
||||
<v-divider />
|
||||
<v-progress-linear v-if="loading" indeterminate color="primary" height="2" />
|
||||
|
||||
<v-card-text class="pa-6">
|
||||
<!-- 查询输入区域 -->
|
||||
<v-row>
|
||||
<v-row class="mb-4">
|
||||
<v-col cols="12" md="8">
|
||||
<v-textarea
|
||||
v-model="query"
|
||||
:label="t('retrieval.query')"
|
||||
:placeholder="t('retrieval.queryPlaceholder')"
|
||||
variant="outlined"
|
||||
rows="3"
|
||||
auto-grow
|
||||
clearable
|
||||
/>
|
||||
<v-textarea v-model="query" :label="t('retrieval.query')" :placeholder="t('retrieval.queryPlaceholder')"
|
||||
variant="outlined" rows="3" auto-grow clearable />
|
||||
|
||||
<!-- debug -->
|
||||
<div v-if="debugVisualize" class="mt-2">
|
||||
<v-card variant="outlined">
|
||||
<v-img :src="`data:image/png;base64,${debugVisualize}`" :alt="t('retrieval.tsneVisualization')" cover>
|
||||
<template v-slot:placeholder>
|
||||
<div class="d-flex align-center justify-center fill-height">
|
||||
<v-progress-circular indeterminate color="primary" />
|
||||
</div>
|
||||
</template>
|
||||
</v-img>
|
||||
</v-card>
|
||||
</div>
|
||||
</v-col>
|
||||
<v-col cols="12" md="4">
|
||||
<v-card variant="outlined" class="pa-4">
|
||||
<h4 class="text-subtitle-2 mb-3">{{ t('retrieval.settings') }}</h4>
|
||||
|
||||
<v-text-field
|
||||
v-model.number="topK"
|
||||
:label="t('retrieval.topK')"
|
||||
:hint="t('retrieval.topKHint')"
|
||||
type="number"
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
persistent-hint
|
||||
class="mb-3"
|
||||
/>
|
||||
<v-text-field v-model.number="topK" :label="t('retrieval.topK')" :hint="t('retrieval.topKHint')"
|
||||
type="number" variant="outlined" density="compact" persistent-hint class="mb-3" />
|
||||
|
||||
<v-checkbox
|
||||
v-model="enableRerank"
|
||||
:label="t('retrieval.enableRerank')"
|
||||
:hint="t('retrieval.enableRerankHint')"
|
||||
color="primary"
|
||||
density="compact"
|
||||
persistent-hint
|
||||
/>
|
||||
<v-alert v-if="enableRerank" type="info" variant="tonal" class="mt-2" density="compact">
|
||||
如果没有配置重排序模型提供商,将跳过重排序步骤
|
||||
</v-alert>
|
||||
<v-switch v-model="debugMode" :label="t('retrieval.debugMode')" color="primary" density="compact"
|
||||
hide-details>
|
||||
<template v-slot:label>
|
||||
<span class="text-caption">
|
||||
<v-icon size="small" class="mr-1">mdi-bug</v-icon>
|
||||
Debug (t-SNE)
|
||||
</span>
|
||||
</template>
|
||||
</v-switch>
|
||||
</v-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<div class="d-flex justify-end mb-4">
|
||||
<v-btn
|
||||
prepend-icon="mdi-magnify"
|
||||
color="primary"
|
||||
variant="elevated"
|
||||
@click="performRetrieval"
|
||||
:loading="loading"
|
||||
:disabled="!query || query.trim() === ''"
|
||||
>
|
||||
<v-btn prepend-icon="mdi-magnify" color="primary" variant="elevated" @click="performRetrieval"
|
||||
:loading="loading" :disabled="!query || query.trim() === ''">
|
||||
{{ loading ? t('retrieval.searching') : t('retrieval.search') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
@@ -71,28 +62,33 @@
|
||||
|
||||
<div class="d-flex align-center mb-4">
|
||||
<h3 class="text-h6">{{ t('retrieval.results') }}</h3>
|
||||
<v-chip class="ml-3" color="primary" variant="tonal">
|
||||
<v-chip class="ml-3" color="primary" variant="tonal" size="small">
|
||||
{{ results.length }} {{ t('retrieval.results') }}
|
||||
</v-chip>
|
||||
</div>
|
||||
|
||||
<!-- 结果列表 -->
|
||||
<div v-if="results.length > 0" class="results-list">
|
||||
<v-card
|
||||
v-for="(result, index) in results"
|
||||
:key="result.chunk_id"
|
||||
variant="outlined"
|
||||
class="mb-4"
|
||||
>
|
||||
<v-card-title class="d-flex align-center pa-4">
|
||||
<v-chip size="small" color="primary" class="mr-2">
|
||||
<v-card v-for="(result, index) in results" :key="result.chunk_id" variant="outlined" class="mb-4">
|
||||
<v-card-title class="d-flex align-center pa-2">
|
||||
<v-chip size="x-small" color="primary" class="mr-2">
|
||||
#{{ index + 1 }}
|
||||
</v-chip>
|
||||
<span class="text-subtitle-1">
|
||||
{{ t('retrieval.chunk', { index: result.chunk_index }) }}
|
||||
</span>
|
||||
<div class="ml-4">
|
||||
<v-chip size="x-small" variant="tonal" class="mr-2">
|
||||
<v-icon start size="small">mdi-file-document</v-icon>
|
||||
{{ result.doc_name }}
|
||||
</v-chip>
|
||||
<v-chip size="x-small" variant="tonal">
|
||||
<v-icon start size="small">mdi-text</v-icon>
|
||||
{{ t('retrieval.charCount', { count: result.char_count }) }}
|
||||
</v-chip>
|
||||
</div>
|
||||
<v-spacer />
|
||||
<v-chip size="small" :color="getScoreColor(result.score)">
|
||||
<v-chip size="x-small" :color="getScoreColor(result.score)">
|
||||
{{ t('retrieval.score') }}: {{ result.score.toFixed(4) }}
|
||||
</v-chip>
|
||||
</v-card-title>
|
||||
@@ -100,17 +96,6 @@
|
||||
<v-divider />
|
||||
|
||||
<v-card-text class="pa-4">
|
||||
<div class="mb-3">
|
||||
<v-chip size="small" variant="tonal" class="mr-2">
|
||||
<v-icon start size="small">mdi-file-document</v-icon>
|
||||
{{ result.doc_name }}
|
||||
</v-chip>
|
||||
<v-chip size="small" variant="tonal">
|
||||
<v-icon start size="small">mdi-text</v-icon>
|
||||
{{ t('retrieval.charCount', { count: result.char_count }) }}
|
||||
</v-chip>
|
||||
</div>
|
||||
|
||||
<div class="content-box">
|
||||
{{ result.content }}
|
||||
</div>
|
||||
@@ -143,16 +128,18 @@ import { useModuleI18n } from '@/i18n/composables'
|
||||
const { tm: t } = useModuleI18n('features/knowledge-base/detail')
|
||||
|
||||
const props = defineProps<{
|
||||
kbId: string
|
||||
kbId: string,
|
||||
kbName: string,
|
||||
}>()
|
||||
|
||||
// 状态
|
||||
const loading = ref(false)
|
||||
const query = ref('')
|
||||
const topK = ref(5)
|
||||
const enableRerank = ref(false)
|
||||
const debugMode = ref(false)
|
||||
const results = ref<any[]>([])
|
||||
const hasSearched = ref(false)
|
||||
const debugVisualize = ref<string | null>(null)
|
||||
|
||||
const snackbar = ref({
|
||||
show: false,
|
||||
@@ -175,18 +162,24 @@ const performRetrieval = async () => {
|
||||
|
||||
loading.value = true
|
||||
hasSearched.value = false
|
||||
debugVisualize.value = null
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/kb/retrieve', {
|
||||
query: query.value,
|
||||
kb_ids: [props.kbId],
|
||||
kb_names: [props.kbName],
|
||||
top_k: topK.value,
|
||||
enable_rerank: enableRerank.value
|
||||
debug: debugMode.value
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
results.value = response.data.data.results || []
|
||||
hasSearched.value = true
|
||||
|
||||
if (debugMode.value && response.data.data.visualization) {
|
||||
debugVisualize.value = response.data.data.visualization
|
||||
}
|
||||
|
||||
showSnackbar(t('retrieval.searchSuccess', { count: results.value.length }))
|
||||
} else {
|
||||
showSnackbar(response.data.message || t('retrieval.searchFailed'), 'error')
|
||||
@@ -214,8 +207,13 @@ const getScoreColor = (score: number) => {
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; }
|
||||
to { opacity: 1; }
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
.results-section {
|
||||
@@ -227,6 +225,7 @@ const getScoreColor = (score: number) => {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
@@ -234,7 +233,7 @@ const getScoreColor = (score: number) => {
|
||||
}
|
||||
|
||||
.content-box {
|
||||
background: rgba(var(--v-theme-surface-variant), 0.3);
|
||||
background: rgba(var(--v-theme-surface-variant), 0.1);
|
||||
border-radius: 8px;
|
||||
padding: 16px;
|
||||
white-space: pre-wrap;
|
||||
@@ -242,5 +241,8 @@ const getScoreColor = (score: number) => {
|
||||
font-family: 'Consolas', 'Monaco', 'Courier New', monospace;
|
||||
font-size: 0.9rem;
|
||||
line-height: 1.6;
|
||||
height: 120px;
|
||||
overflow-y: auto;
|
||||
font-size: 13px;
|
||||
}
|
||||
</style>
|
||||
|
||||
@@ -1,172 +0,0 @@
|
||||
<template>
|
||||
<div class="sessions-tab">
|
||||
<v-card elevation="2">
|
||||
<v-card-title class="d-flex align-center pa-4">
|
||||
<span>{{ t('sessions.title') }}</span>
|
||||
<v-spacer />
|
||||
<v-btn
|
||||
prepend-icon="mdi-refresh"
|
||||
variant="tonal"
|
||||
size="small"
|
||||
@click="loadSessions"
|
||||
:loading="loading"
|
||||
>
|
||||
{{ t('sessions.refresh') }}
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
|
||||
<v-card-subtitle class="px-4 pb-4">
|
||||
{{ t('sessions.subtitle') }}
|
||||
</v-card-subtitle>
|
||||
|
||||
<v-divider />
|
||||
|
||||
<v-card-text class="pa-0">
|
||||
<v-data-table
|
||||
:headers="headers"
|
||||
:items="sessions"
|
||||
:loading="loading"
|
||||
>
|
||||
<template #item.scope="{ item }">
|
||||
<v-chip :color="item.scope === 'session' ? 'primary' : 'secondary'" size="small" variant="tonal">
|
||||
{{ item.scope === 'session' ? t('sessions.scopeSession') : t('sessions.scopePlatform') }}
|
||||
</v-chip>
|
||||
</template>
|
||||
|
||||
<template #item.enable_rerank="{ item }">
|
||||
<v-icon :color="item.enable_rerank ? 'success' : 'grey'">
|
||||
{{ item.enable_rerank ? 'mdi-check-circle' : 'mdi-close-circle' }}
|
||||
</v-icon>
|
||||
</template>
|
||||
|
||||
<template #item.actions="{ item }">
|
||||
<v-btn
|
||||
icon="mdi-open-in-new"
|
||||
variant="text"
|
||||
size="small"
|
||||
color="primary"
|
||||
@click="goToSessionManagement(item)"
|
||||
:title="t('sessions.viewInSessionManagement')"
|
||||
/>
|
||||
</template>
|
||||
|
||||
<template #no-data>
|
||||
<div class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-2">mdi-account-multiple-outline</v-icon>
|
||||
<p class="mt-4 text-medium-emphasis">{{ t('sessions.empty') }}</p>
|
||||
<v-btn
|
||||
class="mt-4"
|
||||
prepend-icon="mdi-cog"
|
||||
variant="tonal"
|
||||
color="primary"
|
||||
@click="goToSessionManagement()"
|
||||
>
|
||||
{{ t('sessions.goToSessionManagement') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</template>
|
||||
</v-data-table>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
|
||||
<!-- 消息提示 -->
|
||||
<v-snackbar v-model="snackbar.show" :color="snackbar.color">
|
||||
{{ snackbar.text }}
|
||||
</v-snackbar>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted } from 'vue'
|
||||
import { useRouter } from 'vue-router'
|
||||
import axios from 'axios'
|
||||
import { useModuleI18n } from '@/i18n/composables'
|
||||
|
||||
const { tm: t } = useModuleI18n('features/knowledge-base/detail')
|
||||
const router = useRouter()
|
||||
|
||||
const props = defineProps<{
|
||||
kbId: string
|
||||
}>()
|
||||
|
||||
// 状态
|
||||
const loading = ref(false)
|
||||
const sessions = ref<any[]>([])
|
||||
|
||||
const snackbar = ref({
|
||||
show: false,
|
||||
text: '',
|
||||
color: 'success'
|
||||
})
|
||||
|
||||
const showSnackbar = (text: string, color: string = 'success') => {
|
||||
snackbar.value.text = text
|
||||
snackbar.value.color = color
|
||||
snackbar.value.show = true
|
||||
}
|
||||
|
||||
// 表格列
|
||||
const headers = [
|
||||
{ title: t('sessions.scope'), key: 'scope' },
|
||||
{ title: t('sessions.scopeId'), key: 'scope_id' },
|
||||
{ title: t('sessions.topK'), key: 'top_k' },
|
||||
{ title: t('sessions.enableRerank'), key: 'enable_rerank' },
|
||||
{ title: t('sessions.actions'), key: 'actions', sortable: false, align: 'end' }
|
||||
]
|
||||
|
||||
// 加载使用该知识库的会话
|
||||
const loadSessions = async () => {
|
||||
loading.value = true
|
||||
|
||||
|
||||
try {
|
||||
const url = '/api/kb/session/config/list_by_kb'
|
||||
const params = { kb_id: props.kbId }
|
||||
|
||||
|
||||
const response = await axios.get(url, { params })
|
||||
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
sessions.value = response.data.data.sessions
|
||||
|
||||
} else {
|
||||
console.error('[SessionsTab] API返回错误:', response.data.message)
|
||||
showSnackbar(response.data.message || t('sessions.loadFailed'), 'error')
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[SessionsTab] 请求失败:', error)
|
||||
if (error.response) {
|
||||
console.error('[SessionsTab] 错误响应状态:', error.response.status)
|
||||
console.error('[SessionsTab] 错误响应数据:', error.response.data)
|
||||
} else if (error.request) {
|
||||
console.error('[SessionsTab] 请求未收到响应:', error.request)
|
||||
} else {
|
||||
console.error('[SessionsTab] 请求配置错误:', error.message)
|
||||
}
|
||||
showSnackbar(t('sessions.loadFailed'), 'error')
|
||||
} finally {
|
||||
loading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 跳转到会话管理页面
|
||||
const goToSessionManagement = (session?: any) => {
|
||||
router.push({ name: 'SessionManagement' })
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadSessions()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.sessions-tab {
|
||||
animation: fadeIn 0.3s ease;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; }
|
||||
to { opacity: 1; }
|
||||
}
|
||||
</style>
|
||||
@@ -34,7 +34,7 @@
|
||||
<h3 class="text-h6 mb-4 mt-6">{{ t('settings.retrieval') }}</h3>
|
||||
|
||||
<v-row>
|
||||
<v-col cols="12" md="4">
|
||||
<v-col cols="12" md="6">
|
||||
<v-text-field
|
||||
v-model.number="formData.top_k_dense"
|
||||
:label="t('settings.topKDense')"
|
||||
@@ -43,7 +43,7 @@
|
||||
density="comfortable"
|
||||
/>
|
||||
</v-col>
|
||||
<v-col cols="12" md="4">
|
||||
<v-col cols="12" md="6">
|
||||
<v-text-field
|
||||
v-model.number="formData.top_k_sparse"
|
||||
:label="t('settings.topKSparse')"
|
||||
@@ -52,7 +52,7 @@
|
||||
density="comfortable"
|
||||
/>
|
||||
</v-col>
|
||||
<v-col cols="12" md="4">
|
||||
<!-- <v-col cols="12" md="4">
|
||||
<v-text-field
|
||||
v-model.number="formData.top_m_final"
|
||||
:label="t('settings.topMFinal')"
|
||||
@@ -60,23 +60,7 @@
|
||||
variant="outlined"
|
||||
density="comfortable"
|
||||
/>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<v-row>
|
||||
<v-col cols="12">
|
||||
<v-checkbox
|
||||
v-model="formData.enable_rerank"
|
||||
:label="t('settings.enableRerank')"
|
||||
:hint="rerankProviders.length === 0 ? '未检测到可用的重排序模型提供商' : '使用重排序模型提高检索质量'"
|
||||
:disabled="rerankProviders.length === 0"
|
||||
color="primary"
|
||||
persistent-hint
|
||||
/>
|
||||
<v-alert v-if="formData.enable_rerank && rerankProviders.length === 0" type="warning" variant="tonal" class="mt-2" density="compact">
|
||||
当前没有可用的重排序模型提供商,请先在提供商管理中添加支持 rerank 的模型
|
||||
</v-alert>
|
||||
</v-col>
|
||||
</v-col> -->
|
||||
</v-row>
|
||||
|
||||
<!-- 模型设置 -->
|
||||
@@ -93,6 +77,7 @@
|
||||
variant="outlined"
|
||||
density="comfortable"
|
||||
@update:model-value="handleEmbeddingProviderChange"
|
||||
:disabled="true"
|
||||
/>
|
||||
</v-col>
|
||||
<v-col cols="12" md="6">
|
||||
@@ -216,8 +201,6 @@ const formData = ref({
|
||||
chunk_overlap: 50,
|
||||
top_k_dense: 50,
|
||||
top_k_sparse: 50,
|
||||
top_m_final: 5,
|
||||
enable_rerank: false,
|
||||
embedding_provider_id: '',
|
||||
rerank_provider_id: ''
|
||||
})
|
||||
@@ -230,8 +213,7 @@ watch(() => props.kb, (kb) => {
|
||||
chunk_overlap: kb.chunk_overlap || 50,
|
||||
top_k_dense: kb.top_k_dense || 50,
|
||||
top_k_sparse: kb.top_k_sparse || 50,
|
||||
top_m_final: kb.top_m_final || 5,
|
||||
enable_rerank: kb.enable_rerank === true,
|
||||
// top_m_final: kb.top_m_final || 5,
|
||||
embedding_provider_id: kb.embedding_provider_id || '',
|
||||
rerank_provider_id: kb.rerank_provider_id || ''
|
||||
}
|
||||
@@ -299,8 +281,7 @@ const saveSettings = async () => {
|
||||
chunk_overlap: formData.value.chunk_overlap,
|
||||
top_k_dense: formData.value.top_k_dense,
|
||||
top_k_sparse: formData.value.top_k_sparse,
|
||||
top_m_final: formData.value.top_m_final,
|
||||
enable_rerank: formData.value.enable_rerank,
|
||||
// top_m_final: formData.value.top_m_final,
|
||||
rerank_provider_id: formData.value.rerank_provider_id
|
||||
})
|
||||
|
||||
|
||||
@@ -129,7 +129,7 @@ class ProviderCommands:
|
||||
)
|
||||
return
|
||||
i = 1
|
||||
ret = "下面列出了此服务提供商可用模型:"
|
||||
ret = "下面列出了此模型提供商可用模型:"
|
||||
for model in models:
|
||||
ret += f"\n{i}. {model}"
|
||||
i += 1
|
||||
|
||||
@@ -11,19 +11,26 @@ class SIDCommand:
|
||||
self.context = context
|
||||
|
||||
async def sid(self, event: AstrMessageEvent):
|
||||
"""获取会话 ID 和 管理员 ID"""
|
||||
"""获取消息来源信息"""
|
||||
sid = event.unified_msg_origin
|
||||
user_id = str(event.get_sender_id())
|
||||
ret = f"""SID: {sid} 此 ID 可用于设置会话白名单。
|
||||
/wl <SID> 添加白名单, /dwl <SID> 删除白名单。
|
||||
|
||||
UID: {user_id} 此 ID 可用于设置管理员。
|
||||
/op <UID> 授权管理员, /deop <UID> 取消管理员。"""
|
||||
umo_platform = event.session.platform_id
|
||||
umo_msg_type = event.session.message_type.value
|
||||
umo_session_id = event.session.session_id
|
||||
ret = (
|
||||
f"UMO: 「{sid}」 此值可用于设置白名单。\n"
|
||||
f"UID: 「{user_id}」 此值可用于设置管理员。\n"
|
||||
f"消息会话来源信息:\n"
|
||||
f" 机器人 ID: 「{umo_platform}」\n"
|
||||
f" 消息类型: 「{umo_msg_type}」\n"
|
||||
f" 会话 ID: 「{umo_session_id}」\n"
|
||||
f"消息来源可用于配置机器人的配置文件路由。"
|
||||
)
|
||||
|
||||
if (
|
||||
self.context.get_config()["platform_settings"]["unique_session"]
|
||||
and event.get_group_id()
|
||||
):
|
||||
ret += f"\n\n当前处于独立会话模式, 此群 ID: {event.get_group_id()}, 也可将此 ID 加入白名单来放行整个群聊。"
|
||||
ret += f"\n\n当前处于独立会话模式, 此群 ID: 「{event.get_group_id()}」, 也可将此 ID 加入白名单来放行整个群聊。"
|
||||
|
||||
event.set_result(MessageEventResult().message(ret).use_t2i(False))
|
||||
|
||||
@@ -4,7 +4,7 @@ import random
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.platform import MessageType
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot.api.provider import ProviderRequest, Provider
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot import logger
|
||||
from collections import defaultdict
|
||||
@@ -32,6 +32,7 @@ class LongTermMemory:
|
||||
image_caption = (
|
||||
True
|
||||
if cfg["provider_settings"]["default_image_caption_provider_id"]
|
||||
and cfg["provider_ltm_settings"]["image_caption"]
|
||||
else False
|
||||
)
|
||||
image_caption_prompt = cfg["provider_settings"]["image_caption_prompt"]
|
||||
@@ -73,6 +74,8 @@ class LongTermMemory:
|
||||
provider = self.context.get_provider_by_id(image_caption_provider_id)
|
||||
if not provider:
|
||||
raise Exception(f"没有找到 ID 为 {image_caption_provider_id} 的提供商")
|
||||
if not isinstance(provider, Provider):
|
||||
raise Exception(f"提供商类型错误({type(provider)}),无法获取图片描述")
|
||||
response = await provider.text_chat(
|
||||
prompt=image_caption_prompt,
|
||||
session_id=uuid.uuid4().hex,
|
||||
@@ -122,8 +125,11 @@ class LongTermMemory:
|
||||
elif isinstance(comp, Image):
|
||||
if cfg["image_caption"]:
|
||||
try:
|
||||
url = comp.url if comp.url else comp.file
|
||||
if not url:
|
||||
raise Exception("图片 URL 为空")
|
||||
caption = await self.get_image_caption(
|
||||
comp.url if comp.url else comp.file,
|
||||
url,
|
||||
cfg["image_caption_provider_id"],
|
||||
cfg["image_caption_prompt"],
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import astrbot.api.star as star
|
||||
import builtins
|
||||
import datetime
|
||||
@@ -41,7 +42,7 @@ class ProcessLLMRequest:
|
||||
if persona:
|
||||
if prompt := persona["prompt"]:
|
||||
req.system_prompt += prompt
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
if begin_dialogs := copy.deepcopy(persona["_begin_dialogs_processed"]):
|
||||
req.contexts[:0] = begin_dialogs
|
||||
|
||||
# tools select
|
||||
|
||||
@@ -54,6 +54,7 @@ dependencies = [
|
||||
"pypdf>=6.1.1",
|
||||
"aiofiles>=25.1.0",
|
||||
"rank-bm25>=0.2.2",
|
||||
"jieba>=0.42.1",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
Reference in New Issue
Block a user