Compare commits
47 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 82a96a8cce | |||
| 343b153263 | |||
| 3a41b19318 | |||
| af444ea6cc | |||
| cb84db532e | |||
| 99b82f48ec | |||
| 00471f904e | |||
| 5df15c60ff | |||
| 32e523b7da | |||
| 0de4fd9f0d | |||
| e23a7e2505 | |||
| 1ed4d9f484 | |||
| d842155770 | |||
| 7f5cc7cf1a | |||
| f26867c77d | |||
| a14d588b44 | |||
| e236402d92 | |||
| 454841de10 | |||
| 442b5403df | |||
| 9db7bf59b8 | |||
| 3622504021 | |||
| fc42db40ce | |||
| e413a002c1 | |||
| 6437d759a3 | |||
| c758b2d888 | |||
| 510290fe0e | |||
| c61d62edb6 | |||
| 45bce6fe76 | |||
| f156adddf8 | |||
| b5a4b80c36 | |||
| 792fb69d6d | |||
| 300a73ace0 | |||
| a5b9de3695 | |||
| 90142bcafe | |||
| 79d0487c03 | |||
| 4f15102e79 | |||
| ef1feb639c | |||
| 1039a4f864 | |||
| 66e2f49c11 | |||
| c5773fe63e | |||
| 4e9ef48af2 | |||
| 9eafd7b44a | |||
| fc61f7ad32 | |||
| f51810997a | |||
| fb4baf676f | |||
| 71ad974c3c | |||
| f0fff68947 |
+52
-15
@@ -1,27 +1,64 @@
|
||||
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
|
||||
# 本工作流用于标记并关闭长期不活跃的 Issue。
|
||||
# 目前仅针对带 `bug` 标签的 Issue 生效,不会处理 PR。
|
||||
#
|
||||
# You can adjust the behavior by modifying this file.
|
||||
# For more information, see:
|
||||
# https://github.com/actions/stale
|
||||
name: Mark stale issues and pull requests
|
||||
# 文档: https://github.com/actions/stale
|
||||
name: Mark stale bug issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '21 23 * * *'
|
||||
# 每天 UTC 08:30 执行 (北京时间 16:30)
|
||||
- cron: '30 8 * * *'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dry-run:
|
||||
description: '仅预览, 不实际执行 (Dry run mode)'
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: 'Stale issue message'
|
||||
stale-pr-message: 'Stale pull request message'
|
||||
stale-issue-label: 'no-issue-activity'
|
||||
stale-pr-label: 'no-pr-activity'
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
operations-per-run: 200
|
||||
|
||||
# 只处理带 bug 标签的 Issue
|
||||
any-of-labels: 'bug'
|
||||
|
||||
# 不处理 PR
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
|
||||
# 不活跃判定与关闭策略: 先标记 stale, 再延迟关闭
|
||||
days-before-issue-stale: 60
|
||||
days-before-issue-close: 30
|
||||
|
||||
stale-issue-label: 'stale'
|
||||
stale-issue-message: |
|
||||
This issue has been automatically marked as **stale** because it has not had any activity.
|
||||
It will be closed in a certain period of time if no further activity occurs.
|
||||
If this issue is still relevant, please leave a comment.
|
||||
|
||||
---
|
||||
|
||||
该 Issue 已较长时间无活动, 已被标记为 `stale`。
|
||||
如无后续活动, 将在一段时间后自动关闭。
|
||||
如仍需跟进, 请回复评论。
|
||||
close-issue-message: |
|
||||
This issue has been automatically closed due to inactivity.
|
||||
If the problem still exists, feel free to reopen or create a new issue with updated information.
|
||||
|
||||
---
|
||||
|
||||
该 Issue 因长期无活动已自动关闭。
|
||||
如问题仍存在, 欢迎补充复现信息并重新打开或新建 Issue。
|
||||
|
||||
remove-stale-when-updated: true
|
||||
|
||||
debug-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run }}
|
||||
|
||||
@@ -132,6 +132,7 @@ uv run main.py
|
||||
|
||||
**社区维护**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
@@ -208,6 +209,7 @@ pre-commit install
|
||||
- 5 群:822130018
|
||||
- 6 群:753075035
|
||||
- 7 群:743746109
|
||||
- 8 群:1030353265
|
||||
- 开发者群:975206796
|
||||
|
||||
### Telegram 群组
|
||||
|
||||
@@ -134,6 +134,7 @@ Or refer to the official documentation: [Deploy AstrBot from Source](https://ast
|
||||
|
||||
**Community Maintained**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili Direct Messages](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources
|
||||
|
||||
**Maintenues par la communauté**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ uv run main.py
|
||||
|
||||
**コミュニティメンテナンス**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili ダイレクトメッセージ](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ uv run main.py
|
||||
|
||||
**Поддерживаемые сообществом**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -134,6 +134,7 @@ uv run main.py
|
||||
|
||||
**社群維護**
|
||||
|
||||
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
|
||||
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
|
||||
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
|
||||
- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
|
||||
|
||||
@@ -21,6 +21,9 @@ from astrbot.core.star.register import (
|
||||
from astrbot.core.star.register import register_on_llm_request as on_llm_request
|
||||
from astrbot.core.star.register import register_on_llm_response as on_llm_response
|
||||
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
|
||||
from astrbot.core.star.register import (
|
||||
register_on_waiting_llm_request as on_waiting_llm_request,
|
||||
)
|
||||
from astrbot.core.star.register import register_permission_type as permission_type
|
||||
from astrbot.core.star.register import (
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
@@ -46,6 +49,7 @@ __all__ = [
|
||||
"on_llm_request",
|
||||
"on_llm_response",
|
||||
"on_platform_loaded",
|
||||
"on_waiting_llm_request",
|
||||
"permission_type",
|
||||
"platform_adapter_type",
|
||||
"regex",
|
||||
|
||||
@@ -100,16 +100,8 @@ class Main(star.Star):
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
"""在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话"""
|
||||
umo = event.unified_msg_origin
|
||||
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||
show_reasoning = cfg.get("display_reasoning_text", False)
|
||||
if show_reasoning and resp.reasoning_content:
|
||||
resp.completion_text = (
|
||||
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
|
||||
)
|
||||
|
||||
async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
"""在 LLM 响应后记录对话"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.after_req_llm(event, resp)
|
||||
|
||||
@@ -14,13 +14,13 @@ class TTSCommand:
|
||||
async def tts(self, event: AstrMessageEvent):
|
||||
"""开关文本转语音(会话级别)"""
|
||||
umo = event.unified_msg_origin
|
||||
ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo)
|
||||
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
|
||||
cfg = self.context.get_config(umo=umo)
|
||||
tts_enable = cfg["provider_tts_settings"]["enable"]
|
||||
|
||||
# 切换状态
|
||||
new_status = not ses_tts
|
||||
SessionServiceManager.set_tts_status_for_session(umo, new_status)
|
||||
await SessionServiceManager.set_tts_status_for_session(umo, new_status)
|
||||
|
||||
status_text = "已开启" if new_status else "已关闭"
|
||||
|
||||
|
||||
@@ -157,9 +157,8 @@ class Main(star.Star):
|
||||
async def is_docker_available(self) -> bool:
|
||||
"""Check if docker is available"""
|
||||
try:
|
||||
docker = aiodocker.Docker()
|
||||
await docker.version()
|
||||
await docker.close()
|
||||
async with aiodocker.Docker() as docker:
|
||||
await docker.version()
|
||||
return True
|
||||
except BaseException as e:
|
||||
logger.info(f"检查 Docker 可用性: {e}")
|
||||
@@ -279,14 +278,14 @@ class Main(star.Star):
|
||||
@pi.command("repull")
|
||||
async def pi_repull(self, event: AstrMessageEvent):
|
||||
"""重新拉取沙箱镜像"""
|
||||
docker = aiodocker.Docker()
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
await docker.images.delete(image_name, force=True)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
pass
|
||||
await docker.images.pull(image_name)
|
||||
async with aiodocker.Docker() as docker:
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
await docker.images.delete(image_name, force=True)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
pass
|
||||
await docker.images.pull(image_name)
|
||||
yield event.plain_result("重新拉取沙箱镜像成功。")
|
||||
|
||||
@pi.command("file")
|
||||
@@ -371,137 +370,137 @@ class Main(star.Star):
|
||||
obs = ""
|
||||
n = 5
|
||||
|
||||
for i in range(n):
|
||||
if i > 0:
|
||||
logger.info(f"Try {i + 1}/{n}")
|
||||
async with aiodocker.Docker() as docker:
|
||||
for i in range(n):
|
||||
if i > 0:
|
||||
logger.info(f"Try {i + 1}/{n}")
|
||||
|
||||
PROMPT_ = PROMPT.format(
|
||||
prompt=plain_text,
|
||||
extra_input=extra_inputs,
|
||||
extra_prompt=obs,
|
||||
)
|
||||
provider = self.context.get_using_provider()
|
||||
llm_response = await provider.text_chat(
|
||||
prompt=PROMPT_,
|
||||
session_id=f"{event.session_id}_{magic_code}_{i!s}",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"code interpreter llm gened code:" + llm_response.completion_text,
|
||||
)
|
||||
|
||||
# 整理代码并保存
|
||||
code_clean = await self.tidy_code(llm_response.completion_text)
|
||||
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
|
||||
f.write(code_clean)
|
||||
|
||||
# 启动容器
|
||||
docker = aiodocker.Docker()
|
||||
|
||||
# 检查有没有image
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
# 拉取镜像
|
||||
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
|
||||
await docker.images.pull(image_name)
|
||||
|
||||
yield event.plain_result(
|
||||
f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})",
|
||||
)
|
||||
|
||||
self.docker_host_astrbot_abs_path = self.config.get(
|
||||
"docker_host_astrbot_abs_path",
|
||||
"",
|
||||
)
|
||||
if self.docker_host_astrbot_abs_path:
|
||||
host_shared = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
self.shared_path,
|
||||
PROMPT_ = PROMPT.format(
|
||||
prompt=plain_text,
|
||||
extra_input=extra_inputs,
|
||||
extra_prompt=obs,
|
||||
)
|
||||
host_output = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
output_path,
|
||||
)
|
||||
host_workplace = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
workplace_path,
|
||||
provider = self.context.get_using_provider()
|
||||
llm_response = await provider.text_chat(
|
||||
prompt=PROMPT_,
|
||||
session_id=f"{event.session_id}_{magic_code}_{i!s}",
|
||||
)
|
||||
|
||||
else:
|
||||
host_shared = os.path.abspath(self.shared_path)
|
||||
host_output = os.path.abspath(output_path)
|
||||
host_workplace = os.path.abspath(workplace_path)
|
||||
logger.debug(
|
||||
"code interpreter llm gened code:" + llm_response.completion_text,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}",
|
||||
)
|
||||
# 整理代码并保存
|
||||
code_clean = await self.tidy_code(llm_response.completion_text)
|
||||
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
|
||||
f.write(code_clean)
|
||||
|
||||
container = await docker.containers.run(
|
||||
{
|
||||
"Image": image_name,
|
||||
"Cmd": ["python", "exec.py"],
|
||||
"Memory": 512 * 1024 * 1024,
|
||||
"NanoCPUs": 1000000000,
|
||||
"HostConfig": {
|
||||
"Binds": [
|
||||
f"{host_shared}:/astrbot_sandbox/shared:ro",
|
||||
f"{host_output}:/astrbot_sandbox/output:rw",
|
||||
f"{host_workplace}:/astrbot_sandbox:rw",
|
||||
],
|
||||
},
|
||||
"Env": [f"MAGIC_CODE={magic_code}"],
|
||||
"AutoRemove": True,
|
||||
},
|
||||
)
|
||||
# 检查有没有image
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
# 拉取镜像
|
||||
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
|
||||
await docker.images.pull(image_name)
|
||||
|
||||
logger.debug(f"Container {container.id} created.")
|
||||
logs = await self.run_container(container)
|
||||
yield event.plain_result(
|
||||
f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})",
|
||||
)
|
||||
|
||||
logger.debug(f"Container {container.id} finished.")
|
||||
logger.debug(f"Container {container.id} logs: {logs}")
|
||||
|
||||
# 发送结果
|
||||
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
|
||||
ok = False
|
||||
traceback = ""
|
||||
for idx, log in enumerate(logs):
|
||||
match = re.match(pattern, log)
|
||||
if match:
|
||||
ok = True
|
||||
if match.group(1) == "TEXT":
|
||||
yield event.plain_result(match.group(2))
|
||||
elif match.group(1) == "IMAGE":
|
||||
image_path = os.path.join(workplace_path, match.group(2))
|
||||
logger.debug(f"Sending image: {image_path}")
|
||||
yield event.image_result(image_path)
|
||||
elif match.group(1) == "FILE":
|
||||
file_path = os.path.join(workplace_path, match.group(2))
|
||||
# logger.debug(f"Sending file: {file_path}")
|
||||
# file_s3_url = await self.file_upload(file_path)
|
||||
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
||||
file_name = os.path.basename(file_path)
|
||||
chain: list[BaseMessageComponent] = [
|
||||
File(name=file_name, file=file_path)
|
||||
]
|
||||
yield event.set_result(MessageEventResult(chain=chain))
|
||||
|
||||
elif "Traceback (most recent call last)" in log or "[Error]: " in log:
|
||||
traceback = "\n".join(logs[idx:])
|
||||
|
||||
if not ok:
|
||||
if traceback:
|
||||
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
|
||||
else:
|
||||
logger.warning(
|
||||
f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}",
|
||||
self.docker_host_astrbot_abs_path = self.config.get(
|
||||
"docker_host_astrbot_abs_path",
|
||||
"",
|
||||
)
|
||||
if self.docker_host_astrbot_abs_path:
|
||||
host_shared = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
self.shared_path,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# 成功了
|
||||
self.user_file_msg_buffer.pop(event.get_session_id())
|
||||
return
|
||||
host_output = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
output_path,
|
||||
)
|
||||
host_workplace = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
workplace_path,
|
||||
)
|
||||
|
||||
else:
|
||||
host_shared = os.path.abspath(self.shared_path)
|
||||
host_output = os.path.abspath(output_path)
|
||||
host_workplace = os.path.abspath(workplace_path)
|
||||
|
||||
logger.debug(
|
||||
f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}",
|
||||
)
|
||||
|
||||
container = await docker.containers.run(
|
||||
{
|
||||
"Image": image_name,
|
||||
"Cmd": ["python", "exec.py"],
|
||||
"Memory": 512 * 1024 * 1024,
|
||||
"NanoCPUs": 1000000000,
|
||||
"HostConfig": {
|
||||
"Binds": [
|
||||
f"{host_shared}:/astrbot_sandbox/shared:ro",
|
||||
f"{host_output}:/astrbot_sandbox/output:rw",
|
||||
f"{host_workplace}:/astrbot_sandbox:rw",
|
||||
],
|
||||
},
|
||||
"Env": [f"MAGIC_CODE={magic_code}"],
|
||||
"AutoRemove": True,
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(f"Container {container.id} created.")
|
||||
logs = await self.run_container(container)
|
||||
|
||||
logger.debug(f"Container {container.id} finished.")
|
||||
logger.debug(f"Container {container.id} logs: {logs}")
|
||||
|
||||
# 发送结果
|
||||
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
|
||||
ok = False
|
||||
traceback = ""
|
||||
for idx, log in enumerate(logs):
|
||||
match = re.match(pattern, log)
|
||||
if match:
|
||||
ok = True
|
||||
if match.group(1) == "TEXT":
|
||||
yield event.plain_result(match.group(2))
|
||||
elif match.group(1) == "IMAGE":
|
||||
image_path = os.path.join(workplace_path, match.group(2))
|
||||
logger.debug(f"Sending image: {image_path}")
|
||||
yield event.image_result(image_path)
|
||||
elif match.group(1) == "FILE":
|
||||
file_path = os.path.join(workplace_path, match.group(2))
|
||||
# logger.debug(f"Sending file: {file_path}")
|
||||
# file_s3_url = await self.file_upload(file_path)
|
||||
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
||||
file_name = os.path.basename(file_path)
|
||||
chain: list[BaseMessageComponent] = [
|
||||
File(name=file_name, file=file_path)
|
||||
]
|
||||
yield event.set_result(MessageEventResult(chain=chain))
|
||||
|
||||
elif (
|
||||
"Traceback (most recent call last)" in log or "[Error]: " in log
|
||||
):
|
||||
traceback = "\n".join(logs[idx:])
|
||||
|
||||
if not ok:
|
||||
if traceback:
|
||||
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
|
||||
else:
|
||||
logger.warning(
|
||||
f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}",
|
||||
)
|
||||
break
|
||||
else:
|
||||
# 成功了
|
||||
self.user_file_msg_buffer.pop(event.get_session_id())
|
||||
return
|
||||
|
||||
yield event.plain_result(
|
||||
"经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。",
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.10.3"
|
||||
__version__ = "4.11.0"
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot import logger
|
||||
else:
|
||||
try:
|
||||
from astrbot import logger
|
||||
except ImportError:
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.truncator import ContextTruncator
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ContextCompressor(Protocol):
|
||||
"""
|
||||
Protocol for context compressors.
|
||||
Provides an interface for compressing message lists.
|
||||
"""
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens for the model.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Compress the message list.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The compressed message list.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TruncateByTurnsCompressor:
|
||||
"""Truncate by turns compressor implementation.
|
||||
Truncates the message list by removing older turns.
|
||||
"""
|
||||
|
||||
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
|
||||
"""Initialize the truncate by turns compressor.
|
||||
|
||||
Args:
|
||||
truncate_turns: The number of turns to remove when truncating (default: 1).
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.truncate_turns = truncate_turns
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
truncator = ContextTruncator()
|
||||
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
|
||||
messages,
|
||||
drop_turns=self.truncate_turns,
|
||||
)
|
||||
return truncated_messages
|
||||
|
||||
|
||||
def split_history(
|
||||
messages: list[Message], keep_recent: int
|
||||
) -> tuple[list[Message], list[Message], list[Message]]:
|
||||
"""Split the message list into system messages, messages to summarize, and recent messages.
|
||||
|
||||
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
keep_recent: The number of latest messages to keep.
|
||||
|
||||
Returns:
|
||||
tuple: (system_messages, messages_to_summarize, recent_messages)
|
||||
"""
|
||||
# keep the system messages
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) <= keep_recent:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
# Find the split point, ensuring recent_messages starts with a user message
|
||||
# This maintains complete conversation turns
|
||||
split_index = len(non_system_messages) - keep_recent
|
||||
|
||||
# Search backward from split_index to find the first user message
|
||||
# This ensures recent_messages starts with a user message (complete turn)
|
||||
while split_index > 0 and non_system_messages[split_index].role != "user":
|
||||
# TODO: +=1 or -=1 ? calculate by tokens
|
||||
split_index -= 1
|
||||
|
||||
# If we couldn't find a user message, keep all messages as recent
|
||||
if split_index == 0:
|
||||
return system_messages, [], non_system_messages
|
||||
|
||||
messages_to_summarize = non_system_messages[:split_index]
|
||||
recent_messages = non_system_messages[split_index:]
|
||||
|
||||
return system_messages, messages_to_summarize, recent_messages
|
||||
|
||||
|
||||
class LLMSummaryCompressor:
|
||||
"""LLM-based summary compressor.
|
||||
Uses LLM to summarize the old conversation history, keeping the latest messages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider: "Provider",
|
||||
keep_recent: int = 4,
|
||||
instruction_text: str | None = None,
|
||||
compression_threshold: float = 0.82,
|
||||
):
|
||||
"""Initialize the LLM summary compressor.
|
||||
|
||||
Args:
|
||||
provider: The LLM provider instance.
|
||||
keep_recent: The number of latest messages to keep (default: 4).
|
||||
instruction_text: Custom instruction for summary generation.
|
||||
compression_threshold: The compression trigger threshold (default: 0.82).
|
||||
"""
|
||||
self.provider = provider
|
||||
self.keep_recent = keep_recent
|
||||
self.compression_threshold = compression_threshold
|
||||
|
||||
self.instruction_text = instruction_text or (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
)
|
||||
|
||||
def should_compress(
|
||||
self, messages: list[Message], current_tokens: int, max_tokens: int
|
||||
) -> bool:
|
||||
"""Check if compression is needed.
|
||||
|
||||
Args:
|
||||
messages: The message list to evaluate.
|
||||
current_tokens: The current token count.
|
||||
max_tokens: The maximum allowed tokens.
|
||||
|
||||
Returns:
|
||||
True if compression is needed, False otherwise.
|
||||
"""
|
||||
if max_tokens <= 0 or current_tokens <= 0:
|
||||
return False
|
||||
usage_rate = current_tokens / max_tokens
|
||||
return usage_rate > self.compression_threshold
|
||||
|
||||
async def __call__(self, messages: list[Message]) -> list[Message]:
|
||||
"""Use LLM to generate a summary of the conversation history.
|
||||
|
||||
Process:
|
||||
1. Divide messages: keep the system message and the latest N messages.
|
||||
2. Send the old messages + the instruction message to the LLM.
|
||||
3. Reconstruct the message list: [system message, summary message, latest messages].
|
||||
"""
|
||||
if len(messages) <= self.keep_recent + 1:
|
||||
return messages
|
||||
|
||||
system_messages, messages_to_summarize, recent_messages = split_history(
|
||||
messages, self.keep_recent
|
||||
)
|
||||
|
||||
if not messages_to_summarize:
|
||||
return messages
|
||||
|
||||
# build payload
|
||||
instruction_message = Message(role="user", content=self.instruction_text)
|
||||
llm_payload = messages_to_summarize + [instruction_message]
|
||||
|
||||
# generate summary
|
||||
try:
|
||||
response = await self.provider.text_chat(contexts=llm_payload)
|
||||
summary_content = response.completion_text
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate summary: {e}")
|
||||
return messages
|
||||
|
||||
# build result
|
||||
result = []
|
||||
result.extend(system_messages)
|
||||
|
||||
result.append(
|
||||
Message(
|
||||
role="user",
|
||||
content=f"Our previous history conversation summary: {summary_content}",
|
||||
)
|
||||
)
|
||||
result.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Acknowledged the summary of our previous conversation history.",
|
||||
)
|
||||
)
|
||||
|
||||
result.extend(recent_messages)
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,35 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .compressor import ContextCompressor
|
||||
from .token_counter import TokenCounter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextConfig:
|
||||
"""Context configuration class."""
|
||||
|
||||
max_context_tokens: int = 0
|
||||
"""Maximum number of context tokens. <= 0 means no limit."""
|
||||
enforce_max_turns: int = -1 # -1 means no limit
|
||||
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
|
||||
truncate_turns: int = 1
|
||||
"""Number of conversation turns to discard at once when truncation is triggered.
|
||||
Two processes will use this value:
|
||||
|
||||
1. Enforce max turns truncation.
|
||||
2. Truncation by turns compression strategy.
|
||||
"""
|
||||
llm_compress_instruction: str | None = None
|
||||
"""Instruction prompt for LLM-based compression."""
|
||||
llm_compress_keep_recent: int = 0
|
||||
"""Number of recent messages to keep during LLM-based compression."""
|
||||
llm_compress_provider: "Provider | None" = None
|
||||
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
|
||||
custom_token_counter: TokenCounter | None = None
|
||||
"""Custom token counting method. If None, the default method is used."""
|
||||
custom_compressor: ContextCompressor | None = None
|
||||
"""Custom context compression method. If None, the default method is used."""
|
||||
@@ -0,0 +1,120 @@
|
||||
from astrbot import logger
|
||||
|
||||
from ..message import Message
|
||||
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
|
||||
from .config import ContextConfig
|
||||
from .token_counter import EstimateTokenCounter
|
||||
from .truncator import ContextTruncator
|
||||
|
||||
|
||||
class ContextManager:
|
||||
"""Context compression manager."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ContextConfig,
|
||||
):
|
||||
"""Initialize the context manager.
|
||||
|
||||
There are two strategies to handle context limit reached:
|
||||
1. Truncate by turns: remove older messages by turns.
|
||||
2. LLM-based compression: use LLM to summarize old messages.
|
||||
|
||||
Args:
|
||||
config: The context configuration.
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
|
||||
self.truncator = ContextTruncator()
|
||||
|
||||
if config.custom_compressor:
|
||||
self.compressor = config.custom_compressor
|
||||
elif config.llm_compress_provider:
|
||||
self.compressor = LLMSummaryCompressor(
|
||||
provider=config.llm_compress_provider,
|
||||
keep_recent=config.llm_compress_keep_recent,
|
||||
instruction_text=config.llm_compress_instruction,
|
||||
)
|
||||
else:
|
||||
self.compressor = TruncateByTurnsCompressor(
|
||||
truncate_turns=config.truncate_turns
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> list[Message]:
|
||||
"""Process the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
|
||||
Returns:
|
||||
The processed message list.
|
||||
"""
|
||||
try:
|
||||
result = messages
|
||||
|
||||
# 1. 基于轮次的截断 (Enforce max turns)
|
||||
if self.config.enforce_max_turns != -1:
|
||||
result = self.truncator.truncate_by_turns(
|
||||
result,
|
||||
keep_most_recent_turns=self.config.enforce_max_turns,
|
||||
drop_turns=self.config.truncate_turns,
|
||||
)
|
||||
|
||||
# 2. 基于 token 的压缩
|
||||
if self.config.max_context_tokens > 0:
|
||||
total_tokens = self.token_counter.count_tokens(
|
||||
result, trusted_token_usage
|
||||
)
|
||||
|
||||
if self.compressor.should_compress(
|
||||
result, total_tokens, self.config.max_context_tokens
|
||||
):
|
||||
result = await self._run_compression(result, total_tokens)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error during context processing: {e}", exc_info=True)
|
||||
return messages
|
||||
|
||||
async def _run_compression(
|
||||
self, messages: list[Message], prev_tokens: int
|
||||
) -> list[Message]:
|
||||
"""
|
||||
Compress/truncate the messages.
|
||||
|
||||
Args:
|
||||
messages: The original message list.
|
||||
prev_tokens: The token count before compression.
|
||||
|
||||
Returns:
|
||||
The compressed/truncated message list.
|
||||
"""
|
||||
logger.debug("Compress triggered, starting compression...")
|
||||
|
||||
messages = await self.compressor(messages)
|
||||
|
||||
# double check
|
||||
tokens_after_summary = self.token_counter.count_tokens(messages)
|
||||
|
||||
# calculate compress rate
|
||||
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
|
||||
logger.info(
|
||||
f"Compress completed."
|
||||
f" {prev_tokens} -> {tokens_after_summary} tokens,"
|
||||
f" compression rate: {compress_rate:.2f}%.",
|
||||
)
|
||||
|
||||
# last check
|
||||
if self.compressor.should_compress(
|
||||
messages, tokens_after_summary, self.config.max_context_tokens
|
||||
):
|
||||
logger.info(
|
||||
"Context still exceeds max tokens after compression, applying halving truncation..."
|
||||
)
|
||||
# still need compress, truncate by half
|
||||
messages = self.truncator.truncate_by_halving(messages)
|
||||
|
||||
return messages
|
||||
@@ -0,0 +1,64 @@
|
||||
import json
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
from ..message import Message, TextPart
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TokenCounter(Protocol):
|
||||
"""
|
||||
Protocol for token counters.
|
||||
Provides an interface for counting tokens in message lists.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
"""Count the total tokens in the message list.
|
||||
|
||||
Args:
|
||||
messages: The message list.
|
||||
trusted_token_usage: The total token usage that LLM API returned.
|
||||
For some cases, this value is more accurate.
|
||||
But some API does not return it, so the value defaults to 0.
|
||||
|
||||
Returns:
|
||||
The total token count.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class EstimateTokenCounter:
|
||||
"""Estimate token counter implementation.
|
||||
Provides a simple estimation of token count based on character types.
|
||||
"""
|
||||
|
||||
def count_tokens(
|
||||
self, messages: list[Message], trusted_token_usage: int = 0
|
||||
) -> int:
|
||||
if trusted_token_usage > 0:
|
||||
return trusted_token_usage
|
||||
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
total += self._estimate_tokens(content)
|
||||
elif isinstance(content, list):
|
||||
# 处理多模态内容
|
||||
for part in content:
|
||||
if isinstance(part, TextPart):
|
||||
total += self._estimate_tokens(part.text)
|
||||
|
||||
# 处理 Tool Calls
|
||||
if msg.tool_calls:
|
||||
for tc in msg.tool_calls:
|
||||
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
|
||||
total += self._estimate_tokens(tc_str)
|
||||
|
||||
return total
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
|
||||
other_count = len(text) - chinese_count
|
||||
return int(chinese_count * 0.6 + other_count * 0.3)
|
||||
@@ -0,0 +1,141 @@
|
||||
from ..message import Message
|
||||
|
||||
|
||||
class ContextTruncator:
|
||||
"""Context truncator."""
|
||||
|
||||
def fix_messages(self, messages: list[Message]) -> list[Message]:
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.role == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
|
||||
def truncate_by_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
keep_most_recent_turns: int,
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""截断上下文列表,确保不超过最大长度。
|
||||
一个 turn 包含一个 user 消息和一个 assistant 消息。
|
||||
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
|
||||
|
||||
Args:
|
||||
messages: 上下文列表
|
||||
keep_most_recent_turns: 保留最近的对话轮数
|
||||
drop_turns: 一次性丢弃的对话轮数
|
||||
|
||||
Returns:
|
||||
截断后的上下文列表
|
||||
"""
|
||||
if keep_most_recent_turns == -1:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= keep_most_recent_turns:
|
||||
return messages
|
||||
|
||||
num_to_keep = keep_most_recent_turns - drop_turns + 1
|
||||
if num_to_keep <= 0:
|
||||
truncated_contexts = []
|
||||
else:
|
||||
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
|
||||
|
||||
# 找到第一个 role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
result = system_messages + truncated_contexts
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_dropping_oldest_turns(
|
||||
self,
|
||||
messages: list[Message],
|
||||
drop_turns: int = 1,
|
||||
) -> list[Message]:
|
||||
"""丢弃最旧的 N 个对话轮次。"""
|
||||
if drop_turns <= 0:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
if len(non_system_messages) // 2 <= drop_turns:
|
||||
truncated_non_system = []
|
||||
else:
|
||||
truncated_non_system = non_system_messages[drop_turns * 2 :]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
elif truncated_non_system:
|
||||
truncated_non_system = []
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
|
||||
def truncate_by_halving(
|
||||
self,
|
||||
messages: list[Message],
|
||||
) -> list[Message]:
|
||||
"""对半砍策略,删除 50% 的消息"""
|
||||
if len(messages) <= 2:
|
||||
return messages
|
||||
|
||||
first_non_system = 0
|
||||
for i, msg in enumerate(messages):
|
||||
if msg.role != "system":
|
||||
first_non_system = i
|
||||
break
|
||||
|
||||
system_messages = messages[:first_non_system]
|
||||
non_system_messages = messages[first_non_system:]
|
||||
|
||||
messages_to_delete = len(non_system_messages) // 2
|
||||
if messages_to_delete == 0:
|
||||
return messages
|
||||
|
||||
truncated_non_system = non_system_messages[messages_to_delete:]
|
||||
|
||||
index = next(
|
||||
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
|
||||
None,
|
||||
)
|
||||
if index is not None:
|
||||
truncated_non_system = truncated_non_system[index:]
|
||||
|
||||
result = system_messages + truncated_non_system
|
||||
|
||||
return self.fix_messages(result)
|
||||
@@ -12,7 +12,7 @@ class ContentPart(BaseModel):
|
||||
|
||||
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
|
||||
|
||||
type: str
|
||||
type: Literal["text", "think", "image_url", "audio_url"]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
@@ -63,6 +63,28 @@ class TextPart(ContentPart):
|
||||
text: str
|
||||
|
||||
|
||||
class ThinkPart(ContentPart):
|
||||
"""
|
||||
>>> ThinkPart(think="I think I need to think about this.").model_dump()
|
||||
{'type': 'think', 'think': 'I think I need to think about this.', 'encrypted': None}
|
||||
"""
|
||||
|
||||
type: str = "think"
|
||||
think: str
|
||||
encrypted: str | None = None
|
||||
"""Encrypted thinking content, or signature."""
|
||||
|
||||
def merge_in_place(self, other: Any) -> bool:
|
||||
if not isinstance(other, ThinkPart):
|
||||
return False
|
||||
if self.encrypted:
|
||||
return False
|
||||
self.think += other.think
|
||||
if other.encrypted:
|
||||
self.encrypted = other.encrypted
|
||||
return True
|
||||
|
||||
|
||||
class ImageURLPart(ContentPart):
|
||||
"""
|
||||
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
|
||||
|
||||
@@ -13,6 +13,7 @@ from mcp.types import (
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import TextPart, ThinkPart
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
@@ -24,6 +25,10 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..context.compressor import ContextCompressor
|
||||
from ..context.config import ContextConfig
|
||||
from ..context.manager import ContextManager
|
||||
from ..context.token_counter import TokenCounter
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
@@ -46,10 +51,47 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
run_context: ContextWrapper[TContext],
|
||||
tool_executor: BaseFunctionToolExecutor[TContext],
|
||||
agent_hooks: BaseAgentRunHooks[TContext],
|
||||
streaming: bool = False,
|
||||
# enforce max turns, will discard older turns when exceeded BEFORE compression
|
||||
# -1 means no limit
|
||||
enforce_max_turns: int = -1,
|
||||
# llm compressor
|
||||
llm_compress_instruction: str | None = None,
|
||||
llm_compress_keep_recent: int = 0,
|
||||
llm_compress_provider: Provider | None = None,
|
||||
# truncate by turns compressor
|
||||
truncate_turns: int = 1,
|
||||
# customize
|
||||
custom_token_counter: TokenCounter | None = None,
|
||||
custom_compressor: ContextCompressor | None = None,
|
||||
**kwargs: T.Any,
|
||||
) -> None:
|
||||
self.req = request
|
||||
self.streaming = kwargs.get("streaming", False)
|
||||
self.streaming = streaming
|
||||
self.enforce_max_turns = enforce_max_turns
|
||||
self.llm_compress_instruction = llm_compress_instruction
|
||||
self.llm_compress_keep_recent = llm_compress_keep_recent
|
||||
self.llm_compress_provider = llm_compress_provider
|
||||
self.truncate_turns = truncate_turns
|
||||
self.custom_token_counter = custom_token_counter
|
||||
self.custom_compressor = custom_compressor
|
||||
# we will do compress when:
|
||||
# 1. before requesting LLM
|
||||
# TODO: 2. after LLM output a tool call
|
||||
self.context_config = ContextConfig(
|
||||
# <=0 will never do compress
|
||||
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
|
||||
# enforce max turns before compression
|
||||
enforce_max_turns=self.enforce_max_turns,
|
||||
truncate_turns=self.truncate_turns,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self.llm_compress_provider,
|
||||
custom_token_counter=self.custom_token_counter,
|
||||
custom_compressor=self.custom_compressor,
|
||||
)
|
||||
self.context_manager = ContextManager(self.context_config)
|
||||
|
||||
self.provider = provider
|
||||
self.final_llm_resp = None
|
||||
self._state = AgentState.IDLE
|
||||
@@ -109,6 +151,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self._transition_state(AgentState.RUNNING)
|
||||
llm_resp_result = None
|
||||
|
||||
# do truncate and compress
|
||||
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
|
||||
self.run_context.messages = await self.context_manager.process(
|
||||
self.run_context.messages, trusted_token_usage=token_usage
|
||||
)
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
if llm_response.is_chunk:
|
||||
# update ttft
|
||||
@@ -169,13 +217,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
|
||||
# record the final assistant message
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=llm_resp.completion_text or "*No response*",
|
||||
),
|
||||
)
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
parts.append(TextPart(text=llm_resp.completion_text or "*No response*"))
|
||||
self.run_context.messages.append(Message(role="assistant", content=parts))
|
||||
|
||||
# call the on_agent_done hook
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
@@ -214,10 +269,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
data=AgentResponseData(chain=result),
|
||||
)
|
||||
# 将结果添加到上下文中
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
parts.append(TextPart(text=llm_resp.completion_text or "*No response*"))
|
||||
tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=AssistantMessageSegment(
|
||||
tool_calls=llm_resp.to_openai_to_calls_model(),
|
||||
content=llm_resp.completion_text,
|
||||
content=parts,
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,12 @@ from astrbot.core.star.star_handler import EventType
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
if llm_response and llm_response.reasoning_content:
|
||||
# we will use this in result_decorate stage to inject reasoning content to chain
|
||||
run_context.context.event.set_extra(
|
||||
"_llm_reasoning_content", llm_response.reasoning_content
|
||||
)
|
||||
|
||||
await call_event_hook(
|
||||
run_context.context.event,
|
||||
EventType.OnLLMResponseEvent,
|
||||
|
||||
@@ -447,6 +447,7 @@ class AstrBotExporter:
|
||||
"version": BACKUP_MANIFEST_VERSION,
|
||||
"astrbot_version": VERSION,
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传
|
||||
"schema_version": {
|
||||
"main_db": "v4",
|
||||
"kb_db": "v1",
|
||||
|
||||
@@ -80,6 +80,8 @@ class AstrBotConfig(dict):
|
||||
if v["type"] == "object":
|
||||
conf[k] = {}
|
||||
_parse_schema(v["items"], conf[k])
|
||||
elif v["type"] == "template_list":
|
||||
conf[k] = default
|
||||
else:
|
||||
conf[k] = default
|
||||
|
||||
|
||||
+135
-33
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.10.3"
|
||||
VERSION = "4.11.0"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -83,6 +83,16 @@ DEFAULT_CONFIG = {
|
||||
"default_personality": "default",
|
||||
"persona_pool": ["*"],
|
||||
"prompt_prefix": "{{prompt}}",
|
||||
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
|
||||
"llm_compress_instruction": (
|
||||
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
|
||||
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
|
||||
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
|
||||
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
|
||||
"4. Write the summary in the user's language.\n"
|
||||
),
|
||||
"llm_compress_keep_recent": 4,
|
||||
"llm_compress_provider_id": "",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
@@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict):
|
||||
model: str
|
||||
modalities: list
|
||||
custom_extra_body: dict[str, Any]
|
||||
max_context_tokens: int
|
||||
|
||||
|
||||
CHAT_PROVIDER_TEMPLATE = {
|
||||
@@ -187,6 +198,7 @@ CHAT_PROVIDER_TEMPLATE = {
|
||||
"model": "",
|
||||
"modalities": [],
|
||||
"custom_extra_body": {},
|
||||
"max_context_tokens": 0,
|
||||
}
|
||||
|
||||
"""
|
||||
@@ -227,7 +239,7 @@ CONFIG_METADATA_2 = {
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6196,
|
||||
},
|
||||
"OneBot v11": {
|
||||
"OneBot v11 (QQ 个人号等)": {
|
||||
"id": "default",
|
||||
"type": "aiocqhttp",
|
||||
"enable": False,
|
||||
@@ -235,16 +247,6 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_port": 6199,
|
||||
"ws_reverse_token": "",
|
||||
},
|
||||
"WeChatPadPro": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
"enable": False,
|
||||
"admin_key": "stay33",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 8059,
|
||||
"wpp_active_message_poll": False,
|
||||
"wpp_active_message_poll_interval": 3,
|
||||
},
|
||||
"微信公众平台": {
|
||||
"id": "weixin_official_account",
|
||||
"type": "weixin_official_account",
|
||||
@@ -374,6 +376,16 @@ CONFIG_METADATA_2 = {
|
||||
"satori_heartbeat_interval": 10,
|
||||
"satori_reconnect_delay": 5,
|
||||
},
|
||||
"WeChatPadPro": {
|
||||
"id": "wechatpadpro",
|
||||
"type": "wechatpadpro",
|
||||
"enable": False,
|
||||
"admin_key": "stay33",
|
||||
"host": "这里填写你的局域网IP或者公网服务器IP",
|
||||
"port": 8059,
|
||||
"wpp_active_message_poll": False,
|
||||
"wpp_active_message_poll_interval": 3,
|
||||
},
|
||||
# "WebChat": {
|
||||
# "id": "webchat",
|
||||
# "type": "webchat",
|
||||
@@ -905,6 +917,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
"anth_thinking_config": {"budget": 0},
|
||||
},
|
||||
"Moonshot": {
|
||||
"id": "moonshot",
|
||||
@@ -920,7 +933,7 @@ CONFIG_METADATA_2 = {
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"provider": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"type": "xai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
@@ -1286,7 +1299,7 @@ CONFIG_METADATA_2 = {
|
||||
"minimax-is-timber-weight": False,
|
||||
"minimax-voice-id": "female-shaonv",
|
||||
"minimax-timber-weight": '[\n {\n "voice_id": "Chinese (Mandarin)_Warm_Girl",\n "weight": 25\n },\n {\n "voice_id": "Chinese (Mandarin)_BashfulGirl",\n "weight": 50\n }\n]',
|
||||
"minimax-voice-emotion": "neutral",
|
||||
"minimax-voice-emotion": "auto",
|
||||
"minimax-voice-latex": False,
|
||||
"minimax-voice-english-normalization": False,
|
||||
"timeout": 20,
|
||||
@@ -1450,7 +1463,32 @@ CONFIG_METADATA_2 = {
|
||||
"description": "自定义请求体参数",
|
||||
"type": "dict",
|
||||
"items": {},
|
||||
"hint": "此处添加的键值对将被合并到发送给 API 的 extra_body 中。值可以是字符串、数字或布尔值。",
|
||||
"hint": "用于在请求时添加额外的参数,如 temperature、top_p、max_tokens 等。",
|
||||
"template_schema": {
|
||||
"temperature": {
|
||||
"name": "Temperature",
|
||||
"description": "温度参数",
|
||||
"hint": "控制输出的随机性,范围通常为 0-2。值越高越随机。",
|
||||
"type": "float",
|
||||
"default": 0.6,
|
||||
"slider": {"min": 0, "max": 2, "step": 0.1},
|
||||
},
|
||||
"top_p": {
|
||||
"name": "Top-p",
|
||||
"description": "Top-p 采样",
|
||||
"hint": "核采样参数,范围通常为 0-1。控制模型考虑的概率质量。",
|
||||
"type": "float",
|
||||
"default": 1.0,
|
||||
"slider": {"min": 0, "max": 1, "step": 0.01},
|
||||
},
|
||||
"max_tokens": {
|
||||
"name": "Max Tokens",
|
||||
"description": "最大令牌数",
|
||||
"hint": "生成的最大令牌数。",
|
||||
"type": "int",
|
||||
"default": 8192,
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider": {
|
||||
"type": "string",
|
||||
@@ -1787,6 +1825,17 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"anth_thinking_config": {
|
||||
"description": "Thinking Config",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"budget": {
|
||||
"description": "Thinking Budget",
|
||||
"type": "int",
|
||||
"hint": "Anthropic thinking.budget_tokens param. Must >= 1024. See: https://platform.claude.com/docs/en/build-with-claude/extended-thinking",
|
||||
},
|
||||
},
|
||||
},
|
||||
"minimax-group-id": {
|
||||
"type": "string",
|
||||
"description": "用户组",
|
||||
@@ -1858,15 +1907,18 @@ CONFIG_METADATA_2 = {
|
||||
"minimax-voice-emotion": {
|
||||
"type": "string",
|
||||
"description": "情绪",
|
||||
"hint": "控制合成语音的情绪",
|
||||
"hint": "控制合成语音的情绪。当为 auto 时,将根据文本内容自动选择情绪。",
|
||||
"options": [
|
||||
"auto",
|
||||
"happy",
|
||||
"sad",
|
||||
"angry",
|
||||
"fearful",
|
||||
"disgusted",
|
||||
"surprised",
|
||||
"neutral",
|
||||
"calm",
|
||||
"fluent",
|
||||
"whisper",
|
||||
],
|
||||
},
|
||||
"minimax-voice-latex": {
|
||||
@@ -1993,6 +2045,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_context_tokens": {
|
||||
"description": "模型上下文窗口大小",
|
||||
"type": "int",
|
||||
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
"type": "string",
|
||||
@@ -2500,6 +2557,66 @@ CONFIG_METADATA_3 = {
|
||||
# "provider_settings.enable": True,
|
||||
# },
|
||||
# },
|
||||
"truncate_and_compress": {
|
||||
"description": "上下文管理策略",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.context_limit_reached_strategy": {
|
||||
"description": "超出模型上下文窗口时的处理方式",
|
||||
"type": "string",
|
||||
"options": ["truncate_by_turns", "llm_compress"],
|
||||
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
"hint": "",
|
||||
},
|
||||
"provider_settings.llm_compress_instruction": {
|
||||
"description": "上下文压缩提示词",
|
||||
"type": "text",
|
||||
"hint": "如果为空则使用默认提示词。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_keep_recent": {
|
||||
"description": "压缩时保留最近对话轮数",
|
||||
"type": "int",
|
||||
"hint": "始终保留的最近 N 轮对话。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.llm_compress_provider_id": {
|
||||
"description": "用于上下文压缩的模型提供商 ID",
|
||||
"type": "string",
|
||||
"_special": "select_provider",
|
||||
"hint": "留空时将降级为“按对话轮数截断”的策略。",
|
||||
"condition": {
|
||||
"provider_settings.context_limit_reached_strategy": "llm_compress",
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"type": "object",
|
||||
@@ -2564,22 +2681,6 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.streaming_response": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"type": "int",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
},
|
||||
},
|
||||
"provider_settings.wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀 ",
|
||||
"type": "string",
|
||||
@@ -3049,4 +3150,5 @@ DEFAULT_VALUE_MAP = {
|
||||
"text": "",
|
||||
"list": [],
|
||||
"object": {},
|
||||
"template_list": [],
|
||||
}
|
||||
|
||||
@@ -69,6 +69,7 @@ class ConversationManager:
|
||||
persona_id=conv_v2.persona_id,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
token_usage=conv_v2.token_usage,
|
||||
)
|
||||
|
||||
async def new_conversation(
|
||||
@@ -256,6 +257,7 @@ class ConversationManager:
|
||||
history: list[dict] | None = None,
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""更新会话的对话.
|
||||
|
||||
@@ -263,6 +265,7 @@ class ConversationManager:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
token_usage (int | None): token 使用量。None 表示不更新
|
||||
|
||||
"""
|
||||
if not conversation_id:
|
||||
@@ -274,6 +277,7 @@ class ConversationManager:
|
||||
title=title,
|
||||
persona_id=persona_id,
|
||||
content=history,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
async def update_conversation_title(
|
||||
|
||||
@@ -90,6 +90,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化 UMOP 配置路由器
|
||||
self.umop_config_router = UmopConfigRouter(sp=sp)
|
||||
await self.umop_config_router.initialize()
|
||||
|
||||
# 初始化 AstrBot 配置管理器
|
||||
self.astrbot_config_mgr = AstrBotConfigManager(
|
||||
|
||||
@@ -152,6 +152,7 @@ class BaseDatabase(abc.ABC):
|
||||
title: str | None = None,
|
||||
persona_id: str | None = None,
|
||||
content: list[dict] | None = None,
|
||||
token_usage: int | None = None,
|
||||
) -> None:
|
||||
"""Update a conversation's history."""
|
||||
...
|
||||
|
||||
@@ -0,0 +1,61 @@
|
||||
"""Migration script to add token_usage column to conversations table.
|
||||
|
||||
This migration adds the token_usage field to track token consumption for each conversation.
|
||||
|
||||
Changes:
|
||||
- Adds token_usage column to conversations table (default: 0)
|
||||
"""
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from astrbot.api import logger, sp
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
|
||||
async def migrate_token_usage(db_helper: BaseDatabase):
|
||||
"""Add token_usage column to conversations table.
|
||||
|
||||
This migration adds a new column to track token consumption in conversations.
|
||||
"""
|
||||
# 检查是否已经完成迁移
|
||||
migration_done = await db_helper.get_preference(
|
||||
"global", "global", "migration_done_token_usage_1"
|
||||
)
|
||||
if migration_done:
|
||||
return
|
||||
|
||||
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
|
||||
|
||||
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
|
||||
|
||||
try:
|
||||
async with db_helper.get_db() as session:
|
||||
# 检查列是否已存在
|
||||
result = await session.execute(text("PRAGMA table_info(conversations)"))
|
||||
columns = result.fetchall()
|
||||
column_names = [col[1] for col in columns]
|
||||
|
||||
if "token_usage" in column_names:
|
||||
logger.info("token_usage 列已存在,跳过迁移")
|
||||
await sp.put_async(
|
||||
"global", "global", "migration_done_token_usage_1", True
|
||||
)
|
||||
return
|
||||
|
||||
# 添加 token_usage 列
|
||||
await session.execute(
|
||||
text(
|
||||
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info("token_usage 列添加成功")
|
||||
|
||||
# 标记迁移完成
|
||||
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
|
||||
logger.info("token_usage 迁移完成")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True):
|
||||
)
|
||||
title: str | None = Field(default=None, max_length=255)
|
||||
persona_id: str | None = Field(default=None)
|
||||
token_usage: int = Field(default=0, nullable=False)
|
||||
"""content is a list of OpenAI-formated messages in list[dict] format.
|
||||
token_usage is the total token value of the messages.
|
||||
when 0, will use estimated token counter.
|
||||
"""
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
@@ -313,6 +318,8 @@ class Conversation:
|
||||
persona_id: str | None = ""
|
||||
created_at: int = 0
|
||||
updated_at: int = 0
|
||||
token_usage: int = 0
|
||||
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
|
||||
@@ -241,7 +241,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
session.add(new_conversation)
|
||||
return new_conversation
|
||||
|
||||
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
|
||||
async def update_conversation(
|
||||
self, cid, title=None, persona_id=None, content=None, token_usage=None
|
||||
):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
@@ -255,6 +257,8 @@ class SQLiteDatabase(BaseDatabase):
|
||||
values["persona_id"] = persona_id
|
||||
if content is not None:
|
||||
values["content"] = content
|
||||
if token_usage is not None:
|
||||
values["token_usage"] = token_usage
|
||||
if not values:
|
||||
return None
|
||||
query = query.values(**values)
|
||||
|
||||
@@ -149,8 +149,16 @@ class RecursiveCharacterChunker(BaseChunker):
|
||||
分割后的文本块列表
|
||||
|
||||
"""
|
||||
chunk_size = chunk_size or self.chunk_size
|
||||
overlap = overlap or self.chunk_overlap
|
||||
if chunk_size is None:
|
||||
chunk_size = self.chunk_size
|
||||
if overlap is None:
|
||||
overlap = self.chunk_overlap
|
||||
if chunk_size <= 0:
|
||||
raise ValueError("chunk_size must be greater than 0")
|
||||
if overlap < 0:
|
||||
raise ValueError("chunk_overlap must be non-negative")
|
||||
if overlap >= chunk_size:
|
||||
raise ValueError("chunk_overlap must be less than chunk_size")
|
||||
result = []
|
||||
for i in range(0, len(text), chunk_size - overlap):
|
||||
end = min(i + chunk_size, len(text))
|
||||
|
||||
@@ -38,7 +38,7 @@ class AgentRequestSubStage(Stage):
|
||||
)
|
||||
return
|
||||
|
||||
if not SessionServiceManager.should_process_llm_request(event):
|
||||
if not await SessionServiceManager.should_process_llm_request(event):
|
||||
logger.debug(
|
||||
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
|
||||
)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""本地 Agent 模式的 LLM 调用 Stage"""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.response import AgentStats
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
@@ -23,6 +24,7 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType, star_map
|
||||
from astrbot.core.utils.file_extract import extract_file_moonshotai
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.utils.session_lock import session_lock_manager
|
||||
|
||||
@@ -40,11 +42,6 @@ class InternalAgentSubStage(Stage):
|
||||
self.ctx = ctx
|
||||
conf = ctx.astrbot_config
|
||||
settings = conf["provider_settings"]
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
self.streaming_response: bool = settings["streaming_response"]
|
||||
self.unsupported_streaming_strategy: str = settings[
|
||||
"unsupported_streaming_strategy"
|
||||
@@ -64,6 +61,25 @@ class InternalAgentSubStage(Stage):
|
||||
"moonshotai_api_key", ""
|
||||
)
|
||||
|
||||
# 上下文管理相关
|
||||
self.context_limit_reached_strategy: str = settings.get(
|
||||
"context_limit_reached_strategy", "truncate_by_turns"
|
||||
)
|
||||
self.llm_compress_instruction: str = settings.get(
|
||||
"llm_compress_instruction", ""
|
||||
)
|
||||
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
|
||||
self.llm_compress_provider_id: str = settings.get(
|
||||
"llm_compress_provider_id", ""
|
||||
)
|
||||
self.max_context_length = settings["max_context_length"] # int
|
||||
self.dequeue_context_length: int = min(
|
||||
max(1, settings["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
)
|
||||
if self.dequeue_context_length <= 0:
|
||||
self.dequeue_context_length = 1
|
||||
|
||||
self.conv_manager = ctx.plugin_manager.context.conversation_manager
|
||||
|
||||
def _select_provider(self, event: AstrMessageEvent):
|
||||
@@ -166,34 +182,6 @@ class InternalAgentSubStage(Stage):
|
||||
},
|
||||
)
|
||||
|
||||
def _truncate_contexts(
|
||||
self,
|
||||
contexts: list[dict],
|
||||
) -> list[dict]:
|
||||
"""截断上下文列表,确保不超过最大长度"""
|
||||
if self.max_context_length == -1:
|
||||
return contexts
|
||||
|
||||
if len(contexts) // 2 <= self.max_context_length:
|
||||
return contexts
|
||||
|
||||
truncated_contexts = contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(truncated_contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
truncated_contexts = truncated_contexts[index:]
|
||||
|
||||
return truncated_contexts
|
||||
|
||||
def _modalities_fix(
|
||||
self,
|
||||
provider: Provider,
|
||||
@@ -294,6 +282,8 @@ class InternalAgentSubStage(Stage):
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
all_messages: list[Message],
|
||||
runner_stats: AgentStats | None,
|
||||
):
|
||||
if (
|
||||
not req
|
||||
@@ -307,222 +297,255 @@ class InternalAgentSubStage(Stage):
|
||||
logger.debug("LLM 响应为空,不保存记录。")
|
||||
return
|
||||
|
||||
if req.contexts is None:
|
||||
req.contexts = []
|
||||
# using agent context messages to save to history
|
||||
message_to_save = []
|
||||
for message in all_messages:
|
||||
if message.role == "system":
|
||||
# we do not save system messages to history
|
||||
continue
|
||||
if message.role in ["assistant", "user"] and getattr(
|
||||
message, "_no_save", None
|
||||
):
|
||||
# we do not save user and assistant messages that are marked as _no_save
|
||||
continue
|
||||
message_to_save.append(message.model_dump())
|
||||
|
||||
# get token usage from agent runner stats
|
||||
token_usage = None
|
||||
if runner_stats:
|
||||
token_usage = runner_stats.token_usage.total
|
||||
|
||||
# 历史上下文
|
||||
messages = copy.deepcopy(req.contexts)
|
||||
# 这一轮对话请求的用户输入
|
||||
messages.append(await req.assemble_context())
|
||||
# 这一轮对话的 LLM 响应
|
||||
if req.tool_calls_result:
|
||||
if not isinstance(req.tool_calls_result, list):
|
||||
messages.extend(req.tool_calls_result.to_openai_messages())
|
||||
elif isinstance(req.tool_calls_result, list):
|
||||
for tcr in req.tool_calls_result:
|
||||
messages.extend(tcr.to_openai_messages())
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text or "*No response*",
|
||||
}
|
||||
)
|
||||
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
history=messages,
|
||||
history=message_to_save,
|
||||
token_usage=token_usage,
|
||||
)
|
||||
|
||||
def _fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
"""验证并且修复上下文"""
|
||||
fixed_messages = []
|
||||
for message in messages:
|
||||
if message.get("role") == "tool":
|
||||
# tool block 前面必须要有 user 和 assistant block
|
||||
if len(fixed_messages) < 2:
|
||||
# 这种情况可能是上下文被截断导致的
|
||||
# 我们直接将之前的上下文都清空
|
||||
fixed_messages = []
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
else:
|
||||
fixed_messages.append(message)
|
||||
return fixed_messages
|
||||
def _get_compress_provider(self) -> Provider | None:
|
||||
if not self.llm_compress_provider_id:
|
||||
return None
|
||||
if self.context_limit_reached_strategy != "llm_compress":
|
||||
return None
|
||||
provider = self.ctx.plugin_manager.context.get_provider_by_id(
|
||||
self.llm_compress_provider_id,
|
||||
)
|
||||
if provider is None:
|
||||
logger.warning(
|
||||
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
|
||||
)
|
||||
return None
|
||||
if not isinstance(provider, Provider):
|
||||
logger.warning(
|
||||
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
|
||||
)
|
||||
return None
|
||||
return provider
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent, provider_wake_prefix: str
|
||||
) -> AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||
return
|
||||
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
logger.debug("ready to request llm provider")
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
try:
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(
|
||||
f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。"
|
||||
)
|
||||
return
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
logger.debug("ready to request llm provider")
|
||||
|
||||
# 通知等待调用 LLM(在获取锁之前)
|
||||
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
|
||||
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
# truncate contexts to fit max length
|
||||
# NOW moved to ContextManager inside ToolLoopAgentRunner
|
||||
# if req.contexts:
|
||||
# req.contexts = self._truncate_contexts(req.contexts)
|
||||
# self._fix_messages(req.contexts)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
||||
self._modalities_fix(provider, req)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
# filter tools, only keep tools from this pipeline's selected plugins
|
||||
self._plugin_tool_fix(event, req)
|
||||
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
||||
self._modalities_fix(provider, req)
|
||||
|
||||
# filter tools, only keep tools from this pipeline's selected plugins
|
||||
self._plugin_tool_fix(event, req)
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
|
||||
# inject model context length limit
|
||||
if provider.provider_config.get("max_context_tokens", 0) <= 0:
|
||||
model = provider.get_model()
|
||||
if model_info := LLM_METADATAS.get(model):
|
||||
provider.provider_config["max_context_tokens"] = model_info[
|
||||
"limit"
|
||||
]["context"]
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
llm_compress_instruction=self.llm_compress_instruction,
|
||||
llm_compress_keep_recent=self.llm_compress_keep_recent,
|
||||
llm_compress_provider=self._get_compress_provider(),
|
||||
truncate_turns=self.dequeue_context_length,
|
||||
enforce_max_turns=self.max_context_length,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
),
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
agent_runner.get_final_llm_resp(),
|
||||
agent_runner.run_context.messages,
|
||||
agent_runner.stats,
|
||||
)
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while processing agent: {e}")
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"Error occurred while processing agent request: {e}"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -98,6 +98,9 @@ class ResultDecorateStage(Stage):
|
||||
self.content_safe_check_stage = stage_cls()
|
||||
await self.content_safe_check_stage.initialize(ctx)
|
||||
|
||||
provider_cfg = ctx.astrbot_config.get("provider_settings", {})
|
||||
self.show_reasoning = provider_cfg.get("display_reasoning_text", False)
|
||||
|
||||
def _split_text_by_words(self, text: str) -> list[str]:
|
||||
"""使用分段词列表分段文本"""
|
||||
if not self.split_words_pattern:
|
||||
@@ -254,70 +257,75 @@ class ResultDecorateStage(Stage):
|
||||
event.unified_msg_origin,
|
||||
)
|
||||
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
should_tts = (
|
||||
bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"])
|
||||
and result.is_llm_result()
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
should_tts = self.tts_trigger_probability >= 1.0 or (
|
||||
self.tts_trigger_probability > 0.0
|
||||
and random.random() <= self.tts_trigger_probability
|
||||
and await SessionServiceManager.should_process_tts_request(event)
|
||||
and random.random() <= self.tts_trigger_probability
|
||||
and tts_provider
|
||||
)
|
||||
if should_tts and not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||
)
|
||||
|
||||
if not should_tts:
|
||||
logger.debug("跳过 TTS:触发概率未命中。")
|
||||
elif not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||
)
|
||||
else:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}",
|
||||
)
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
if (
|
||||
not should_tts
|
||||
and self.show_reasoning
|
||||
and event.get_extra("_llm_reasoning_content")
|
||||
):
|
||||
# inject reasoning content to chain
|
||||
reasoning_content = event.get_extra("_llm_reasoning_content")
|
||||
result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n"))
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path,
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
),
|
||||
if should_tts and tts_provider:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}",
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path,
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
),
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
|
||||
# 文本转图片
|
||||
elif (
|
||||
|
||||
@@ -21,7 +21,7 @@ class SessionStatusCheckStage(Stage):
|
||||
event: AstrMessageEvent,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
# 检查会话是否整体启用
|
||||
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
if not await SessionServiceManager.is_session_enabled(event.unified_msg_origin):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
|
||||
|
||||
# workaround for #2309
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
@@ -13,6 +14,23 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
|
||||
UNIQUE_SESSION_ID_BUILDERS: dict[str, Callable[[AstrMessageEvent], str | None]] = {
|
||||
"aiocqhttp": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
|
||||
"slack": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
|
||||
"dingtalk": lambda e: e.get_sender_id(),
|
||||
"qq_official": lambda e: e.get_sender_id(),
|
||||
"qq_official_webhook": lambda e: e.get_sender_id(),
|
||||
"lark": lambda e: f"{e.get_sender_id()}%{e.get_group_id()}",
|
||||
"misskey": lambda e: f"{e.get_session_id()}_{e.get_sender_id()}",
|
||||
"wechatpadpro": lambda e: f"{e.get_group_id()}#{e.get_sender_id()}",
|
||||
}
|
||||
|
||||
|
||||
def build_unique_session_id(event: AstrMessageEvent) -> str | None:
|
||||
platform = event.get_platform_name()
|
||||
builder = UNIQUE_SESSION_ID_BUILDERS.get(platform)
|
||||
return builder(event) if builder else None
|
||||
|
||||
|
||||
@register_stage
|
||||
class WakingCheckStage(Stage):
|
||||
@@ -53,18 +71,27 @@ class WakingCheckStage(Stage):
|
||||
self.disable_builtin_commands = self.ctx.astrbot_config.get(
|
||||
"disable_builtin_commands", False
|
||||
)
|
||||
platform_settings = self.ctx.astrbot_config.get("platform_settings", {})
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
# apply unique session
|
||||
if self.unique_session and event.message_obj.type == MessageType.GROUP_MESSAGE:
|
||||
sid = build_unique_session_id(event)
|
||||
if sid:
|
||||
event.session_id = sid
|
||||
|
||||
# ignore bot self message
|
||||
if (
|
||||
self.ignore_bot_self_message
|
||||
and event.get_self_id() == event.get_sender_id()
|
||||
):
|
||||
# 忽略机器人自己发送的消息
|
||||
event.stop_event()
|
||||
return
|
||||
|
||||
# 设置 sender 身份
|
||||
event.message_str = event.message_str.strip()
|
||||
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
||||
@@ -200,7 +227,7 @@ class WakingCheckStage(Stage):
|
||||
event._extras.pop("parsed_params", None)
|
||||
|
||||
# 根据会话配置过滤插件处理器
|
||||
activated_handlers = SessionPluginManager.filter_handlers_by_session(
|
||||
activated_handlers = await SessionPluginManager.filter_handlers_by_session(
|
||||
event,
|
||||
activated_handlers,
|
||||
)
|
||||
|
||||
@@ -41,7 +41,6 @@ class AiocqhttpAdapter(Platform):
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.host = platform_config["ws_reverse_host"]
|
||||
self.port = platform_config["ws_reverse_port"]
|
||||
|
||||
@@ -136,14 +135,11 @@ class AiocqhttpAdapter(Platform):
|
||||
abm.group_id = str(event.group_id)
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
abm.timestamp = int(time.time())
|
||||
@@ -164,16 +160,11 @@ class AiocqhttpAdapter(Platform):
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = (
|
||||
str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||
) # 也保留群组 id
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
abm.raw_message = event
|
||||
@@ -210,16 +201,11 @@ class AiocqhttpAdapter(Platform):
|
||||
abm.group.group_name = event.get("group_name", "N/A")
|
||||
elif event["message_type"] == "private":
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = (
|
||||
abm.sender.user_id + "_" + str(event.group_id)
|
||||
) # 也保留群组 id
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
|
||||
abm.message_id = str(event.message_id)
|
||||
abm.message = []
|
||||
|
||||
@@ -50,8 +50,6 @@ class DingtalkPlatformAdapter(Platform):
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
self.client_id = platform_config["client_id"]
|
||||
self.client_secret = platform_config["client_secret"]
|
||||
|
||||
@@ -129,10 +127,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
if id := self._id_to_sid(user.dingtalk_id):
|
||||
abm.message.append(At(qq=id))
|
||||
abm.group_id = message.conversation_id
|
||||
if self.unique_session:
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
abm.session_id = abm.group_id
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
|
||||
@@ -25,6 +25,20 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
client: dingtalk_stream.ChatbotHandler,
|
||||
message: MessageChain,
|
||||
):
|
||||
icm = cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message)
|
||||
ats = []
|
||||
# fixes: #4218
|
||||
# 钉钉 at 机器人需要使用 sender_staff_id 而不是 sender_id
|
||||
for i in message.chain:
|
||||
if isinstance(i, Comp.At):
|
||||
print(i.qq, icm.sender_id, icm.sender_staff_id)
|
||||
if str(i.qq) in str(icm.sender_id or ""):
|
||||
# 适配器会将开头的 $:LWCP_v1:$ 去掉,因此我们用 in 判断
|
||||
ats.append(f"@{icm.sender_staff_id}")
|
||||
else:
|
||||
ats.append(f"@{i.qq}")
|
||||
at_str = " ".join(ats)
|
||||
|
||||
for segment in message.chain:
|
||||
if isinstance(segment, Comp.Plain):
|
||||
segment.text = segment.text.strip()
|
||||
@@ -32,7 +46,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
None,
|
||||
client.reply_markdown,
|
||||
segment.text,
|
||||
segment.text,
|
||||
f"{at_str} {segment.text}".strip(),
|
||||
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
|
||||
)
|
||||
elif isinstance(segment, Comp.Image):
|
||||
|
||||
@@ -44,8 +44,6 @@ class LarkPlatformAdapter(Platform):
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
self.appid = platform_config["app_id"]
|
||||
self.appsecret = platform_config["app_secret"]
|
||||
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
|
||||
@@ -317,14 +315,8 @@ class LarkPlatformAdapter(Platform):
|
||||
user_id=event.event.sender.sender_id.open_id,
|
||||
nickname=event.event.sender.sender_id.open_id[:8],
|
||||
)
|
||||
# 独立会话
|
||||
if not self.unique_session:
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
|
||||
@@ -91,8 +91,6 @@ class MisskeyPlatformAdapter(Platform):
|
||||
except Exception:
|
||||
self.max_download_bytes = None
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
self.api: MisskeyAPI | None = None
|
||||
self._running = False
|
||||
self.client_self_id = ""
|
||||
@@ -641,7 +639,6 @@ class MisskeyPlatformAdapter(Platform):
|
||||
sender_info,
|
||||
self.client_self_id,
|
||||
is_chat=False,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
cache_user_info(
|
||||
self._user_cache,
|
||||
@@ -690,7 +687,6 @@ class MisskeyPlatformAdapter(Platform):
|
||||
sender_info,
|
||||
self.client_self_id,
|
||||
is_chat=True,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
cache_user_info(
|
||||
self._user_cache,
|
||||
@@ -720,7 +716,6 @@ class MisskeyPlatformAdapter(Platform):
|
||||
self.client_self_id,
|
||||
is_chat=False,
|
||||
room_id=room_id,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
|
||||
cache_user_info(
|
||||
|
||||
@@ -338,7 +338,6 @@ def create_base_message(
|
||||
client_self_id: str,
|
||||
is_chat: bool = False,
|
||||
room_id: str | None = None,
|
||||
unique_session: bool = False,
|
||||
) -> AstrBotMessage:
|
||||
"""创建基础消息对象"""
|
||||
message = AstrBotMessage()
|
||||
@@ -353,8 +352,6 @@ def create_base_message(
|
||||
if room_id:
|
||||
session_prefix = "room"
|
||||
session_id = f"{session_prefix}%{room_id}"
|
||||
if unique_session:
|
||||
session_id += f"_{sender_info['sender_id']}"
|
||||
message.type = MessageType.GROUP_MESSAGE
|
||||
message.group_id = room_id
|
||||
elif is_chat:
|
||||
|
||||
@@ -44,11 +44,8 @@ class botClient(Client):
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id
|
||||
if self.platform.unique_session
|
||||
else cast(str, message.group_openid)
|
||||
)
|
||||
abm.group_id = cast(str, message.group_openid)
|
||||
abm.session_id = abm.group_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到频道消息
|
||||
@@ -57,9 +54,8 @@ class botClient(Client):
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id if self.platform.unique_session else message.channel_id
|
||||
)
|
||||
abm.group_id = message.channel_id
|
||||
abm.session_id = abm.group_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到私聊消息
|
||||
@@ -104,7 +100,6 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
|
||||
self.appid = platform_config["appid"]
|
||||
self.secret = platform_config["secret"]
|
||||
self.unique_session: bool = platform_settings["unique_session"]
|
||||
qq_group = platform_config["enable_group_c2c"]
|
||||
guild_dm = platform_config["enable_guild_direct_message"]
|
||||
|
||||
|
||||
@@ -35,11 +35,8 @@ class botClient(Client):
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id
|
||||
if self.platform.unique_session
|
||||
else cast(str, message.group_openid)
|
||||
)
|
||||
abm.group_id = cast(str, message.group_openid)
|
||||
abm.session_id = abm.group_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到频道消息
|
||||
@@ -48,9 +45,8 @@ class botClient(Client):
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id if self.platform.unique_session else message.channel_id
|
||||
)
|
||||
abm.group_id = message.channel_id
|
||||
abm.session_id = abm.group_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到私聊消息
|
||||
@@ -95,7 +91,6 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
|
||||
self.appid = platform_config["appid"]
|
||||
self.secret = platform_config["secret"]
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
|
||||
|
||||
intents = botpy.Intents(
|
||||
|
||||
@@ -142,7 +142,12 @@ class SatoriPlatformAdapter(Platform):
|
||||
raise ValueError(f"WebSocket URL必须以ws://或wss://开头: {self.endpoint}")
|
||||
|
||||
try:
|
||||
websocket = await connect(self.endpoint, additional_headers={})
|
||||
websocket = await connect(
|
||||
self.endpoint,
|
||||
additional_headers={},
|
||||
max_size=10 * 1024 * 1024, # 10MB
|
||||
)
|
||||
|
||||
self.ws = websocket
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -41,7 +41,6 @@ class SlackAdapter(Platform):
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
self.bot_token = platform_config.get("bot_token")
|
||||
self.app_token = platform_config.get("app_token")
|
||||
@@ -147,12 +146,10 @@ class SlackAdapter(Platform):
|
||||
abm.group_id = channel_id
|
||||
|
||||
# 设置会话ID
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = f"{user_id}_{channel_id}"
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = (
|
||||
channel_id if abm.type == MessageType.GROUP_MESSAGE else user_id
|
||||
)
|
||||
abm.session_id = user_id
|
||||
|
||||
abm.message_id = event.get("client_msg_id", uuid.uuid4().hex)
|
||||
abm.timestamp = int(float(event.get("ts", time.time())))
|
||||
|
||||
@@ -79,7 +79,6 @@ class WebChatAdapter(Platform):
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
|
||||
@@ -47,7 +47,6 @@ class WeChatPadProAdapter(Platform):
|
||||
self._shutdown_event = None
|
||||
self.wxnewpass = None
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="wechatpadpro",
|
||||
@@ -509,11 +508,10 @@ class WeChatPadProAdapter(Platform):
|
||||
if accurate_nickname:
|
||||
abm.sender.nickname = accurate_nickname
|
||||
|
||||
# 对于群聊,session_id 可以是群聊 ID 或发送者 ID + 群聊 ID (如果 unique_session 为 True)
|
||||
if self.unique_session:
|
||||
abm.session_id = f"{from_user_name}#{abm.sender.user_id}"
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = from_user_name
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
msg_source = raw_message.get("msg_source", "")
|
||||
if self.wxid in msg_source:
|
||||
|
||||
@@ -191,7 +191,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
||||
if self.active_send_mode:
|
||||
await self.convert_message(msg, None)
|
||||
else:
|
||||
if msg.id in self.wexin_event_workers:
|
||||
if str(msg.id) in self.wexin_event_workers:
|
||||
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
|
||||
logger.debug(f"duplicate message id checked: {msg.id}")
|
||||
else:
|
||||
|
||||
@@ -272,6 +272,8 @@ class LLMResponse:
|
||||
"""Tool call extra content. tool_call_id -> extra_content dict"""
|
||||
reasoning_content: str = ""
|
||||
"""The reasoning content extracted from the LLM, if any."""
|
||||
reasoning_signature: str | None = None
|
||||
"""The signature of the reasoning content, if any."""
|
||||
|
||||
raw_completion: (
|
||||
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
|
||||
@@ -292,12 +294,14 @@ class LLMResponse:
|
||||
def __init__(
|
||||
self,
|
||||
role: str,
|
||||
completion_text: str = "",
|
||||
completion_text: str | None = None,
|
||||
result_chain: MessageChain | None = None,
|
||||
tools_call_args: list[dict[str, Any]] | None = None,
|
||||
tools_call_name: list[str] | None = None,
|
||||
tools_call_ids: list[str] | None = None,
|
||||
tools_call_extra_content: dict[str, dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
reasoning_signature: str | None = None,
|
||||
raw_completion: ChatCompletion
|
||||
| GenerateContentResponse
|
||||
| AnthropicMessage
|
||||
@@ -317,6 +321,8 @@ class LLMResponse:
|
||||
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
|
||||
|
||||
"""
|
||||
if reasoning_content is None:
|
||||
reasoning_content = ""
|
||||
if tools_call_args is None:
|
||||
tools_call_args = []
|
||||
if tools_call_name is None:
|
||||
@@ -333,6 +339,8 @@ class LLMResponse:
|
||||
self.tools_call_name = tools_call_name
|
||||
self.tools_call_ids = tools_call_ids
|
||||
self.tools_call_extra_content = tools_call_extra_content
|
||||
self.reasoning_content = reasoning_content
|
||||
self.reasoning_signature = reasoning_signature
|
||||
self.raw_completion = raw_completion
|
||||
self.is_chunk = is_chunk
|
||||
|
||||
|
||||
@@ -119,19 +119,34 @@ class ProviderManager:
|
||||
TTSProvider,
|
||||
):
|
||||
self.curr_tts_provider_inst = prov
|
||||
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
||||
await sp.put_async(
|
||||
key="curr_provider_tts",
|
||||
value=provider_id,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
||||
prov,
|
||||
STTProvider,
|
||||
):
|
||||
self.curr_stt_provider_inst = prov
|
||||
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
||||
await sp.put_async(
|
||||
key="curr_provider_stt",
|
||||
value=provider_id,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
||||
prov,
|
||||
Provider,
|
||||
):
|
||||
self.curr_provider_inst = prov
|
||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||
await sp.put_async(
|
||||
key="curr_provider",
|
||||
value=provider_id,
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
|
||||
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
|
||||
"""根据提供商 ID 获取提供商实例"""
|
||||
@@ -206,21 +221,21 @@ class ProviderManager:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(e)
|
||||
|
||||
selected_provider_id = sp.get(
|
||||
"curr_provider",
|
||||
self.provider_settings.get("default_provider_id"),
|
||||
selected_provider_id = await sp.get_async(
|
||||
key="curr_provider",
|
||||
default=self.provider_settings.get("default_provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
selected_stt_provider_id = sp.get(
|
||||
"curr_provider_stt",
|
||||
self.provider_stt_settings.get("provider_id"),
|
||||
selected_stt_provider_id = await sp.get_async(
|
||||
key="curr_provider_stt",
|
||||
default=self.provider_stt_settings.get("provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
selected_tts_provider_id = sp.get(
|
||||
"curr_provider_tts",
|
||||
self.provider_tts_settings.get("provider_id"),
|
||||
selected_tts_provider_id = await sp.get_async(
|
||||
key="curr_provider_tts",
|
||||
default=self.provider_tts_settings.get("provider_id"),
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
|
||||
@@ -48,6 +48,8 @@ class ProviderAnthropic(Provider):
|
||||
base_url=self.base_url,
|
||||
)
|
||||
|
||||
self.thinking_config = provider_config.get("anth_thinking_config", {})
|
||||
|
||||
self.set_model(provider_config.get("model", "unknown"))
|
||||
|
||||
def _prepare_payload(self, messages: list[dict]):
|
||||
@@ -64,11 +66,32 @@ class ProviderAnthropic(Provider):
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_prompt = message["content"]
|
||||
system_prompt = message["content"] or "<empty system prompt>"
|
||||
elif message["role"] == "assistant":
|
||||
blocks = []
|
||||
if isinstance(message["content"], str):
|
||||
reasoning_content = ""
|
||||
thinking_signature = ""
|
||||
if isinstance(message["content"], str) and message["content"].strip():
|
||||
blocks.append({"type": "text", "text": message["content"]})
|
||||
elif isinstance(message["content"], list):
|
||||
for part in message["content"]:
|
||||
if part.get("type") == "think":
|
||||
# only pick the last think part for now
|
||||
reasoning_content = part.get("think")
|
||||
thinking_signature = part.get("encrypted")
|
||||
else:
|
||||
blocks.append(part)
|
||||
|
||||
if reasoning_content and thinking_signature:
|
||||
blocks.insert(
|
||||
0,
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": reasoning_content,
|
||||
"signature": thinking_signature,
|
||||
},
|
||||
)
|
||||
|
||||
if "tool_calls" in message and isinstance(message["tool_calls"], list):
|
||||
for tool_call in message["tool_calls"]:
|
||||
blocks.append( # noqa: PERF401
|
||||
@@ -100,7 +123,7 @@ class ProviderAnthropic(Provider):
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
"content": message["content"] or "<empty response>",
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -135,6 +158,11 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
if "max_tokens" not in payloads:
|
||||
payloads["max_tokens"] = 1024
|
||||
if self.thinking_config.get("budget"):
|
||||
payloads["thinking"] = {
|
||||
"budget_tokens": self.thinking_config.get("budget"),
|
||||
"type": "enabled",
|
||||
}
|
||||
|
||||
completion = await self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
@@ -153,6 +181,11 @@ class ProviderAnthropic(Provider):
|
||||
completion_text = str(content_block.text).strip()
|
||||
llm_response.completion_text = completion_text
|
||||
|
||||
if content_block.type == "thinking":
|
||||
reasoning_content = str(content_block.thinking).strip()
|
||||
llm_response.reasoning_content = reasoning_content
|
||||
llm_response.reasoning_signature = content_block.signature
|
||||
|
||||
if content_block.type == "tool_use":
|
||||
llm_response.tools_call_args.append(content_block.input)
|
||||
llm_response.tools_call_name.append(content_block.name)
|
||||
@@ -184,9 +217,16 @@ class ProviderAnthropic(Provider):
|
||||
id = None
|
||||
usage = TokenUsage()
|
||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
reasoning_content = ""
|
||||
reasoning_signature = ""
|
||||
|
||||
if "max_tokens" not in payloads:
|
||||
payloads["max_tokens"] = 1024
|
||||
if self.thinking_config.get("budget"):
|
||||
payloads["thinking"] = {
|
||||
"budget_tokens": self.thinking_config.get("budget"),
|
||||
"type": "enabled",
|
||||
}
|
||||
|
||||
async with self.client.messages.stream(
|
||||
**payloads, extra_body=extra_body
|
||||
@@ -226,6 +266,21 @@ class ProviderAnthropic(Provider):
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
elif event.delta.type == "thinking_delta":
|
||||
# 思考增量
|
||||
reasoning = event.delta.thinking
|
||||
if reasoning:
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
reasoning_content=reasoning,
|
||||
is_chunk=True,
|
||||
usage=usage,
|
||||
id=id,
|
||||
reasoning_signature=reasoning_signature or None,
|
||||
)
|
||||
reasoning_content += reasoning
|
||||
elif event.delta.type == "signature_delta":
|
||||
reasoning_signature = event.delta.signature
|
||||
elif event.delta.type == "input_json_delta":
|
||||
# 工具调用参数增量
|
||||
if event.index in tool_use_buffer:
|
||||
@@ -282,6 +337,8 @@ class ProviderAnthropic(Provider):
|
||||
is_chunk=False,
|
||||
usage=usage,
|
||||
id=id,
|
||||
reasoning_content=reasoning_content,
|
||||
reasoning_signature=reasoning_signature or None,
|
||||
)
|
||||
|
||||
if final_tool_calls:
|
||||
|
||||
@@ -321,9 +321,37 @@ class ProviderGoogleGenAI(Provider):
|
||||
append_or_extend(gemini_contents, parts, types.UserContent)
|
||||
|
||||
elif role == "assistant":
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
parts = [types.Part.from_text(text=content)]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
elif isinstance(content, list):
|
||||
parts = []
|
||||
thinking_signature = None
|
||||
text = ""
|
||||
for part in content:
|
||||
# for most cases, assistant content only contains two parts: think and text
|
||||
if part.get("type") == "think":
|
||||
thinking_signature = part.get("encrypted") or None
|
||||
else:
|
||||
text += str(part.get("text"))
|
||||
|
||||
if thinking_signature and isinstance(thinking_signature, str):
|
||||
try:
|
||||
thinking_signature = base64.b64decode(thinking_signature)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to decode google gemini thinking signature: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
thinking_signature = None
|
||||
parts.append(
|
||||
types.Part(
|
||||
text=text,
|
||||
thought_signature=thinking_signature,
|
||||
)
|
||||
)
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
|
||||
elif not native_tool_enabled and "tool_calls" in message:
|
||||
parts = []
|
||||
for tool in message["tool_calls"]:
|
||||
@@ -441,7 +469,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
for part in result_parts:
|
||||
if part.text:
|
||||
chain.append(Comp.Plain(part.text))
|
||||
elif (
|
||||
|
||||
if (
|
||||
part.function_call
|
||||
and part.function_call.name is not None
|
||||
and part.function_call.args is not None
|
||||
@@ -458,13 +487,18 @@ class ProviderGoogleGenAI(Provider):
|
||||
llm_response.tools_call_extra_content[tool_call_id] = {
|
||||
"google": {"thought_signature": ts_bs64}
|
||||
}
|
||||
elif (
|
||||
|
||||
if (
|
||||
part.inline_data
|
||||
and part.inline_data.mime_type
|
||||
and part.inline_data.mime_type.startswith("image/")
|
||||
and part.inline_data.data
|
||||
):
|
||||
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||
|
||||
if ts := part.thought_signature:
|
||||
# only keep the last thinking signature
|
||||
llm_response.reasoning_signature = base64.b64encode(ts).decode("utf-8")
|
||||
return MessageChain(chain=chain)
|
||||
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
|
||||
@@ -51,7 +51,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
"voice_id": ""
|
||||
if self.is_timber_weight
|
||||
else provider_config.get("minimax-voice-id", ""),
|
||||
"emotion": provider_config.get("minimax-voice-emotion", "neutral"),
|
||||
"emotion": provider_config.get("minimax-voice-emotion", "auto"),
|
||||
"latex_read": provider_config.get("minimax-voice-latex", False),
|
||||
"english_normalization": provider_config.get(
|
||||
"minimax-voice-english-normalization",
|
||||
@@ -59,6 +59,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
),
|
||||
}
|
||||
|
||||
if self.voice_setting["emotion"] == "auto":
|
||||
self.voice_setting.pop("emotion", None)
|
||||
|
||||
self.audio_setting: dict = {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
|
||||
@@ -74,28 +74,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
self.reasoning_key = "reasoning_content"
|
||||
|
||||
def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
|
||||
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
|
||||
|
||||
- 仅在 provider_config.xai_native_search 为 True 时生效
|
||||
- 默认注入 {"mode": "auto"}
|
||||
- 允许通过 kwargs 使用 xai_search_mode 覆盖(on/auto/off)
|
||||
"""
|
||||
if not bool(self.provider_config.get("xai_native_search", False)):
|
||||
return
|
||||
|
||||
mode = kwargs.get("xai_search_mode", "auto")
|
||||
mode = str(mode).lower()
|
||||
if mode not in ("auto", "on", "off"):
|
||||
mode = "auto"
|
||||
|
||||
# off 时不注入,保持与未开启一致
|
||||
if mode == "off":
|
||||
return
|
||||
|
||||
# OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body
|
||||
payloads["search_parameters"] = {"mode": mode}
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
models_str = []
|
||||
@@ -134,10 +112,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
model = payloads.get("model", "").lower()
|
||||
|
||||
# 针对 deepseek 模型的特殊处理:deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
|
||||
if model == "deepseek-reasoner" and "tools" in payloads:
|
||||
del payloads["tools"]
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False,
|
||||
@@ -251,10 +225,14 @@ class ProviderOpenAIOfficial(Provider):
|
||||
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
|
||||
ptd = usage.prompt_tokens_details
|
||||
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
|
||||
prompt_tokens = 0 if usage.prompt_tokens is None else usage.prompt_tokens
|
||||
completion_tokens = (
|
||||
0 if usage.completion_tokens is None else usage.completion_tokens
|
||||
)
|
||||
return TokenUsage(
|
||||
input_other=usage.prompt_tokens - cached,
|
||||
input_cached=ptd.cached_tokens if ptd and ptd.cached_tokens else 0,
|
||||
output=usage.completion_tokens,
|
||||
input_other=prompt_tokens - cached,
|
||||
input_cached=cached,
|
||||
output=completion_tokens,
|
||||
)
|
||||
|
||||
async def _parse_openai_completion(
|
||||
@@ -381,11 +359,28 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
# xAI origin search tool inject
|
||||
self._maybe_inject_xai_search(payloads, **kwargs)
|
||||
self._finally_convert_payload(payloads)
|
||||
|
||||
return payloads, context_query
|
||||
|
||||
def _finally_convert_payload(self, payloads: dict):
|
||||
"""Finally convert the payload. Such as think part conversion, tool inject."""
|
||||
for message in payloads.get("messages", []):
|
||||
if message.get("role") == "assistant" and isinstance(
|
||||
message.get("content"), list
|
||||
):
|
||||
reasoning_content = ""
|
||||
new_content = [] # not including think part
|
||||
for part in message["content"]:
|
||||
if part.get("type") == "think":
|
||||
reasoning_content += str(part.get("think"))
|
||||
else:
|
||||
new_content.append(part)
|
||||
message["content"] = new_content
|
||||
# reasoning key is "reasoning_content"
|
||||
if reasoning_content:
|
||||
message["reasoning_content"] = reasoning_content
|
||||
|
||||
async def _handle_api_error(
|
||||
self,
|
||||
e: Exception,
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
from ..register import register_provider_adapter
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"xai_chat_completion", "xAI Chat Completion Provider Adapter"
|
||||
)
|
||||
class ProviderXAI(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
|
||||
def _maybe_inject_xai_search(self, payloads: dict):
|
||||
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
|
||||
|
||||
- 仅在 provider_config.xai_native_search 为 True 时生效
|
||||
- 默认注入 {"mode": "auto"}
|
||||
"""
|
||||
if not bool(self.provider_config.get("xai_native_search", False)):
|
||||
return
|
||||
# OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body
|
||||
payloads["search_parameters"] = {"mode": "auto"}
|
||||
|
||||
def _finally_convert_payload(self, payloads: dict):
|
||||
self._maybe_inject_xai_search(payloads)
|
||||
super()._finally_convert_payload(payloads)
|
||||
@@ -8,7 +8,10 @@ from xinference_client.client.restful.async_restful_client import (
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.tencent_record_helper import (
|
||||
convert_to_pcm_wav,
|
||||
tencent_silk_to_wav,
|
||||
)
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import STTProvider
|
||||
@@ -111,17 +114,22 @@ class ProviderXinferenceSTT(STTProvider):
|
||||
return ""
|
||||
|
||||
# 2. Check for conversion
|
||||
needs_conversion = False
|
||||
if (
|
||||
audio_url.endswith((".amr", ".silk"))
|
||||
or is_tencent
|
||||
or b"SILK" in audio_bytes[:8]
|
||||
):
|
||||
needs_conversion = True
|
||||
conversion_type = None
|
||||
|
||||
if b"SILK" in audio_bytes[:8]:
|
||||
conversion_type = "silk"
|
||||
elif b"#!AMR" in audio_bytes[:6]:
|
||||
conversion_type = "amr"
|
||||
elif audio_url.endswith(".silk") or is_tencent:
|
||||
conversion_type = "silk"
|
||||
elif audio_url.endswith(".amr"):
|
||||
conversion_type = "amr"
|
||||
|
||||
# 3. Perform conversion if needed
|
||||
if needs_conversion:
|
||||
logger.info("Audio requires conversion, using temporary files...")
|
||||
if conversion_type:
|
||||
logger.info(
|
||||
f"Audio requires conversion ({conversion_type}), using temporary files..."
|
||||
)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
@@ -132,8 +140,12 @@ class ProviderXinferenceSTT(STTProvider):
|
||||
with open(input_path, "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
logger.info("Converting silk/amr file to wav ...")
|
||||
await tencent_silk_to_wav(input_path, output_path)
|
||||
if conversion_type == "silk":
|
||||
logger.info("Converting silk to wav ...")
|
||||
await tencent_silk_to_wav(input_path, output_path)
|
||||
elif conversion_type == "amr":
|
||||
logger.info("Converting amr to wav ...")
|
||||
await convert_to_pcm_wav(input_path, output_path)
|
||||
|
||||
with open(output_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
@@ -149,9 +149,12 @@ class Context:
|
||||
contexts: context messages for the LLM
|
||||
max_steps: Maximum number of tool calls before stopping the loop
|
||||
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
|
||||
stream: bool - whether to stream the LLM response
|
||||
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
|
||||
agent_context: AstrAgentContext - context to use for the agent
|
||||
|
||||
other kwargs will be DIRECTLY passed to the runner.reset() method
|
||||
|
||||
Returns:
|
||||
The final LLMResponse after tool calls are completed.
|
||||
|
||||
@@ -194,6 +197,15 @@ class Context:
|
||||
)
|
||||
agent_runner = ToolLoopAgentRunner()
|
||||
tool_executor = FunctionToolExecutor()
|
||||
|
||||
streaming = kwargs.get("stream", False)
|
||||
|
||||
other_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k not in ["stream", "agent_hooks", "agent_context"]
|
||||
}
|
||||
|
||||
await agent_runner.reset(
|
||||
provider=prov,
|
||||
request=request,
|
||||
@@ -203,7 +215,8 @@ class Context:
|
||||
),
|
||||
tool_executor=tool_executor,
|
||||
agent_hooks=agent_hooks,
|
||||
streaming=kwargs.get("stream", False),
|
||||
streaming=streaming,
|
||||
**other_kwargs,
|
||||
)
|
||||
async for _ in agent_runner.step_until_done(max_steps):
|
||||
pass
|
||||
|
||||
@@ -12,6 +12,7 @@ from .star_handler import (
|
||||
register_on_llm_request,
|
||||
register_on_llm_response,
|
||||
register_on_platform_loaded,
|
||||
register_on_waiting_llm_request,
|
||||
register_permission_type,
|
||||
register_platform_adapter_type,
|
||||
register_regex,
|
||||
@@ -30,6 +31,7 @@ __all__ = [
|
||||
"register_on_llm_request",
|
||||
"register_on_llm_response",
|
||||
"register_on_platform_loaded",
|
||||
"register_on_waiting_llm_request",
|
||||
"register_permission_type",
|
||||
"register_platform_adapter_type",
|
||||
"register_regex",
|
||||
|
||||
@@ -339,6 +339,30 @@ def register_on_platform_loaded(**kwargs):
|
||||
return decorator
|
||||
|
||||
|
||||
def register_on_waiting_llm_request(**kwargs):
|
||||
"""当等待调用 LLM 时的通知事件(在获取锁之前)
|
||||
|
||||
此钩子在消息确定要调用 LLM 但还未开始排队等锁时触发,
|
||||
适合用于发送"正在思考中..."等用户反馈提示。
|
||||
|
||||
Examples:
|
||||
```py
|
||||
@on_waiting_llm_request()
|
||||
async def on_waiting_llm(self, event: AstrMessageEvent) -> None:
|
||||
await event.send("🤔 正在思考中...")
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(
|
||||
awaitable, EventType.OnWaitingLLMRequestEvent, **kwargs
|
||||
)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_on_llm_request(**kwargs):
|
||||
"""当有 LLM 请求时的事件
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class SessionServiceManager:
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_llm_enabled_for_session(session_id: str) -> bool:
|
||||
async def is_llm_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查LLM是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
@@ -23,11 +23,11 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_services = sp.get(
|
||||
"session_service_config",
|
||||
{},
|
||||
session_services = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
default={},
|
||||
)
|
||||
|
||||
# 如果配置了该会话的LLM状态,返回该状态
|
||||
@@ -39,7 +39,7 @@ class SessionServiceManager:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
async def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置LLM在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
@@ -48,18 +48,24 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
default={},
|
||||
)
|
||||
or {}
|
||||
)
|
||||
session_config["llm_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config",
|
||||
session_config,
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
value=session_config,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||
async def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理LLM请求
|
||||
|
||||
Args:
|
||||
@@ -70,14 +76,14 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_llm_enabled_for_session(session_id)
|
||||
return await SessionServiceManager.is_llm_enabled_for_session(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# TTS 相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_tts_enabled_for_session(session_id: str) -> bool:
|
||||
async def is_tts_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查TTS是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
@@ -88,11 +94,11 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_services = sp.get(
|
||||
"session_service_config",
|
||||
{},
|
||||
session_services = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
default={},
|
||||
)
|
||||
|
||||
# 如果配置了该会话的TTS状态,返回该状态
|
||||
@@ -104,7 +110,7 @@ class SessionServiceManager:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
async def set_tts_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置TTS在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
@@ -113,14 +119,20 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
session_config = (
|
||||
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
|
||||
await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
default={},
|
||||
)
|
||||
or {}
|
||||
)
|
||||
session_config["tts_enabled"] = enabled
|
||||
sp.put(
|
||||
"session_service_config",
|
||||
session_config,
|
||||
await sp.put_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
value=session_config,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -128,7 +140,7 @@ class SessionServiceManager:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_process_tts_request(event: AstrMessageEvent) -> bool:
|
||||
async def should_process_tts_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理TTS请求
|
||||
|
||||
Args:
|
||||
@@ -139,14 +151,14 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
return await SessionServiceManager.is_tts_enabled_for_session(session_id)
|
||||
|
||||
# =============================================================================
|
||||
# 会话整体启停相关方法
|
||||
# =============================================================================
|
||||
|
||||
@staticmethod
|
||||
def is_session_enabled(session_id: str) -> bool:
|
||||
async def is_session_enabled(session_id: str) -> bool:
|
||||
"""检查会话是否整体启用
|
||||
|
||||
Args:
|
||||
@@ -157,11 +169,11 @@ class SessionServiceManager:
|
||||
|
||||
"""
|
||||
# 获取会话服务配置
|
||||
session_services = sp.get(
|
||||
"session_service_config",
|
||||
{},
|
||||
session_services = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_service_config",
|
||||
default={},
|
||||
)
|
||||
|
||||
# 如果配置了该会话的整体状态,返回该状态
|
||||
|
||||
@@ -8,7 +8,10 @@ class SessionPluginManager:
|
||||
"""管理会话级别的插件启停状态"""
|
||||
|
||||
@staticmethod
|
||||
def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
|
||||
async def is_plugin_enabled_for_session(
|
||||
session_id: str,
|
||||
plugin_name: str,
|
||||
) -> bool:
|
||||
"""检查插件是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
@@ -20,11 +23,11 @@ class SessionPluginManager:
|
||||
|
||||
"""
|
||||
# 获取会话插件配置
|
||||
session_plugin_config = sp.get(
|
||||
"session_plugin_config",
|
||||
{},
|
||||
session_plugin_config = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_plugin_config",
|
||||
default={},
|
||||
)
|
||||
session_config = session_plugin_config.get(session_id, {})
|
||||
|
||||
@@ -43,7 +46,10 @@ class SessionPluginManager:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list:
|
||||
async def filter_handlers_by_session(
|
||||
event: AstrMessageEvent,
|
||||
handlers: list,
|
||||
) -> list:
|
||||
"""根据会话配置过滤处理器列表
|
||||
|
||||
Args:
|
||||
@@ -59,6 +65,15 @@ class SessionPluginManager:
|
||||
session_id = event.unified_msg_origin
|
||||
filtered_handlers = []
|
||||
|
||||
session_plugin_config = await sp.get_async(
|
||||
scope="umo",
|
||||
scope_id=session_id,
|
||||
key="session_plugin_config",
|
||||
default={},
|
||||
)
|
||||
session_config = session_plugin_config.get(session_id, {})
|
||||
disabled_plugins = session_config.get("disabled_plugins", [])
|
||||
|
||||
for handler in handlers:
|
||||
# 获取处理器对应的插件
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
@@ -76,14 +91,11 @@ class SessionPluginManager:
|
||||
continue
|
||||
|
||||
# 检查插件是否在当前会话中启用
|
||||
if SessionPluginManager.is_plugin_enabled_for_session(
|
||||
session_id,
|
||||
plugin.name,
|
||||
):
|
||||
filtered_handlers.append(handler)
|
||||
else:
|
||||
if plugin.name in disabled_plugins:
|
||||
logger.debug(
|
||||
f"插件 {plugin.name} 在会话 {session_id} 中被禁用,跳过处理器 {handler.handler_name}",
|
||||
)
|
||||
else:
|
||||
filtered_handlers.append(handler)
|
||||
|
||||
return filtered_handlers
|
||||
|
||||
@@ -184,6 +184,7 @@ class EventType(enum.Enum):
|
||||
OnPlatformLoadedEvent = enum.auto() # 平台加载完成
|
||||
|
||||
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
||||
OnWaitingLLMRequestEvent = enum.auto() # 等待调用 LLM(在获取锁之前,仅通知)
|
||||
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
||||
OnLLMResponseEvent = enum.auto() # LLM 响应后
|
||||
OnDecoratingResultEvent = enum.auto() # 发送消息前
|
||||
|
||||
@@ -944,8 +944,49 @@ class PluginManager:
|
||||
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
|
||||
dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower()
|
||||
desti_dir = os.path.join(self.plugin_store_path, dir_name)
|
||||
|
||||
# 第一步:检查是否已安装同目录名的插件,先终止旧插件
|
||||
existing_plugin = None
|
||||
for star in self.context.get_all_stars():
|
||||
if star.root_dir_name == dir_name:
|
||||
existing_plugin = star
|
||||
break
|
||||
|
||||
if existing_plugin:
|
||||
logger.info(f"检测到插件 {existing_plugin.name} 已安装,正在终止旧插件...")
|
||||
try:
|
||||
await self._terminate_plugin(existing_plugin)
|
||||
except Exception:
|
||||
logger.warning(traceback.format_exc())
|
||||
if existing_plugin.name and existing_plugin.module_path:
|
||||
await self._unbind_plugin(
|
||||
existing_plugin.name, existing_plugin.module_path
|
||||
)
|
||||
|
||||
self.updator.unzip_file(zip_file_path, desti_dir)
|
||||
|
||||
# 第二步:解压后,读取新插件的 metadata.yaml,检查是否存在同名但不同目录的插件
|
||||
try:
|
||||
new_metadata = self._load_plugin_metadata(desti_dir)
|
||||
if new_metadata and new_metadata.name:
|
||||
for star in self.context.get_all_stars():
|
||||
if (
|
||||
star.name == new_metadata.name
|
||||
and star.root_dir_name != dir_name
|
||||
):
|
||||
logger.warning(
|
||||
f"检测到同名插件 {star.name} 存在于不同目录 {star.root_dir_name},正在终止..."
|
||||
)
|
||||
try:
|
||||
await self._terminate_plugin(star)
|
||||
except Exception:
|
||||
logger.warning(traceback.format_exc())
|
||||
if star.name and star.module_path:
|
||||
await self._unbind_plugin(star.name, star.module_path)
|
||||
break # 只处理第一个匹配的
|
||||
except Exception as e:
|
||||
logger.debug(f"读取新插件 metadata.yaml 失败,跳过同名检查: {e!s}")
|
||||
|
||||
# remove the zip
|
||||
try:
|
||||
os.remove(zip_file_path)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import fnmatch
|
||||
|
||||
from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
|
||||
|
||||
@@ -9,14 +11,15 @@ class UmopConfigRouter:
|
||||
"""UMOP 到配置文件 ID 的映射"""
|
||||
self.sp = sp
|
||||
|
||||
self._load_routing_table()
|
||||
async def initialize(self):
|
||||
await self._load_routing_table()
|
||||
|
||||
def _load_routing_table(self):
|
||||
async def _load_routing_table(self):
|
||||
"""加载路由表"""
|
||||
# 从 SharedPreferences 中加载 umop_to_conf_id 映射
|
||||
sp_data = self.sp.get(
|
||||
"umop_config_routing",
|
||||
{},
|
||||
sp_data = await self.sp.get_async(
|
||||
key="umop_config_routing",
|
||||
default={},
|
||||
scope="global",
|
||||
scope_id="global",
|
||||
)
|
||||
@@ -30,7 +33,7 @@ class UmopConfigRouter:
|
||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||
return False # 非法格式
|
||||
|
||||
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
||||
return all(p == "" or fnmatch.fnmatchcase(t, p) for p, t in zip(p1_ls, p2_ls))
|
||||
|
||||
def get_conf_id_for_umop(self, umo: str) -> str | None:
|
||||
"""根据 UMO 获取对应的配置文件 ID
|
||||
|
||||
@@ -3,6 +3,7 @@ import traceback
|
||||
from astrbot.core import astrbot_config, logger
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager
|
||||
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
|
||||
from astrbot.core.db.migration.migra_token_usage import migrate_token_usage
|
||||
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
|
||||
|
||||
|
||||
@@ -139,6 +140,13 @@ async def migra(
|
||||
logger.error(f"Migration for webchat session failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migration for token_usage column
|
||||
try:
|
||||
await migrate_token_usage(db)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for token_usage column failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# migra third party agent runner configs
|
||||
_c = False
|
||||
providers = astrbot_config["provider"]
|
||||
|
||||
@@ -1,10 +1,29 @@
|
||||
import asyncio
|
||||
import locale
|
||||
import logging
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
def _robust_decode(line: bytes) -> str:
|
||||
"""解码字节流,兼容不同平台的编码"""
|
||||
try:
|
||||
return line.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
try:
|
||||
return line.decode(locale.getpreferredencoding(False)).strip()
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
if sys.platform.startswith("win"):
|
||||
try:
|
||||
return line.decode("gbk").strip()
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
return line.decode("utf-8", errors="replace").strip()
|
||||
|
||||
|
||||
class PipInstaller:
|
||||
def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None):
|
||||
self.pip_install_arg = pip_install_arg
|
||||
@@ -42,7 +61,7 @@ class PipInstaller:
|
||||
|
||||
assert process.stdout is not None
|
||||
async for line in process.stdout:
|
||||
logger.info(line.decode().strip())
|
||||
logger.info(_robust_decode(line))
|
||||
|
||||
await process.wait()
|
||||
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
"""备份管理 API 路由"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import jwt
|
||||
from quart import request, send_file
|
||||
|
||||
from astrbot.core import logger
|
||||
@@ -22,6 +27,10 @@ from astrbot.core.utils.astrbot_path import (
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
# 分片上传常量
|
||||
CHUNK_SIZE = 1024 * 1024 # 1MB
|
||||
UPLOAD_EXPIRE_SECONDS = 3600 # 上传会话过期时间(1小时)
|
||||
|
||||
|
||||
def secure_filename(filename: str) -> str:
|
||||
"""清洗文件名,移除路径遍历字符和危险字符
|
||||
@@ -54,17 +63,17 @@ def secure_filename(filename: str) -> str:
|
||||
|
||||
|
||||
def generate_unique_filename(original_filename: str) -> str:
|
||||
"""生成唯一的文件名,添加时间戳前缀
|
||||
"""生成唯一的文件名,在原文件名后添加时间戳后缀避免重名
|
||||
|
||||
Args:
|
||||
original_filename: 原始文件名(已清洗)
|
||||
|
||||
Returns:
|
||||
唯一的文件名
|
||||
添加了时间戳后缀的唯一文件名,格式为 {原文件名}_{YYYYMMDD_HHMMSS}.{扩展名}
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
name, ext = os.path.splitext(original_filename)
|
||||
return f"uploaded_{timestamp}_{name}{ext}"
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return f"{name}_{timestamp}{ext}"
|
||||
|
||||
|
||||
class BackupRoute(Route):
|
||||
@@ -84,21 +93,34 @@ class BackupRoute(Route):
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.backup_dir = get_astrbot_backups_path()
|
||||
self.data_dir = get_astrbot_data_path()
|
||||
self.chunks_dir = os.path.join(self.backup_dir, ".chunks")
|
||||
|
||||
# 任务状态跟踪
|
||||
self.backup_tasks: dict[str, dict] = {}
|
||||
self.backup_progress: dict[str, dict] = {}
|
||||
|
||||
# 分片上传会话跟踪
|
||||
# upload_id -> {filename, total_chunks, received_chunks, last_activity, chunk_dir}
|
||||
self.upload_sessions: dict[str, dict] = {}
|
||||
|
||||
# 后台清理任务句柄
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
|
||||
# 注册路由
|
||||
self.routes = {
|
||||
"/backup/list": ("GET", self.list_backups),
|
||||
"/backup/export": ("POST", self.export_backup),
|
||||
"/backup/upload": ("POST", self.upload_backup), # 上传文件
|
||||
"/backup/upload": ("POST", self.upload_backup), # 上传文件(兼容小文件)
|
||||
"/backup/upload/init": ("POST", self.upload_init), # 分片上传初始化
|
||||
"/backup/upload/chunk": ("POST", self.upload_chunk), # 上传分片
|
||||
"/backup/upload/complete": ("POST", self.upload_complete), # 完成分片上传
|
||||
"/backup/upload/abort": ("POST", self.upload_abort), # 取消上传
|
||||
"/backup/check": ("POST", self.check_backup), # 预检查
|
||||
"/backup/import": ("POST", self.import_backup), # 确认导入
|
||||
"/backup/progress": ("GET", self.get_progress),
|
||||
"/backup/download": ("GET", self.download_backup),
|
||||
"/backup/delete": ("POST", self.delete_backup),
|
||||
"/backup/rename": ("POST", self.rename_backup), # 重命名备份
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
@@ -173,7 +195,81 @@ class BackupRoute(Route):
|
||||
|
||||
return _callback
|
||||
|
||||
def _ensure_cleanup_task_started(self):
|
||||
"""确保后台清理任务已启动(在异步上下文中延迟启动)"""
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
try:
|
||||
self._cleanup_task = asyncio.create_task(
|
||||
self._cleanup_expired_uploads()
|
||||
)
|
||||
except RuntimeError:
|
||||
# 如果没有运行中的事件循环,跳过(等待下次异步调用时启动)
|
||||
pass
|
||||
|
||||
async def _cleanup_expired_uploads(self):
|
||||
"""定期清理过期的上传会话
|
||||
|
||||
基于 last_activity 字段判断过期,避免清理活跃的上传会话。
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # 每5分钟检查一次
|
||||
current_time = time.time()
|
||||
expired_sessions = []
|
||||
|
||||
for upload_id, session in self.upload_sessions.items():
|
||||
# 使用 last_activity 判断过期,而非 created_at
|
||||
last_activity = session.get("last_activity", session["created_at"])
|
||||
if current_time - last_activity > UPLOAD_EXPIRE_SECONDS:
|
||||
expired_sessions.append(upload_id)
|
||||
|
||||
for upload_id in expired_sessions:
|
||||
await self._cleanup_upload_session(upload_id)
|
||||
logger.info(f"清理过期的上传会话: {upload_id}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 任务被取消,正常退出
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期上传会话失败: {e}")
|
||||
|
||||
async def _cleanup_upload_session(self, upload_id: str):
|
||||
"""清理上传会话"""
|
||||
if upload_id in self.upload_sessions:
|
||||
session = self.upload_sessions[upload_id]
|
||||
chunk_dir = session.get("chunk_dir")
|
||||
if chunk_dir and os.path.exists(chunk_dir):
|
||||
try:
|
||||
shutil.rmtree(chunk_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理分片目录失败: {e}")
|
||||
del self.upload_sessions[upload_id]
|
||||
|
||||
def _get_backup_manifest(self, zip_path: str) -> dict | None:
|
||||
"""从备份文件读取 manifest.json
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 文件路径
|
||||
|
||||
Returns:
|
||||
dict | None: manifest 内容,如果不是有效备份则返回 None
|
||||
"""
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
if "manifest.json" in zf.namelist():
|
||||
manifest_data = zf.read("manifest.json")
|
||||
return json.loads(manifest_data.decode("utf-8"))
|
||||
else:
|
||||
# 没有 manifest.json,不是有效的 AstrBot 备份
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"读取备份 manifest 失败: {e}")
|
||||
return None # 无法读取,不是有效备份
|
||||
|
||||
async def list_backups(self):
|
||||
# 确保后台清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
"""获取备份列表
|
||||
|
||||
Query 参数:
|
||||
@@ -190,16 +286,34 @@ class BackupRoute(Route):
|
||||
# 获取所有备份文件
|
||||
backup_files = []
|
||||
for filename in os.listdir(self.backup_dir):
|
||||
if filename.endswith(".zip") and filename.startswith("astrbot_backup_"):
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
stat = os.stat(file_path)
|
||||
backup_files.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"size": stat.st_size,
|
||||
"created_at": stat.st_mtime,
|
||||
}
|
||||
)
|
||||
# 只处理 .zip 文件,排除隐藏文件和目录
|
||||
if not filename.endswith(".zip") or filename.startswith("."):
|
||||
continue
|
||||
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.isfile(file_path):
|
||||
continue
|
||||
|
||||
# 读取 manifest.json 获取备份信息
|
||||
# 如果返回 None,说明不是有效的 AstrBot 备份,跳过
|
||||
manifest = self._get_backup_manifest(file_path)
|
||||
if manifest is None:
|
||||
logger.debug(f"跳过无效备份文件: {filename}")
|
||||
continue
|
||||
|
||||
stat = os.stat(file_path)
|
||||
backup_files.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"size": stat.st_size,
|
||||
"created_at": stat.st_mtime,
|
||||
"type": manifest.get(
|
||||
"origin", "exported"
|
||||
), # 老版本没有 origin 默认为 exported
|
||||
"astrbot_version": manifest.get("astrbot_version", "未知"),
|
||||
"exported_at": manifest.get("exported_at"),
|
||||
}
|
||||
)
|
||||
|
||||
# 按创建时间倒序排序
|
||||
backup_files.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
@@ -345,6 +459,309 @@ class BackupRoute(Route):
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"上传备份文件失败: {e!s}").__dict__
|
||||
|
||||
async def upload_init(self):
|
||||
"""初始化分片上传
|
||||
|
||||
创建一个上传会话,返回 upload_id 供后续分片上传使用。
|
||||
|
||||
JSON Body:
|
||||
- filename: 原始文件名
|
||||
- total_size: 文件总大小(字节)
|
||||
|
||||
返回:
|
||||
- upload_id: 上传会话 ID
|
||||
- chunk_size: 分片大小(由后端决定)
|
||||
- total_chunks: 分片总数(由后端根据 total_size 和 chunk_size 计算)
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
filename = data.get("filename")
|
||||
total_size = data.get("total_size", 0)
|
||||
|
||||
if not filename:
|
||||
return Response().error("缺少 filename 参数").__dict__
|
||||
|
||||
if not filename.endswith(".zip"):
|
||||
return Response().error("请上传 ZIP 格式的备份文件").__dict__
|
||||
|
||||
if total_size <= 0:
|
||||
return Response().error("无效的文件大小").__dict__
|
||||
|
||||
# 由后端计算分片总数,确保前后端一致
|
||||
import math
|
||||
|
||||
total_chunks = math.ceil(total_size / CHUNK_SIZE)
|
||||
|
||||
# 生成上传 ID
|
||||
upload_id = str(uuid.uuid4())
|
||||
|
||||
# 创建分片存储目录
|
||||
chunk_dir = os.path.join(self.chunks_dir, upload_id)
|
||||
Path(chunk_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 清洗文件名
|
||||
safe_filename = secure_filename(filename)
|
||||
unique_filename = generate_unique_filename(safe_filename)
|
||||
|
||||
# 创建上传会话
|
||||
current_time = time.time()
|
||||
self.upload_sessions[upload_id] = {
|
||||
"filename": unique_filename,
|
||||
"original_filename": filename,
|
||||
"total_size": total_size,
|
||||
"total_chunks": total_chunks,
|
||||
"received_chunks": set(),
|
||||
"created_at": current_time,
|
||||
"last_activity": current_time, # 用于判断会话是否活跃
|
||||
"chunk_dir": chunk_dir,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"初始化分片上传: upload_id={upload_id}, "
|
||||
f"filename={unique_filename}, total_chunks={total_chunks}"
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"upload_id": upload_id,
|
||||
"chunk_size": CHUNK_SIZE,
|
||||
"total_chunks": total_chunks,
|
||||
"filename": unique_filename,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"初始化分片上传失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"初始化分片上传失败: {e!s}").__dict__
|
||||
|
||||
async def upload_chunk(self):
|
||||
"""上传分片
|
||||
|
||||
上传单个分片数据。
|
||||
|
||||
Form Data:
|
||||
- upload_id: 上传会话 ID
|
||||
- chunk_index: 分片索引(从 0 开始)
|
||||
- chunk: 分片数据
|
||||
|
||||
返回:
|
||||
- received: 已接收的分片数量
|
||||
- total: 分片总数
|
||||
"""
|
||||
try:
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
|
||||
upload_id = form.get("upload_id")
|
||||
chunk_index_str = form.get("chunk_index")
|
||||
|
||||
if not upload_id or chunk_index_str is None:
|
||||
return Response().error("缺少必要参数").__dict__
|
||||
|
||||
try:
|
||||
chunk_index = int(chunk_index_str)
|
||||
except ValueError:
|
||||
return Response().error("无效的分片索引").__dict__
|
||||
|
||||
if "chunk" not in files:
|
||||
return Response().error("缺少分片数据").__dict__
|
||||
|
||||
# 验证上传会话
|
||||
if upload_id not in self.upload_sessions:
|
||||
return Response().error("上传会话不存在或已过期").__dict__
|
||||
|
||||
session = self.upload_sessions[upload_id]
|
||||
|
||||
# 验证分片索引
|
||||
if chunk_index < 0 or chunk_index >= session["total_chunks"]:
|
||||
return Response().error("分片索引超出范围").__dict__
|
||||
|
||||
# 保存分片
|
||||
chunk_file = files["chunk"]
|
||||
chunk_path = os.path.join(session["chunk_dir"], f"{chunk_index}.part")
|
||||
await chunk_file.save(chunk_path)
|
||||
|
||||
# 记录已接收的分片,并更新最后活动时间
|
||||
session["received_chunks"].add(chunk_index)
|
||||
session["last_activity"] = time.time() # 刷新活动时间,防止活跃上传被清理
|
||||
|
||||
received_count = len(session["received_chunks"])
|
||||
total_chunks = session["total_chunks"]
|
||||
|
||||
logger.debug(
|
||||
f"接收分片: upload_id={upload_id}, "
|
||||
f"chunk={chunk_index + 1}/{total_chunks}"
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"received": received_count,
|
||||
"total": total_chunks,
|
||||
"chunk_index": chunk_index,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"上传分片失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"上传分片失败: {e!s}").__dict__
|
||||
|
||||
def _mark_backup_as_uploaded(self, zip_path: str) -> None:
|
||||
"""修改备份文件的 manifest.json,将 origin 设置为 uploaded
|
||||
|
||||
使用 zipfile 的 append 模式添加新的 manifest.json,
|
||||
ZIP 规范中后添加的同名文件会覆盖先前的文件。
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 文件路径
|
||||
"""
|
||||
try:
|
||||
# 读取原有 manifest
|
||||
manifest = {"origin": "uploaded", "uploaded_at": datetime.now().isoformat()}
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
if "manifest.json" in zf.namelist():
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data.decode("utf-8"))
|
||||
manifest["origin"] = "uploaded"
|
||||
manifest["uploaded_at"] = datetime.now().isoformat()
|
||||
|
||||
# 使用 append 模式添加新的 manifest.json
|
||||
# ZIP 规范中,后添加的同名文件会覆盖先前的
|
||||
with zipfile.ZipFile(zip_path, "a") as zf:
|
||||
new_manifest = json.dumps(manifest, ensure_ascii=False, indent=2)
|
||||
zf.writestr("manifest.json", new_manifest)
|
||||
|
||||
logger.debug(f"已标记备份为上传来源: {zip_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"标记备份来源失败: {e}")
|
||||
|
||||
async def upload_complete(self):
|
||||
"""完成分片上传
|
||||
|
||||
合并所有分片为完整文件。
|
||||
|
||||
JSON Body:
|
||||
- upload_id: 上传会话 ID
|
||||
|
||||
返回:
|
||||
- filename: 合并后的文件名
|
||||
- size: 文件大小
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
upload_id = data.get("upload_id")
|
||||
|
||||
if not upload_id:
|
||||
return Response().error("缺少 upload_id 参数").__dict__
|
||||
|
||||
# 验证上传会话
|
||||
if upload_id not in self.upload_sessions:
|
||||
return Response().error("上传会话不存在或已过期").__dict__
|
||||
|
||||
session = self.upload_sessions[upload_id]
|
||||
|
||||
# 检查是否所有分片都已接收
|
||||
received = session["received_chunks"]
|
||||
total = session["total_chunks"]
|
||||
|
||||
if len(received) != total:
|
||||
missing = set(range(total)) - received
|
||||
return (
|
||||
Response()
|
||||
.error(f"分片不完整,缺少: {sorted(missing)[:10]}...")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 合并分片
|
||||
chunk_dir = session["chunk_dir"]
|
||||
filename = session["filename"]
|
||||
|
||||
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||
output_path = os.path.join(self.backup_dir, filename)
|
||||
|
||||
try:
|
||||
with open(output_path, "wb") as outfile:
|
||||
for i in range(total):
|
||||
chunk_path = os.path.join(chunk_dir, f"{i}.part")
|
||||
with open(chunk_path, "rb") as chunk_file:
|
||||
# 分块读取,避免内存溢出
|
||||
while True:
|
||||
data_block = chunk_file.read(8192)
|
||||
if not data_block:
|
||||
break
|
||||
outfile.write(data_block)
|
||||
|
||||
file_size = os.path.getsize(output_path)
|
||||
|
||||
# 标记备份为上传来源(修改 manifest.json 中的 origin 字段)
|
||||
self._mark_backup_as_uploaded(output_path)
|
||||
|
||||
logger.info(
|
||||
f"分片上传完成: {filename}, size={file_size}, chunks={total}"
|
||||
)
|
||||
|
||||
# 清理分片目录
|
||||
await self._cleanup_upload_session(upload_id)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"filename": filename,
|
||||
"original_filename": session["original_filename"],
|
||||
"size": file_size,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
# 如果合并失败,删除不完整的文件
|
||||
if os.path.exists(output_path):
|
||||
os.remove(output_path)
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"完成分片上传失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"完成分片上传失败: {e!s}").__dict__
|
||||
|
||||
async def upload_abort(self):
|
||||
"""取消分片上传
|
||||
|
||||
取消上传并清理已上传的分片。
|
||||
|
||||
JSON Body:
|
||||
- upload_id: 上传会话 ID
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
upload_id = data.get("upload_id")
|
||||
|
||||
if not upload_id:
|
||||
return Response().error("缺少 upload_id 参数").__dict__
|
||||
|
||||
if upload_id not in self.upload_sessions:
|
||||
# 会话已不存在,可能已过期或已完成
|
||||
return Response().ok(message="上传已取消").__dict__
|
||||
|
||||
# 清理会话
|
||||
await self._cleanup_upload_session(upload_id)
|
||||
|
||||
logger.info(f"取消分片上传: {upload_id}")
|
||||
|
||||
return Response().ok(message="上传已取消").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"取消上传失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"取消上传失败: {e!s}").__dict__
|
||||
|
||||
async def check_backup(self):
|
||||
"""预检查备份文件
|
||||
|
||||
@@ -537,12 +954,33 @@ class BackupRoute(Route):
|
||||
|
||||
Query 参数:
|
||||
- filename: 备份文件名 (必填)
|
||||
- token: JWT token (必填,用于浏览器原生下载鉴权)
|
||||
|
||||
注意: 此路由已被添加到 auth_middleware 白名单中,
|
||||
使用 URL 参数中的 token 进行鉴权,以支持浏览器原生下载。
|
||||
"""
|
||||
try:
|
||||
filename = request.args.get("filename")
|
||||
token = request.args.get("token")
|
||||
|
||||
if not filename:
|
||||
return Response().error("缺少参数 filename").__dict__
|
||||
|
||||
if not token:
|
||||
return Response().error("缺少参数 token").__dict__
|
||||
|
||||
# 验证 JWT token
|
||||
try:
|
||||
jwt_secret = self.config.get("dashboard", {}).get("jwt_secret")
|
||||
if not jwt_secret:
|
||||
return Response().error("服务器配置错误").__dict__
|
||||
|
||||
jwt.decode(token, jwt_secret, algorithms=["HS256"])
|
||||
except jwt.ExpiredSignatureError:
|
||||
return Response().error("Token 已过期,请刷新页面后重试").__dict__
|
||||
except jwt.InvalidTokenError:
|
||||
return Response().error("Token 无效").__dict__
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
@@ -587,3 +1025,69 @@ class BackupRoute(Route):
|
||||
logger.error(f"删除备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除备份失败: {e!s}").__dict__
|
||||
|
||||
async def rename_backup(self):
|
||||
"""重命名备份文件
|
||||
|
||||
Body:
|
||||
- filename: 当前文件名 (必填)
|
||||
- new_name: 新文件名 (必填,不含扩展名)
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
filename = data.get("filename")
|
||||
new_name = data.get("new_name")
|
||||
|
||||
if not filename:
|
||||
return Response().error("缺少参数 filename").__dict__
|
||||
|
||||
if not new_name:
|
||||
return Response().error("缺少参数 new_name").__dict__
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
# 清洗新文件名(移除路径和危险字符)
|
||||
new_name = secure_filename(new_name)
|
||||
|
||||
# 移除新文件名中的扩展名(如果有的话)
|
||||
if new_name.endswith(".zip"):
|
||||
new_name = new_name[:-4]
|
||||
|
||||
# 验证新文件名不为空
|
||||
if not new_name or new_name.replace("_", "") == "":
|
||||
return Response().error("新文件名无效").__dict__
|
||||
|
||||
# 强制使用 .zip 扩展名
|
||||
new_filename = f"{new_name}.zip"
|
||||
|
||||
# 检查原文件是否存在
|
||||
old_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(old_path):
|
||||
return Response().error("备份文件不存在").__dict__
|
||||
|
||||
# 检查新文件名是否已存在
|
||||
new_path = os.path.join(self.backup_dir, new_filename)
|
||||
if os.path.exists(new_path):
|
||||
return Response().error(f"文件名 '{new_filename}' 已存在").__dict__
|
||||
|
||||
# 执行重命名
|
||||
os.rename(old_path, new_path)
|
||||
|
||||
logger.info(f"备份文件重命名: {filename} -> {new_filename}")
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"old_filename": filename,
|
||||
"new_filename": new_filename,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"重命名备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"重命名备份失败: {e!s}").__dict__
|
||||
|
||||
@@ -46,6 +46,46 @@ def try_cast(value: Any, type_: str):
|
||||
return None
|
||||
|
||||
|
||||
def _expect_type(value, expected_type, path_key, errors, expected_name=None):
|
||||
if not isinstance(value, expected_type):
|
||||
errors.append(
|
||||
f"错误的类型 {path_key}: 期望是 {expected_name or expected_type.__name__}, "
|
||||
f"得到了 {type(value).__name__}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _validate_template_list(value, meta, path_key, errors, validate_fn):
|
||||
if not _expect_type(value, list, path_key, errors, "list"):
|
||||
return
|
||||
|
||||
templates = meta.get("templates")
|
||||
if not isinstance(templates, dict):
|
||||
templates = {}
|
||||
|
||||
for idx, item in enumerate(value):
|
||||
item_path = f"{path_key}[{idx}]"
|
||||
if not _expect_type(item, dict, item_path, errors, "dict"):
|
||||
continue
|
||||
|
||||
template_key = item.get("__template_key") or item.get("template")
|
||||
if not template_key:
|
||||
errors.append(f"缺少模板选择 {item_path}: 需要 __template_key")
|
||||
continue
|
||||
|
||||
template_meta = templates.get(template_key)
|
||||
if not template_meta:
|
||||
errors.append(f"未知模板 {item_path}: {template_key}")
|
||||
continue
|
||||
|
||||
validate_fn(
|
||||
item,
|
||||
template_meta.get("items", {}),
|
||||
path=f"{item_path}.",
|
||||
)
|
||||
|
||||
|
||||
def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]:
|
||||
errors = []
|
||||
|
||||
@@ -61,6 +101,11 @@ def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]
|
||||
if value is None:
|
||||
data[key] = DEFAULT_VALUE_MAP[meta["type"]]
|
||||
continue
|
||||
|
||||
if meta["type"] == "template_list":
|
||||
_validate_template_list(value, meta, f"{path}{key}", errors, validate)
|
||||
continue
|
||||
|
||||
if meta["type"] == "list" and not isinstance(value, list):
|
||||
errors.append(
|
||||
f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}",
|
||||
|
||||
@@ -115,6 +115,7 @@ class AstrBotDashboard:
|
||||
"/api/file",
|
||||
"/api/platform/webhook",
|
||||
"/api/stat/start-time",
|
||||
"/api/backup/download", # 备份下载使用 URL 参数传递 token
|
||||
]
|
||||
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
|
||||
- 修复钉钉适配器中"回复消息 At 发送人"功能失效的问题
|
||||
- 修复 Xinference STT 在部分情况下无法使用的问题
|
||||
- 修复"会话隔离"功能在非默认配置下无法生效的问题
|
||||
- 修复部分 LLM 中转商因 token 使用情况不符合 OpenAI 标准接口规范导致请求报错的问题
|
||||
- 修复 Deepseek 模型开启思考模式后工具调用报错的问题
|
||||
- 修复部分操作系统环境下 pip 安装依赖时出现 `UnicodeDecodeError` 错误的问题
|
||||
|
||||
### 优化
|
||||
|
||||
- 全面优化对思考型模型的支持(如 Anthropic Extended Thinking、Deepseek 思考模式),完整回传 thinking 内容,提升模型推理性能
|
||||
- 优化 WebUI 记忆侧边栏中"更多功能"和"平台日志"模块的展开状态记忆
|
||||
- 为 MiniMax TTS 新增 "auto" 音色情绪选项,支持模型根据文本内容自动选择情绪
|
||||
- 优化备份功能,支持大文件分片下载
|
||||
- 为 WebSocket 连接添加 max_size 参数,以处理更大的消息并防止接收来自 Satori 平台的大负载时连接断开
|
||||
- 优化插件安装流程,通过文件安装插件时,若插件已加载则先终止再重新加载,避免重复加载
|
||||
- 知识库支持将 overlap 参数设置为 0
|
||||
|
||||
### 新增
|
||||
|
||||
- 为 `dict` 类型的 Schema 新增 JSON value 和 template schema 功能。详见 [dict-类型的-schema](https://docs.astrbot.app/dev/star/guides/plugin-config.html#dict-%E7%B1%BB%E5%9E%8B%E7%9A%84-schema)。
|
||||
- 新增 `template_list` 类型的 Schema,支持渲染指定 template 下的列表。详见 [template-list-类型的-schema](https://docs.astrbot.app/dev/star/guides/plugin-config.html#template-list-%E7%B1%BB%E5%9E%8B%E7%9A%84-schema)。
|
||||
@@ -0,0 +1,5 @@
|
||||
## What's Changed
|
||||
|
||||
hotfix of v4.10.4
|
||||
|
||||
fix: 部分配置项的输入框不显示,如飞书机器人配置的部分配置项。(#4268)
|
||||
@@ -0,0 +1,11 @@
|
||||
## What's Changed
|
||||
|
||||
hotfix of v4.10.4
|
||||
|
||||
fix:
|
||||
|
||||
1. ‼️ 部分情况下使用 OpenAI 接口报错与 reasoning_content 有关的问题;
|
||||
|
||||
feat:
|
||||
|
||||
1. WebUI 已安装插件页支持记忆视图类型(列表/卡片),列表视图显示插件的人类友好名称和 logo。
|
||||
@@ -0,0 +1,19 @@
|
||||
## What's Changed
|
||||
|
||||
### 新增
|
||||
|
||||
- 支持上下文自动压缩功能。入口:配置文件 -> 上下文管理策略 -> 超出模型上下文窗口时的处理方式。详情请查看: [自动上下文压缩](https://docs.astrbot.app/use/context-compress.html) ([#4322](https://github.com/AstrBotDevs/AstrBot/issues/4322))
|
||||
- 新增 `on_waiting_llm_request` 事件钩子 ([#4319](https://github.com/AstrBotDevs/AstrBot/issues/4319))
|
||||
- WebUI 支持强制更新插件 ([#4293](https://github.com/AstrBotDevs/AstrBot/issues/4293))
|
||||
- 社区已提供适用于 [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter) 平台的适配器插件
|
||||
|
||||
### 修复
|
||||
|
||||
- 修复微信公众号中由于 msg.id 数据类型不匹配导致的重试失败问题 ([#4292](https://github.com/AstrBotDevs/AstrBot/issues/4292))
|
||||
- 修复调用 TTS 命令时出现的数据库锁定错误 ([#4313](https://github.com/AstrBotDevs/AstrBot/issues/4313))
|
||||
- 修复 Anthropic 提供商中 token 用量始终为 0 的问题 ([#4328](https://github.com/AstrBotDevs/AstrBot/issues/4328))
|
||||
|
||||
### 优化
|
||||
|
||||
- 完善共享组件的国际化支持 ([#4327](https://github.com/AstrBotDevs/AstrBot/issues/4327))
|
||||
- 优化下载大型备份文件时的稳定性,减少失败情况 ([#4329](https://github.com/AstrBotDevs/AstrBot/issues/4329))
|
||||
@@ -82,7 +82,7 @@
|
||||
{{ tm('availability.test') }}
|
||||
<template #activator="{ props }">
|
||||
<v-btn
|
||||
icon="mdi-wrench"
|
||||
icon="mdi-connection"
|
||||
size="small"
|
||||
variant="text"
|
||||
:disabled="!entry.provider.enable"
|
||||
@@ -93,6 +93,19 @@
|
||||
</template>
|
||||
</v-tooltip>
|
||||
|
||||
<v-tooltip location="top" max-width="300">
|
||||
{{ tm('models.configure') }}
|
||||
<template #activator="{ props }">
|
||||
<v-btn
|
||||
icon="mdi-cog"
|
||||
size="small"
|
||||
variant="text"
|
||||
v-bind="props"
|
||||
@click.stop="emit('open-provider-edit', entry.provider)"
|
||||
></v-btn>
|
||||
</template>
|
||||
</v-tooltip>
|
||||
|
||||
<v-btn icon="mdi-delete" size="small" variant="text" color="error" @click.stop="emit('delete-provider', entry.provider)"></v-btn>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
<script setup>
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { ref, computed } from 'vue'
|
||||
import ListConfigItem from './ListConfigItem.vue'
|
||||
import ObjectEditor from './ObjectEditor.vue'
|
||||
import ProviderSelector from './ProviderSelector.vue'
|
||||
import PersonaSelector from './PersonaSelector.vue'
|
||||
import KnowledgeBaseSelector from './KnowledgeBaseSelector.vue'
|
||||
import ConfigItemRenderer from './ConfigItemRenderer.vue'
|
||||
import TemplateListEditor from './TemplateListEditor.vue'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
import axios from 'axios'
|
||||
import { useToast } from '@/utils/toast'
|
||||
@@ -159,6 +156,30 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Template List -->
|
||||
<div v-else-if="metadata[metadataKey].items[key]?.type === 'template_list'" class="nested-object w-100">
|
||||
<div v-if="!metadata[metadataKey].items[key]?.invisible && shouldShowItem(metadata[metadataKey].items[key], key)" class="nested-container">
|
||||
<div class="config-section mb-2">
|
||||
<v-list-item-title class="config-title">
|
||||
<span v-if="metadata[metadataKey].items[key]?.description">
|
||||
{{ metadata[metadataKey].items[key]?.description }}
|
||||
<span class="property-key">({{ key }})</span>
|
||||
</span>
|
||||
<span v-else>{{ key }}</span>
|
||||
</v-list-item-title>
|
||||
<v-list-item-subtitle class="config-hint">
|
||||
<span v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint" class="important-hint">‼️</span>
|
||||
{{ metadata[metadataKey].items[key]?.hint }}
|
||||
</v-list-item-subtitle>
|
||||
</div>
|
||||
<TemplateListEditor
|
||||
v-model="iterable[key]"
|
||||
:templates="metadata[metadataKey].items[key]?.templates || {}"
|
||||
class="config-field"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Regular Property -->
|
||||
<template v-else>
|
||||
<v-row v-if="!metadata[metadataKey].items[key]?.invisible && shouldShowItem(metadata[metadataKey].items[key], key)" class="config-row">
|
||||
@@ -181,202 +202,14 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-col>
|
||||
|
||||
<v-col cols="12" sm="6" class="config-input">
|
||||
<div v-if="metadata[metadataKey].items[key]" class="w-100">
|
||||
<!-- Special handling for specific metadata types -->
|
||||
<div v-if="metadata[metadataKey].items[key]?._special === 'select_provider'">
|
||||
<ProviderSelector
|
||||
v-model="iterable[key]"
|
||||
:provider-type="'chat_completion'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_provider_stt'">
|
||||
<ProviderSelector
|
||||
v-model="iterable[key]"
|
||||
:provider-type="'speech_to_text'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_provider_tts'">
|
||||
<ProviderSelector
|
||||
v-model="iterable[key]"
|
||||
:provider-type="'text_to_speech'"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_persona'">
|
||||
<PersonaSelector
|
||||
v-model="iterable[key]"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'select_knowledgebase'">
|
||||
<KnowledgeBaseSelector
|
||||
v-model="iterable[key]"
|
||||
/>
|
||||
</div>
|
||||
<!-- Numeric input with get_embedding_dim button -->
|
||||
<div v-else-if="metadata[metadataKey].items[key]?._special === 'get_embedding_dim'"
|
||||
class="d-flex align-center gap-2">
|
||||
<v-text-field
|
||||
v-model="iterable[key]"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
type="number"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
<v-btn
|
||||
color="primary"
|
||||
variant="tonal"
|
||||
size="small"
|
||||
@click="getEmbeddingDimensions(iterable)"
|
||||
:loading="loadingEmbeddingDim"
|
||||
class="ml-2"
|
||||
>
|
||||
自动检测
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- List item with options-->
|
||||
<div v-else-if="metadata[metadataKey].items[key]?.type === 'list' && metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible && metadata[metadataKey].items[key]?.render_type === 'checkbox'"
|
||||
class="d-flex flex-wrap gap-20">
|
||||
<v-checkbox
|
||||
v-for="(option, index) in metadata[metadataKey].items[key]?.options"
|
||||
v-model="iterable[key]"
|
||||
:label="metadata[metadataKey].items[key]?.labels ? metadata[metadataKey].items[key].labels[index] : option"
|
||||
:value="option"
|
||||
class="mr-2"
|
||||
color="primary"
|
||||
hide-details
|
||||
></v-checkbox>
|
||||
</div>
|
||||
<!-- List item with options-->
|
||||
<v-combobox
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'list' && metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
:items="metadata[metadataKey].items[key]?.options"
|
||||
:disabled="metadata[metadataKey].items[key]?.readonly"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
hide-details
|
||||
chips
|
||||
multiple
|
||||
></v-combobox>
|
||||
<!-- Select input -->
|
||||
<v-select
|
||||
v-else-if="metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
:items="metadata[metadataKey].items[key]?.options"
|
||||
:disabled="metadata[metadataKey].items[key]?.readonly"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-select>
|
||||
|
||||
<!-- Code Editor with Full Screen Option -->
|
||||
<div v-else-if="metadata[metadataKey].items[key]?.editor_mode && !metadata[metadataKey].items[key]?.invisible" class="editor-container">
|
||||
<VueMonacoEditor
|
||||
:theme="metadata[metadataKey].items[key]?.editor_theme || 'vs-light'"
|
||||
:language="metadata[metadataKey].items[key]?.editor_language || 'json'"
|
||||
style="min-height: 100px; flex-grow: 1; border: 1px solid rgba(0, 0, 0, 0.1);"
|
||||
v-model:value="iterable[key]"
|
||||
>
|
||||
</VueMonacoEditor>
|
||||
<v-btn
|
||||
icon
|
||||
size="small"
|
||||
variant="text"
|
||||
color="primary"
|
||||
class="editor-fullscreen-btn"
|
||||
@click="openEditorDialog(key, iterable, metadata[metadataKey].items[key]?.editor_theme, metadata[metadataKey].items[key]?.editor_language)"
|
||||
:title="t('core.common.editor.fullscreen')"
|
||||
>
|
||||
<v-icon>mdi-fullscreen</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- String input -->
|
||||
<v-text-field
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'string' && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
|
||||
<!-- Numeric input with optional slider -->
|
||||
<div
|
||||
v-else-if="(metadata[metadataKey].items[key]?.type === 'int' || metadata[metadataKey].items[key]?.type === 'float') && !metadata[metadataKey]?.invisible"
|
||||
class="d-flex align-center gap-3"
|
||||
>
|
||||
<v-slider
|
||||
v-if="metadata[metadataKey].items[key]?.slider"
|
||||
v-model.number="iterable[key]"
|
||||
:min="metadata[metadataKey].items[key]?.slider?.min ?? 0"
|
||||
:max="metadata[metadataKey].items[key]?.slider?.max ?? 100"
|
||||
:step="metadata[metadataKey].items[key]?.slider?.step ?? 1"
|
||||
color="primary"
|
||||
density="compact"
|
||||
hide-details
|
||||
class="flex-grow-1"
|
||||
></v-slider>
|
||||
<v-text-field
|
||||
v-model.number="iterable[key]"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
type="number"
|
||||
hide-details
|
||||
style="max-width: 140px;"
|
||||
></v-text-field>
|
||||
</div>
|
||||
|
||||
<!-- Text area -->
|
||||
<v-textarea
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'text' && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
variant="outlined"
|
||||
rows="3"
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-textarea>
|
||||
|
||||
<!-- Boolean switch -->
|
||||
<v-switch
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'bool' && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
color="primary"
|
||||
inset
|
||||
density="compact"
|
||||
hide-details
|
||||
></v-switch>
|
||||
|
||||
<!-- List item -->
|
||||
<ListConfigItem
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'list' && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
class="config-field"
|
||||
/>
|
||||
|
||||
<!-- Dict item (key-value editor) -->
|
||||
<ObjectEditor
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'dict' && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
class="config-field"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<!-- Fallback for unknown metadata -->
|
||||
<div v-else class="w-100">
|
||||
<v-text-field
|
||||
v-model="iterable[key]"
|
||||
:label="key"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
</div>
|
||||
<ConfigItemRenderer
|
||||
v-model="iterable[key]"
|
||||
:item-meta="metadata[metadataKey].items[key] || null"
|
||||
:loading="loadingEmbeddingDim"
|
||||
:show-fullscreen-btn="!!metadata[metadataKey].items[key]?.editor_mode"
|
||||
@get-embedding-dim="getEmbeddingDimensions(iterable)"
|
||||
@open-fullscreen="openEditorDialog(key, iterable, metadata[metadataKey].items[key]?.editor_theme, metadata[metadataKey].items[key]?.editor_language)"
|
||||
/>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
@@ -406,84 +239,17 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
</v-col>
|
||||
|
||||
<v-col cols="12" sm="5" class="config-input">
|
||||
<div class="w-100">
|
||||
<!-- Select input -->
|
||||
<v-select
|
||||
v-if="metadata[metadataKey]?.options && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]"
|
||||
:items="metadata[metadataKey]?.options"
|
||||
:disabled="metadata[metadataKey]?.readonly"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-select>
|
||||
|
||||
<!-- String input -->
|
||||
<v-text-field
|
||||
v-else-if="metadata[metadataKey]?.type === 'string' && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
|
||||
<!-- Numeric input with optional slider -->
|
||||
<div
|
||||
v-else-if="(metadata[metadataKey]?.type === 'int' || metadata[metadataKey]?.type === 'float') && !metadata[metadataKey]?.invisible"
|
||||
class="d-flex align-center gap-3"
|
||||
>
|
||||
<v-slider
|
||||
v-if="metadata[metadataKey]?.slider"
|
||||
v-model.number="iterable[metadataKey]"
|
||||
:min="metadata[metadataKey]?.slider?.min ?? 0"
|
||||
:max="metadata[metadataKey]?.slider?.max ?? 100"
|
||||
:step="metadata[metadataKey]?.slider?.step ?? 1"
|
||||
color="primary"
|
||||
density="compact"
|
||||
hide-details
|
||||
class="flex-grow-1"
|
||||
></v-slider>
|
||||
<v-text-field
|
||||
v-model.number="iterable[metadataKey]"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
type="number"
|
||||
hide-details
|
||||
style="max-width: 140px;"
|
||||
></v-text-field>
|
||||
</div>
|
||||
|
||||
<!-- Text area -->
|
||||
<v-textarea
|
||||
v-else-if="metadata[metadataKey]?.type === 'text' && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]"
|
||||
variant="outlined"
|
||||
auto-grow
|
||||
rows="3"
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-textarea>
|
||||
|
||||
<!-- Boolean switch -->
|
||||
<v-switch
|
||||
v-else-if="metadata[metadataKey]?.type === 'bool' && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]"
|
||||
color="primary"
|
||||
inset
|
||||
density="compact"
|
||||
hide-details
|
||||
></v-switch>
|
||||
|
||||
<!-- List item -->
|
||||
<ListConfigItem
|
||||
v-else-if="metadata[metadataKey]?.type === 'list' && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]"
|
||||
class="config-field"
|
||||
/>
|
||||
</div>
|
||||
<TemplateListEditor
|
||||
v-if="metadata[metadataKey]?.type === 'template_list' && !metadata[metadataKey]?.invisible"
|
||||
v-model="iterable[metadataKey]"
|
||||
:templates="metadata[metadataKey]?.templates || {}"
|
||||
class="config-field"
|
||||
/>
|
||||
<ConfigItemRenderer
|
||||
v-else
|
||||
v-model="iterable[metadataKey]"
|
||||
:item-meta="metadata[metadataKey]"
|
||||
/>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
|
||||
@@ -1,13 +1,8 @@
|
||||
<script setup>
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { ref, computed } from 'vue'
|
||||
import ListConfigItem from './ListConfigItem.vue'
|
||||
import ObjectEditor from './ObjectEditor.vue'
|
||||
import ProviderSelector from './ProviderSelector.vue'
|
||||
import PersonaSelector from './PersonaSelector.vue'
|
||||
import KnowledgeBaseSelector from './KnowledgeBaseSelector.vue'
|
||||
import PluginSetSelector from './PluginSetSelector.vue'
|
||||
import T2ITemplateEditor from './T2ITemplateEditor.vue'
|
||||
import ConfigItemRenderer from './ConfigItemRenderer.vue'
|
||||
import TemplateListEditor from './TemplateListEditor.vue'
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables'
|
||||
|
||||
|
||||
@@ -215,118 +210,19 @@ function getSpecialSubtype(value) {
|
||||
</v-list-item>
|
||||
</v-col>
|
||||
<v-col cols="12" sm="6" class="config-input">
|
||||
<div class="w-100" v-if="!itemMeta?._special">
|
||||
<!-- Select input for JSON selector -->
|
||||
<v-select v-if="itemMeta?.options" v-model="createSelectorModel(itemKey).value"
|
||||
:items="(() => {
|
||||
const labels = getTranslatedLabels(itemMeta);
|
||||
return labels
|
||||
? itemMeta.options.map((value, index) => ({ title: labels[index] || value, value: value }))
|
||||
: itemMeta.options;
|
||||
})()"
|
||||
:disabled="itemMeta?.readonly" density="compact" variant="outlined"
|
||||
class="config-field" hide-details></v-select>
|
||||
|
||||
<!-- Code Editor for JSON selector -->
|
||||
<div v-else-if="itemMeta?.editor_mode" class="editor-container">
|
||||
<VueMonacoEditor :theme="itemMeta?.editor_theme || 'vs-light'"
|
||||
:language="itemMeta?.editor_language || 'json'"
|
||||
style="min-height: 100px; flex-grow: 1; border: 1px solid rgba(0, 0, 0, 0.1);"
|
||||
v-model:value="createSelectorModel(itemKey).value">
|
||||
</VueMonacoEditor>
|
||||
<v-btn icon size="small" variant="text" color="primary" class="editor-fullscreen-btn"
|
||||
@click="openEditorDialog(itemKey, iterable, itemMeta?.editor_theme, itemMeta?.editor_language)"
|
||||
:title="t('core.common.editor.fullscreen')">
|
||||
<v-icon>mdi-fullscreen</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- String input for JSON selector -->
|
||||
<v-text-field v-else-if="itemMeta?.type === 'string'" v-model="createSelectorModel(itemKey).value"
|
||||
density="compact" variant="outlined" class="config-field" hide-details></v-text-field>
|
||||
|
||||
<!-- Numeric input with optional slider for JSON selector -->
|
||||
<div v-else-if="itemMeta?.type === 'int' || itemMeta?.type === 'float'" class="d-flex align-center gap-3">
|
||||
<v-slider
|
||||
v-if="itemMeta?.slider"
|
||||
v-model.number="createSelectorModel(itemKey).value"
|
||||
:min="itemMeta?.slider?.min ?? 0"
|
||||
:max="itemMeta?.slider?.max ?? 100"
|
||||
:step="itemMeta?.slider?.step ?? 1"
|
||||
color="primary"
|
||||
density="compact"
|
||||
hide-details
|
||||
style="flex: 3"
|
||||
></v-slider>
|
||||
<v-text-field
|
||||
v-model.number="createSelectorModel(itemKey).value"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
style="flex: 2"
|
||||
type="number"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
</div>
|
||||
|
||||
<!-- Text area for JSON selector -->
|
||||
<v-textarea v-else-if="itemMeta?.type === 'text'" v-model="createSelectorModel(itemKey).value"
|
||||
variant="outlined" rows="3" class="config-field" hide-details></v-textarea>
|
||||
|
||||
<!-- Boolean switch for JSON selector -->
|
||||
<v-switch v-else-if="itemMeta?.type === 'bool'" v-model="createSelectorModel(itemKey).value"
|
||||
color="primary" inset density="compact" hide-details
|
||||
style="display: flex; justify-content: end;"></v-switch>
|
||||
|
||||
<!-- List item for JSON selector -->
|
||||
<ListConfigItem v-else-if="itemMeta?.type === 'list'" v-model="createSelectorModel(itemKey).value"
|
||||
button-text="修改" class="config-field" />
|
||||
|
||||
<!-- Object editor for JSON selector -->
|
||||
<ObjectEditor v-else-if="itemMeta?.type === 'dict'" v-model="createSelectorModel(itemKey).value"
|
||||
class="config-field" />
|
||||
|
||||
<!-- Fallback for JSON selector -->
|
||||
<v-text-field v-else v-model="createSelectorModel(itemKey).value" density="compact" variant="outlined"
|
||||
class="config-field" hide-details></v-text-field>
|
||||
</div>
|
||||
|
||||
<!-- Special handling for specific metadata types -->
|
||||
<div v-else-if="itemMeta?._special === 'select_provider'">
|
||||
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'chat_completion'" />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_provider_stt'">
|
||||
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'speech_to_text'" />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_provider_tts'">
|
||||
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'text_to_speech'" />
|
||||
</div>
|
||||
<div v-else-if="getSpecialName(itemMeta?._special) === 'select_agent_runner_provider'">
|
||||
<ProviderSelector
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:provider-type="'agent_runner'"
|
||||
:provider-subtype="getSpecialSubtype(itemMeta?._special)"
|
||||
/>
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'provider_pool'">
|
||||
<ProviderSelector v-model="createSelectorModel(itemKey).value" :provider-type="'chat_completion'"
|
||||
button-text="选择提供商池..." />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_persona'">
|
||||
<PersonaSelector v-model="createSelectorModel(itemKey).value" />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'persona_pool'">
|
||||
<PersonaSelector v-model="createSelectorModel(itemKey).value" button-text="选择人格池..." />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_knowledgebase'">
|
||||
<KnowledgeBaseSelector v-model="createSelectorModel(itemKey).value" />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 'select_plugin_set'">
|
||||
<PluginSetSelector v-model="createSelectorModel(itemKey).value" />
|
||||
</div>
|
||||
<div v-else-if="itemMeta?._special === 't2i_template'">
|
||||
<T2ITemplateEditor />
|
||||
</div>
|
||||
<TemplateListEditor
|
||||
v-if="itemMeta?.type === 'template_list'"
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:templates="itemMeta?.templates || {}"
|
||||
class="config-field"
|
||||
/>
|
||||
<ConfigItemRenderer
|
||||
v-else
|
||||
v-model="createSelectorModel(itemKey).value"
|
||||
:item-meta="itemMeta || null"
|
||||
:show-fullscreen-btn="!!itemMeta?.editor_mode"
|
||||
@open-fullscreen="openEditorDialog(itemKey, iterable, itemMeta?.editor_theme, itemMeta?.editor_language)"
|
||||
/>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
|
||||
@@ -110,9 +110,23 @@
|
||||
|
||||
<!-- 步骤1.5: 上传中 -->
|
||||
<div v-else-if="importStatus === 'uploading'" class="text-center py-8">
|
||||
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
|
||||
<v-icon size="64" color="primary" class="mb-4">mdi-cloud-upload</v-icon>
|
||||
<h3 class="mb-4">{{ t('features.settings.backup.import.uploading') }}</h3>
|
||||
<p class="text-grey">{{ t('features.settings.backup.import.uploadWait') }}</p>
|
||||
<p class="text-grey mb-2">
|
||||
{{ uploadProgress.message || t('features.settings.backup.import.uploadWait') }}
|
||||
</p>
|
||||
<p class="text-grey-darken-1 mb-4">
|
||||
{{ formatFileSize(uploadProgress.uploaded) }} / {{ formatFileSize(uploadProgress.total) }}
|
||||
({{ uploadProgress.percent }}%)
|
||||
</p>
|
||||
<v-progress-linear
|
||||
:model-value="uploadProgress.percent"
|
||||
:max="100"
|
||||
class="mt-2"
|
||||
color="primary"
|
||||
height="8"
|
||||
rounded
|
||||
></v-progress-linear>
|
||||
</div>
|
||||
|
||||
<!-- 步骤2: 确认导入 -->
|
||||
@@ -242,15 +256,38 @@
|
||||
:key="backup.filename"
|
||||
>
|
||||
<template v-slot:prepend>
|
||||
<v-icon color="primary">mdi-zip-box</v-icon>
|
||||
<v-icon :color="backup.type === 'uploaded' ? 'orange' : 'primary'">
|
||||
{{ backup.type === 'uploaded' ? 'mdi-upload' : 'mdi-zip-box' }}
|
||||
</v-icon>
|
||||
</template>
|
||||
|
||||
<v-list-item-title>{{ backup.filename }}</v-list-item-title>
|
||||
<v-list-item-subtitle>
|
||||
{{ formatFileSize(backup.size) }} · {{ formatDate(backup.created_at) }}
|
||||
<v-chip size="x-small" color="primary" variant="tonal" class="ml-2">
|
||||
v{{ backup.astrbot_version }}
|
||||
</v-chip>
|
||||
<v-chip v-if="backup.type === 'uploaded'" size="x-small" color="orange" variant="tonal" class="ml-1">
|
||||
{{ t('features.settings.backup.list.uploaded') }}
|
||||
</v-chip>
|
||||
</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-btn
|
||||
icon="mdi-restore"
|
||||
variant="text"
|
||||
size="small"
|
||||
color="success"
|
||||
:title="t('features.settings.backup.list.restore')"
|
||||
@click="restoreFromList(backup.filename)"
|
||||
></v-btn>
|
||||
<v-btn
|
||||
icon="mdi-pencil"
|
||||
variant="text"
|
||||
size="small"
|
||||
:title="t('features.settings.backup.list.rename')"
|
||||
@click="openRenameDialog(backup.filename)"
|
||||
></v-btn>
|
||||
<v-btn icon="mdi-download" variant="text" size="small" @click="downloadBackup(backup.filename)"></v-btn>
|
||||
<v-btn icon="mdi-delete" variant="text" size="small" color="error" @click="deleteBackup(backup.filename)"></v-btn>
|
||||
</template>
|
||||
@@ -263,6 +300,12 @@
|
||||
{{ t('features.settings.backup.list.refresh') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- 提示信息 -->
|
||||
<p class="text-caption text-grey text-center mt-4">
|
||||
<v-icon size="small" class="mr-1">mdi-information-outline</v-icon>
|
||||
{{ t('features.settings.backup.list.ftpHint') }}
|
||||
</p>
|
||||
</v-window-item>
|
||||
</v-window>
|
||||
</v-card-text>
|
||||
@@ -276,6 +319,50 @@
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 重命名对话框 -->
|
||||
<v-dialog v-model="renameDialogOpen" max-width="450" persistent>
|
||||
<v-card>
|
||||
<v-card-title>
|
||||
<v-icon class="mr-2">mdi-pencil</v-icon>
|
||||
{{ t('features.settings.backup.list.renameTitle') }}
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<v-text-field
|
||||
v-model="renameNewName"
|
||||
:label="t('features.settings.backup.list.newName')"
|
||||
:rules="[renameValidationRule]"
|
||||
:error-messages="renameError"
|
||||
variant="outlined"
|
||||
density="comfortable"
|
||||
autofocus
|
||||
@keyup.enter="confirmRename"
|
||||
>
|
||||
<template v-slot:append-inner>
|
||||
<span class="text-grey">.zip</span>
|
||||
</template>
|
||||
</v-text-field>
|
||||
<p class="text-caption text-grey mt-1">
|
||||
{{ t('features.settings.backup.list.renameHint') }}
|
||||
</p>
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="grey" variant="text" @click="closeRenameDialog">
|
||||
{{ t('core.common.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="primary"
|
||||
variant="flat"
|
||||
@click="confirmRename"
|
||||
:loading="renameLoading"
|
||||
:disabled="!renameNewName || !!renameError"
|
||||
>
|
||||
{{ t('core.common.confirm') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
||||
</template>
|
||||
|
||||
@@ -307,13 +394,33 @@ const importError = ref('')
|
||||
const uploadedFilename = ref('') // 已上传的文件名
|
||||
const checkResult = ref(null) // 预检查结果
|
||||
|
||||
// 分片上传状态
|
||||
const CONCURRENT_UPLOADS = 5 // 并发上传数
|
||||
const uploadId = ref('')
|
||||
const chunkSize = ref(0) // 分片大小(从后端获取)
|
||||
const uploadProgress = ref({
|
||||
uploaded: 0,
|
||||
total: 0,
|
||||
percent: 0,
|
||||
message: ''
|
||||
})
|
||||
|
||||
// 备份列表
|
||||
const loadingList = ref(false)
|
||||
const backupList = ref([])
|
||||
|
||||
// 重命名对话框状态
|
||||
const renameDialogOpen = ref(false)
|
||||
const renameOldFilename = ref('')
|
||||
const renameNewName = ref('')
|
||||
const renameLoading = ref(false)
|
||||
const renameError = ref('')
|
||||
|
||||
// 计算属性
|
||||
const isProcessing = computed(() => {
|
||||
return exportStatus.value === 'processing' || importStatus.value === 'processing'
|
||||
return exportStatus.value === 'processing' ||
|
||||
importStatus.value === 'processing' ||
|
||||
importStatus.value === 'uploading'
|
||||
})
|
||||
|
||||
// 版本检查相关的计算属性
|
||||
@@ -440,28 +547,127 @@ const resetExport = () => {
|
||||
exportError.value = ''
|
||||
}
|
||||
|
||||
/**
|
||||
* 并发上传分片
|
||||
*
|
||||
* 使用并发控制同时上传多个分片,提升上传速度。
|
||||
* 后端按分片索引命名文件(如 0.part, 1.part),合并时按顺序读取,
|
||||
* 因此分片到达顺序不影响最终结果。
|
||||
*/
|
||||
const uploadChunksInParallel = async (file, totalChunks, currentUploadId, currentChunkSize) => {
|
||||
// 跟踪已完成的字节数(使用原子操作避免并发问题)
|
||||
let completedBytes = 0
|
||||
const chunkSizes = []
|
||||
|
||||
// 预计算每个分片的大小(使用后端返回的 chunk_size)
|
||||
for (let i = 0; i < totalChunks; i++) {
|
||||
const start = i * currentChunkSize
|
||||
const end = Math.min(start + currentChunkSize, file.size)
|
||||
chunkSizes[i] = end - start
|
||||
}
|
||||
|
||||
// 上传单个分片的函数
|
||||
const uploadSingleChunk = async (chunkIndex) => {
|
||||
const start = chunkIndex * currentChunkSize
|
||||
const end = Math.min(start + currentChunkSize, file.size)
|
||||
const chunk = file.slice(start, end)
|
||||
|
||||
const formData = new FormData()
|
||||
formData.append('upload_id', currentUploadId)
|
||||
formData.append('chunk_index', chunkIndex.toString())
|
||||
formData.append('chunk', chunk)
|
||||
|
||||
const response = await axios.post('/api/backup/upload/chunk', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' }
|
||||
})
|
||||
|
||||
if (response.data.status !== 'ok') {
|
||||
throw new Error(response.data.message)
|
||||
}
|
||||
|
||||
// 更新进度(累加已完成字节)
|
||||
completedBytes += chunkSizes[chunkIndex]
|
||||
uploadProgress.value.uploaded = completedBytes
|
||||
uploadProgress.value.percent = Math.round((completedBytes / file.size) * 100)
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// 创建分片索引队列
|
||||
const pendingChunks = Array.from({ length: totalChunks }, (_, i) => i)
|
||||
const activePromises = []
|
||||
|
||||
// 处理队列中的分片
|
||||
while (pendingChunks.length > 0 || activePromises.length > 0) {
|
||||
// 填充并发槽位
|
||||
while (pendingChunks.length > 0 && activePromises.length < CONCURRENT_UPLOADS) {
|
||||
const chunkIndex = pendingChunks.shift()
|
||||
const promise = uploadSingleChunk(chunkIndex).then(() => {
|
||||
// 完成后从活动列表移除
|
||||
const idx = activePromises.indexOf(promise)
|
||||
if (idx > -1) activePromises.splice(idx, 1)
|
||||
})
|
||||
activePromises.push(promise)
|
||||
}
|
||||
|
||||
// 等待至少一个完成
|
||||
if (activePromises.length > 0) {
|
||||
await Promise.race(activePromises)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 上传并检查
|
||||
const uploadAndCheck = async () => {
|
||||
if (!importFile.value) return
|
||||
|
||||
importStatus.value = 'uploading'
|
||||
const file = importFile.value
|
||||
|
||||
try {
|
||||
// 步骤1: 上传文件
|
||||
const formData = new FormData()
|
||||
formData.append('file', importFile.value)
|
||||
|
||||
const uploadResponse = await axios.post('/api/backup/upload', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' }
|
||||
})
|
||||
|
||||
if (uploadResponse.data.status !== 'ok') {
|
||||
throw new Error(uploadResponse.data.message)
|
||||
// 初始化上传进度
|
||||
uploadProgress.value = {
|
||||
uploaded: 0,
|
||||
total: file.size,
|
||||
percent: 0,
|
||||
message: t('features.settings.backup.import.uploadInit')
|
||||
}
|
||||
|
||||
uploadedFilename.value = uploadResponse.data.data.filename
|
||||
// 步骤1: 初始化分片上传(后端计算并返回 chunk_size 和 total_chunks)
|
||||
const initResponse = await axios.post('/api/backup/upload/init', {
|
||||
filename: file.name,
|
||||
total_size: file.size
|
||||
})
|
||||
|
||||
if (initResponse.data.status !== 'ok') {
|
||||
throw new Error(initResponse.data.message)
|
||||
}
|
||||
|
||||
uploadId.value = initResponse.data.data.upload_id
|
||||
chunkSize.value = initResponse.data.data.chunk_size
|
||||
const totalChunks = initResponse.data.data.total_chunks
|
||||
|
||||
// 步骤2: 并行分片上传(5个并发连接)
|
||||
uploadProgress.value.message = t('features.settings.backup.import.uploadingChunks')
|
||||
|
||||
await uploadChunksInParallel(file, totalChunks, uploadId.value, chunkSize.value)
|
||||
|
||||
// 步骤3: 完成上传
|
||||
uploadProgress.value.message = t('features.settings.backup.import.uploadComplete')
|
||||
|
||||
const completeResponse = await axios.post('/api/backup/upload/complete', {
|
||||
upload_id: uploadId.value
|
||||
})
|
||||
|
||||
if (completeResponse.data.status !== 'ok') {
|
||||
throw new Error(completeResponse.data.message)
|
||||
}
|
||||
|
||||
uploadedFilename.value = completeResponse.data.data.filename
|
||||
|
||||
// 步骤4: 预检查
|
||||
uploadProgress.value.message = t('features.settings.backup.import.checking')
|
||||
|
||||
// 步骤2: 预检查
|
||||
const checkResponse = await axios.post('/api/backup/check', {
|
||||
filename: uploadedFilename.value
|
||||
})
|
||||
@@ -483,6 +689,17 @@ const uploadAndCheck = async () => {
|
||||
importStatus.value = 'confirm'
|
||||
|
||||
} catch (error) {
|
||||
// 上传失败时尝试清理已上传的分片
|
||||
if (uploadId.value) {
|
||||
try {
|
||||
await axios.post('/api/backup/upload/abort', {
|
||||
upload_id: uploadId.value
|
||||
})
|
||||
} catch (abortError) {
|
||||
console.error('Failed to abort upload:', abortError)
|
||||
}
|
||||
}
|
||||
|
||||
importStatus.value = 'failed'
|
||||
importError.value = error.response?.data?.message || error.message || 'Upload failed'
|
||||
}
|
||||
@@ -548,7 +765,18 @@ const pollImportProgress = async () => {
|
||||
}
|
||||
|
||||
// 重置导入状态
|
||||
const resetImport = () => {
|
||||
const resetImport = async () => {
|
||||
// 如果有进行中的上传,先取消
|
||||
if (uploadId.value && importStatus.value === 'uploading') {
|
||||
try {
|
||||
await axios.post('/api/backup/upload/abort', {
|
||||
upload_id: uploadId.value
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Failed to abort upload:', error)
|
||||
}
|
||||
}
|
||||
|
||||
importStatus.value = 'idle'
|
||||
importFile.value = null
|
||||
importTaskId.value = null
|
||||
@@ -556,29 +784,61 @@ const resetImport = () => {
|
||||
importError.value = ''
|
||||
uploadedFilename.value = ''
|
||||
checkResult.value = null
|
||||
uploadId.value = ''
|
||||
chunkSize.value = 0
|
||||
uploadProgress.value = { uploaded: 0, total: 0, percent: 0, message: '' }
|
||||
}
|
||||
|
||||
// 下载备份
|
||||
const downloadBackup = async (filename) => {
|
||||
// 下载备份(使用浏览器原生下载,可显示下载进度)
|
||||
const downloadBackup = (filename) => {
|
||||
// 获取 token 用于鉴权(因为浏览器原生下载无法携带 Authorization header)
|
||||
const token = localStorage.getItem('token')
|
||||
if (!token) {
|
||||
alert(t('core.common.unauthorized'))
|
||||
return
|
||||
}
|
||||
|
||||
// 直接使用浏览器下载,这样可以看到原生下载进度条
|
||||
const downloadUrl = `/api/backup/download?filename=${encodeURIComponent(filename)}&token=${encodeURIComponent(token)}`
|
||||
|
||||
// 创建隐藏的 a 标签触发下载
|
||||
const link = document.createElement('a')
|
||||
link.href = downloadUrl
|
||||
link.download = filename
|
||||
link.style.display = 'none'
|
||||
document.body.appendChild(link)
|
||||
link.click()
|
||||
document.body.removeChild(link)
|
||||
}
|
||||
|
||||
// 从列表中恢复备份
|
||||
const restoreFromList = async (filename) => {
|
||||
// 切换到导入标签页并设置文件名
|
||||
uploadedFilename.value = filename
|
||||
|
||||
// 预检查
|
||||
try {
|
||||
const response = await axios.get('/api/backup/download', {
|
||||
params: { filename },
|
||||
responseType: 'blob'
|
||||
const checkResponse = await axios.post('/api/backup/check', {
|
||||
filename: filename
|
||||
})
|
||||
|
||||
if (checkResponse.data.status !== 'ok') {
|
||||
throw new Error(checkResponse.data.message)
|
||||
}
|
||||
|
||||
checkResult.value = checkResponse.data.data
|
||||
|
||||
// 创建 Blob URL 并触发下载
|
||||
const blob = new Blob([response.data], { type: 'application/zip' })
|
||||
const url = window.URL.createObjectURL(blob)
|
||||
const link = document.createElement('a')
|
||||
link.href = url
|
||||
link.download = filename
|
||||
document.body.appendChild(link)
|
||||
link.click()
|
||||
document.body.removeChild(link)
|
||||
window.URL.revokeObjectURL(url)
|
||||
if (!checkResult.value.valid) {
|
||||
alert(checkResult.value.error || t('features.settings.backup.import.invalidBackup'))
|
||||
return
|
||||
}
|
||||
|
||||
// 切换到导入标签页并显示确认
|
||||
activeTab.value = 'import'
|
||||
importStatus.value = 'confirm'
|
||||
|
||||
} catch (error) {
|
||||
console.error('Download failed:', error)
|
||||
alert(t('features.settings.backup.export.failed') + ': ' + (error.message || 'Unknown error'))
|
||||
alert(error.response?.data?.message || error.message || 'Check failed')
|
||||
}
|
||||
}
|
||||
|
||||
@@ -598,6 +858,68 @@ const deleteBackup = async (filename) => {
|
||||
}
|
||||
}
|
||||
|
||||
// 重命名相关函数
|
||||
const openRenameDialog = (filename) => {
|
||||
renameOldFilename.value = filename
|
||||
// 移除 .zip 后缀,只显示文件名部分
|
||||
renameNewName.value = filename.replace(/\.zip$/i, '')
|
||||
renameError.value = ''
|
||||
renameDialogOpen.value = true
|
||||
}
|
||||
|
||||
const closeRenameDialog = () => {
|
||||
renameDialogOpen.value = false
|
||||
renameOldFilename.value = ''
|
||||
renameNewName.value = ''
|
||||
renameError.value = ''
|
||||
}
|
||||
|
||||
// 文件名验证规则
|
||||
const renameValidationRule = (value) => {
|
||||
if (!value) return t('features.settings.backup.list.renameRequired')
|
||||
// 检查是否包含非法字符
|
||||
if (/[\\/:*?"<>|]/.test(value)) {
|
||||
return t('features.settings.backup.list.renameInvalidChars')
|
||||
}
|
||||
// 检查是否包含路径遍历字符
|
||||
if (value.includes('..')) {
|
||||
return t('features.settings.backup.list.renameInvalidChars')
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
const confirmRename = async () => {
|
||||
if (!renameNewName.value || renameError.value) return
|
||||
|
||||
// 前端验证
|
||||
const validationResult = renameValidationRule(renameNewName.value)
|
||||
if (validationResult !== true) {
|
||||
renameError.value = validationResult
|
||||
return
|
||||
}
|
||||
|
||||
renameLoading.value = true
|
||||
renameError.value = ''
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/backup/rename', {
|
||||
filename: renameOldFilename.value,
|
||||
new_name: renameNewName.value
|
||||
})
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
closeRenameDialog()
|
||||
loadBackupList()
|
||||
} else {
|
||||
renameError.value = response.data.message || t('features.settings.backup.list.renameFailed')
|
||||
}
|
||||
} catch (error) {
|
||||
renameError.value = error.response?.data?.message || error.message || t('features.settings.backup.list.renameFailed')
|
||||
} finally {
|
||||
renameLoading.value = false
|
||||
}
|
||||
}
|
||||
|
||||
// 格式化文件大小
|
||||
const formatFileSize = (bytes) => {
|
||||
if (bytes === 0) return '0 B'
|
||||
@@ -632,9 +954,9 @@ const restartAstrBot = () => {
|
||||
}
|
||||
|
||||
// 重置所有状态
|
||||
const resetAll = () => {
|
||||
const resetAll = async () => {
|
||||
resetExport()
|
||||
resetImport()
|
||||
await resetImport()
|
||||
activeTab.value = 'export'
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,332 @@
|
||||
<template>
|
||||
<div class="w-100">
|
||||
<!-- Special handling for specific metadata types -->
|
||||
<template v-if="itemMeta?._special === 'select_provider'">
|
||||
<ProviderSelector :model-value="modelValue" @update:model-value="emitUpdate" :provider-type="'chat_completion'" />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'select_provider_stt'">
|
||||
<ProviderSelector :model-value="modelValue" @update:model-value="emitUpdate" :provider-type="'speech_to_text'" />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'select_provider_tts'">
|
||||
<ProviderSelector :model-value="modelValue" @update:model-value="emitUpdate" :provider-type="'text_to_speech'" />
|
||||
</template>
|
||||
<template v-else-if="getSpecialName(itemMeta?._special) === 'select_agent_runner_provider'">
|
||||
<ProviderSelector
|
||||
:model-value="modelValue"
|
||||
@update:model-value="emitUpdate"
|
||||
:provider-type="'agent_runner'"
|
||||
:provider-subtype="getSpecialSubtype(itemMeta?._special)"
|
||||
/>
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'provider_pool'">
|
||||
<ProviderSelector :model-value="modelValue" @update:model-value="emitUpdate" :provider-type="'chat_completion'"
|
||||
button-text="选择提供商池..." />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'select_persona'">
|
||||
<PersonaSelector :model-value="modelValue" @update:model-value="emitUpdate" />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'persona_pool'">
|
||||
<PersonaSelector :model-value="modelValue" @update:model-value="emitUpdate" button-text="选择人格池..." />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'select_knowledgebase'">
|
||||
<KnowledgeBaseSelector :model-value="modelValue" @update:model-value="emitUpdate" />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'select_plugin_set'">
|
||||
<PluginSetSelector :model-value="modelValue" @update:model-value="emitUpdate" />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 't2i_template'">
|
||||
<T2ITemplateEditor />
|
||||
</template>
|
||||
<template v-else-if="itemMeta?._special === 'get_embedding_dim'">
|
||||
<div class="d-flex align-center gap-2">
|
||||
<v-text-field
|
||||
:model-value="modelValue"
|
||||
@update:model-value="emitUpdate"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
type="number"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
<v-btn
|
||||
color="primary"
|
||||
variant="tonal"
|
||||
size="small"
|
||||
@click="$emit('get-embedding-dim')"
|
||||
:loading="loading"
|
||||
class="ml-2"
|
||||
>
|
||||
自动检测
|
||||
</v-btn>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div
|
||||
v-else-if="itemMeta?.type === 'list' && itemMeta?.options && itemMeta?.render_type === 'checkbox'"
|
||||
class="d-flex flex-wrap gap-20"
|
||||
>
|
||||
<v-checkbox
|
||||
v-for="(option, optionIndex) in itemMeta.options"
|
||||
:key="optionIndex"
|
||||
:model-value="modelValue"
|
||||
@update:model-value="emitUpdate"
|
||||
:label="getLabel(itemMeta, optionIndex, option)"
|
||||
:value="option"
|
||||
class="mr-2"
|
||||
color="primary"
|
||||
hide-details
|
||||
></v-checkbox>
|
||||
</div>
|
||||
|
||||
<v-combobox
|
||||
v-else-if="itemMeta?.type === 'list' && itemMeta?.options"
|
||||
:model-value="modelValue"
|
||||
@update:model-value="emitUpdate"
|
||||
:items="itemMeta.options"
|
||||
:disabled="itemMeta?.readonly"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
hide-details
|
||||
chips
|
||||
multiple
|
||||
></v-combobox>
|
||||
|
||||
<v-select
|
||||
v-else-if="itemMeta?.options"
|
||||
:model-value="modelValue"
|
||||
@update:model-value="emitUpdate"
|
||||
:items="getSelectItems(itemMeta)"
|
||||
:disabled="itemMeta?.readonly"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-select>
|
||||
|
||||
<div v-else-if="itemMeta?.editor_mode" class="editor-container">
|
||||
<VueMonacoEditor
|
||||
:theme="itemMeta?.editor_theme || 'vs-light'"
|
||||
:language="itemMeta?.editor_language || 'json'"
|
||||
style="min-height: 100px; flex-grow: 1; border: 1px solid rgba(0, 0, 0, 0.1);"
|
||||
:value="modelValue"
|
||||
@update:value="emitUpdate"
|
||||
>
|
||||
</VueMonacoEditor>
|
||||
<v-btn v-if="showFullscreenBtn" icon size="small" variant="text" color="primary" class="editor-fullscreen-btn"
|
||||
@click="$emit('open-fullscreen')"
|
||||
:title="t('core.common.editor.fullscreen')">
|
||||
<v-icon>mdi-fullscreen</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<v-text-field
|
||||
v-else-if="itemMeta?.type === 'string'"
|
||||
:model-value="modelValue"
|
||||
@update:model-value="emitUpdate"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
|
||||
<div
|
||||
v-else-if="itemMeta?.type === 'int' || itemMeta?.type === 'float'"
|
||||
class="d-flex align-center gap-3"
|
||||
>
|
||||
<v-slider
|
||||
v-if="itemMeta?.slider"
|
||||
:model-value="toNumber(modelValue)"
|
||||
@update:model-value="val => emitUpdate(toNumber(val))"
|
||||
:min="itemMeta?.slider?.min ?? 0"
|
||||
:max="itemMeta?.slider?.max ?? 100"
|
||||
:step="itemMeta?.slider?.step ?? 1"
|
||||
color="primary"
|
||||
density="compact"
|
||||
hide-details
|
||||
style="flex: 1"
|
||||
></v-slider>
|
||||
<v-text-field
|
||||
:model-value="modelValue"
|
||||
@update:model-value="val => emitUpdate(toNumber(val))"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
type="number"
|
||||
hide-details
|
||||
style="flex: 1"
|
||||
></v-text-field>
|
||||
</div>
|
||||
|
||||
<v-textarea
|
||||
v-else-if="itemMeta?.type === 'text'"
|
||||
:model-value="modelValue"
|
||||
@update:model-value="emitUpdate"
|
||||
variant="outlined"
|
||||
rows="3"
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-textarea>
|
||||
|
||||
<v-switch
|
||||
v-else-if="itemMeta?.type === 'bool'"
|
||||
:model-value="modelValue"
|
||||
@update:model-value="emitUpdate"
|
||||
color="primary"
|
||||
inset
|
||||
density="compact"
|
||||
hide-details
|
||||
></v-switch>
|
||||
|
||||
<ListConfigItem
|
||||
v-else-if="itemMeta?.type === 'list'"
|
||||
:model-value="modelValue"
|
||||
@update:model-value="emitUpdate"
|
||||
class="config-field"
|
||||
/>
|
||||
|
||||
<ObjectEditor
|
||||
v-else-if="itemMeta?.type === 'dict'"
|
||||
:model-value="modelValue"
|
||||
:item-meta="itemMeta"
|
||||
@update:model-value="emitUpdate"
|
||||
class="config-field"
|
||||
/>
|
||||
|
||||
<v-text-field
|
||||
v-else
|
||||
:model-value="modelValue"
|
||||
@update:model-value="emitUpdate"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
class="config-field"
|
||||
hide-details
|
||||
></v-text-field>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import ListConfigItem from './ListConfigItem.vue'
|
||||
import ObjectEditor from './ObjectEditor.vue'
|
||||
import ProviderSelector from './ProviderSelector.vue'
|
||||
import PersonaSelector from './PersonaSelector.vue'
|
||||
import KnowledgeBaseSelector from './KnowledgeBaseSelector.vue'
|
||||
import PluginSetSelector from './PluginSetSelector.vue'
|
||||
import T2ITemplateEditor from './T2ITemplateEditor.vue'
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables'
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
type: [String, Number, Boolean, Array, Object],
|
||||
default: null
|
||||
},
|
||||
itemMeta: {
|
||||
type: Object,
|
||||
default: null
|
||||
},
|
||||
loading: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
},
|
||||
showFullscreenBtn: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
}
|
||||
})
|
||||
|
||||
const emit = defineEmits(['update:modelValue', 'get-embedding-dim', 'open-fullscreen'])
|
||||
const { t } = useI18n()
|
||||
const { getRaw } = useModuleI18n('features/config-metadata')
|
||||
|
||||
function emitUpdate(val) {
|
||||
emit('update:modelValue', val)
|
||||
}
|
||||
|
||||
function toNumber(val) {
|
||||
const n = parseFloat(val)
|
||||
return isNaN(n) ? 0 : n
|
||||
}
|
||||
|
||||
function getLabel(itemMeta, index, option) {
|
||||
const labels = getTranslatedLabels(itemMeta)
|
||||
return labels ? labels[index] : option
|
||||
}
|
||||
|
||||
function getTranslatedLabels(itemMeta) {
|
||||
if (!itemMeta?.labels) return null
|
||||
if (typeof itemMeta.labels === 'string') {
|
||||
const translatedLabels = getRaw(itemMeta.labels)
|
||||
if (Array.isArray(translatedLabels)) {
|
||||
return translatedLabels
|
||||
}
|
||||
}
|
||||
if (Array.isArray(itemMeta.labels)) {
|
||||
return itemMeta.labels
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
function getSelectItems(itemMeta) {
|
||||
const labels = getTranslatedLabels(itemMeta)
|
||||
if (labels && itemMeta.options) {
|
||||
return itemMeta.options.map((value, index) => ({
|
||||
title: labels[index] || value,
|
||||
value: value
|
||||
}))
|
||||
}
|
||||
return itemMeta.options || []
|
||||
}
|
||||
|
||||
function parseSpecialValue(value) {
|
||||
if (!value || typeof value !== 'string') {
|
||||
return { name: '', subtype: '' }
|
||||
}
|
||||
const [name, ...rest] = value.split(':')
|
||||
return {
|
||||
name,
|
||||
subtype: rest.join(':') || ''
|
||||
}
|
||||
}
|
||||
|
||||
function getSpecialName(value) {
|
||||
return parseSpecialValue(value).name
|
||||
}
|
||||
|
||||
function getSpecialSubtype(value) {
|
||||
return parseSpecialValue(value).subtype
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.config-field {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.editor-container {
|
||||
position: relative;
|
||||
display: flex;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.editor-fullscreen-btn {
|
||||
position: absolute;
|
||||
top: 4px;
|
||||
right: 4px;
|
||||
z-index: 10;
|
||||
background-color: rgba(0, 0, 0, 0.3);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.editor-fullscreen-btn:hover {
|
||||
background-color: rgba(0, 0, 0, 0.5);
|
||||
}
|
||||
|
||||
.gap-20 {
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
:deep(.v-field__input) {
|
||||
font-size: 14px;
|
||||
}
|
||||
</style>
|
||||
@@ -145,9 +145,11 @@ const viewReadme = () => {
|
||||
}})</v-list-item-title>
|
||||
</v-list-item>
|
||||
|
||||
<v-list-item @click="updateExtension" :disabled="!extension?.has_update">
|
||||
<v-list-item @click="updateExtension">
|
||||
<v-list-item-title>
|
||||
{{ tm('card.actions.updateTo') }} {{ extension.online_version || extension.version }}
|
||||
{{ extension.has_update
|
||||
? tm('card.actions.updateTo') + ' ' + extension.online_version
|
||||
: tm('card.actions.reinstall') }}
|
||||
</v-list-item-title>
|
||||
</v-list-item>
|
||||
</template>
|
||||
|
||||
@@ -23,7 +23,7 @@
|
||||
</div>
|
||||
</div>
|
||||
<v-btn size="small" color="primary" variant="tonal" @click="openDialog">
|
||||
{{ preferSingleItem ? '添加更多' : (buttonText || t('core.common.list.modifyButton')) }}
|
||||
{{ preferSingleItem ? t('core.common.list.addMore') : (buttonText || t('core.common.list.modifyButton')) }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
@@ -48,6 +48,14 @@
|
||||
:placeholder="t('core.common.list.inputPlaceholder')"
|
||||
class="flex-grow-1">
|
||||
</v-text-field>
|
||||
<v-btn
|
||||
@click="addItem"
|
||||
variant="tonal"
|
||||
color="primary"
|
||||
size="small"
|
||||
:disabled="!newItem.trim()">
|
||||
{{ t('core.common.list.addButton') }}
|
||||
</v-btn>
|
||||
<v-btn
|
||||
@click="showBatchImport = true"
|
||||
variant="tonal"
|
||||
@@ -318,4 +326,4 @@ function cancelBatchImport() {
|
||||
.v-chip {
|
||||
margin: 2px;
|
||||
}
|
||||
</style>
|
||||
</style>
|
||||
|
||||
@@ -26,8 +26,9 @@
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text class="pa-4" style="max-height: 400px; overflow-y: auto;">
|
||||
<div v-if="localKeyValuePairs.length > 0">
|
||||
<div v-for="(pair, index) in localKeyValuePairs" :key="index" class="key-value-pair">
|
||||
<!-- Regular key-value pairs (non-template) -->
|
||||
<div v-if="nonTemplatePairs.length > 0">
|
||||
<div v-for="(pair, index) in nonTemplatePairs" :key="index" class="key-value-pair">
|
||||
<v-row no-gutters align="center" class="mb-2">
|
||||
<v-col cols="4">
|
||||
<v-text-field
|
||||
@@ -48,15 +49,29 @@
|
||||
hide-details
|
||||
placeholder="字符串值"
|
||||
></v-text-field>
|
||||
<v-text-field
|
||||
v-else-if="pair.type === 'number'"
|
||||
v-model.number="pair.value"
|
||||
type="number"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details
|
||||
placeholder="数值"
|
||||
></v-text-field>
|
||||
<div v-else-if="pair.type === 'number' || pair.type === 'float' || pair.type === 'int'" class="d-flex align-center gap-2 flex-grow-1">
|
||||
<v-slider
|
||||
v-if="pair.slider"
|
||||
:model-value="Number(pair.value) || 0"
|
||||
@update:model-value="pair.value = $event"
|
||||
:min="pair.slider.min"
|
||||
:max="pair.slider.max"
|
||||
:step="pair.slider.step"
|
||||
color="primary"
|
||||
density="compact"
|
||||
hide-details
|
||||
class="flex-grow-1"
|
||||
></v-slider>
|
||||
<v-text-field
|
||||
v-model.number="pair.value"
|
||||
type="number"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details
|
||||
placeholder="数值"
|
||||
:style="pair.slider ? 'max-width: 120px;' : ''"
|
||||
></v-text-field>
|
||||
</div>
|
||||
<v-switch
|
||||
v-else-if="pair.type === 'boolean'"
|
||||
v-model="pair.value"
|
||||
@@ -64,6 +79,16 @@
|
||||
hide-details
|
||||
color="primary"
|
||||
></v-switch>
|
||||
<v-text-field
|
||||
v-if="pair.type === 'json'"
|
||||
v-model="pair.value"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details="auto"
|
||||
placeholder="JSON"
|
||||
@blur="updateJSON(index, pair.value)"
|
||||
:error-messages="pair.jsonError"
|
||||
></v-text-field>
|
||||
</v-col>
|
||||
<v-col cols="1" class="pl-2">
|
||||
<v-btn
|
||||
@@ -71,7 +96,7 @@
|
||||
variant="text"
|
||||
size="small"
|
||||
color="error"
|
||||
@click="removeKeyValuePair(index)"
|
||||
@click="removeKeyValuePairByKey(pair.key)"
|
||||
>
|
||||
<v-icon>mdi-delete</v-icon>
|
||||
</v-btn>
|
||||
@@ -79,7 +104,79 @@
|
||||
</v-row>
|
||||
</div>
|
||||
</div>
|
||||
<div v-else class="text-center py-8">
|
||||
|
||||
<!-- Template schema fields -->
|
||||
<div v-if="hasTemplateSchema" class="mt-4">
|
||||
<v-divider class="mb-3"></v-divider>
|
||||
<div class="text-caption text-grey mb-2">预设</div>
|
||||
<div v-for="(template, templateKey) in templateSchema" :key="templateKey" class="template-field" :class="{ 'template-field-inactive': !isTemplateKeyAdded(templateKey) }">
|
||||
<v-row no-gutters align="center" class="mb-2">
|
||||
<v-col cols="4">
|
||||
<div class="d-flex flex-column">
|
||||
<span class="text-caption font-weight-medium">{{ template.name || template.description || templateKey }}</span>
|
||||
<span v-if="template.hint" class="text-caption text-grey" style="font-size: 0.7rem;">{{ template.hint }}</span>
|
||||
</div>
|
||||
</v-col>
|
||||
<v-col cols="7" class="pl-2 d-flex align-center justify-end">
|
||||
<v-text-field
|
||||
v-if="template.type === 'string'"
|
||||
:model-value="getTemplateValue(templateKey)"
|
||||
@update:model-value="updateTemplateValue(templateKey, $event)"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details
|
||||
placeholder="字符串值"
|
||||
></v-text-field>
|
||||
<div v-else-if="template.type === 'number' || template.type === 'float' || template.type === 'int'" class="d-flex align-center ga-4 flex-grow-1">
|
||||
<v-slider
|
||||
v-if="template.slider"
|
||||
:model-value="Number(getTemplateValue(templateKey)) || 0"
|
||||
@update:model-value="updateTemplateValue(templateKey, $event)"
|
||||
:min="template.slider.min"
|
||||
:max="template.slider.max"
|
||||
:step="template.slider.step"
|
||||
color="primary"
|
||||
density="compact"
|
||||
hide-details
|
||||
class="flex-grow-1"
|
||||
></v-slider>
|
||||
<v-text-field
|
||||
:model-value="getTemplateValue(templateKey)"
|
||||
@update:model-value="updateTemplateValue(templateKey, $event)"
|
||||
type="number"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
hide-details
|
||||
placeholder="数值"
|
||||
:style="template.slider ? 'max-width: 120px;' : ''"
|
||||
></v-text-field>
|
||||
</div>
|
||||
<v-switch
|
||||
v-else-if="template.type === 'boolean' || template.type === 'bool'"
|
||||
:model-value="getTemplateValue(templateKey)"
|
||||
@update:model-value="updateTemplateValue(templateKey, $event)"
|
||||
density="compact"
|
||||
hide-details
|
||||
color="primary"
|
||||
></v-switch>
|
||||
</v-col>
|
||||
<v-col cols="1" class="pl-2">
|
||||
<v-btn
|
||||
v-if="isTemplateKeyAdded(templateKey)"
|
||||
icon
|
||||
variant="text"
|
||||
size="small"
|
||||
color="error"
|
||||
@click="removeTemplateKey(templateKey)"
|
||||
>
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="localKeyValuePairs.length === 0 && !hasTemplateSchema" class="text-center py-8">
|
||||
<v-icon size="64" color="grey-lighten-1">mdi-code-json</v-icon>
|
||||
<p class="text-grey mt-4">暂无参数</p>
|
||||
</div>
|
||||
@@ -98,7 +195,7 @@
|
||||
></v-text-field>
|
||||
<v-select
|
||||
v-model="newValueType"
|
||||
:items="['string', 'number', 'boolean']"
|
||||
:items="['string', 'number', 'boolean', 'json']"
|
||||
label="值类型"
|
||||
density="compact"
|
||||
variant="outlined"
|
||||
@@ -132,6 +229,10 @@ const props = defineProps({
|
||||
type: Object,
|
||||
required: true
|
||||
},
|
||||
itemMeta: {
|
||||
type: Object,
|
||||
default: null
|
||||
},
|
||||
buttonText: {
|
||||
type: String,
|
||||
default: '修改'
|
||||
@@ -154,11 +255,25 @@ const originalKeyValuePairs = ref([])
|
||||
const newKey = ref('')
|
||||
const newValueType = ref('string')
|
||||
|
||||
// Template schema support
|
||||
const templateSchema = computed(() => {
|
||||
return props.itemMeta?.template_schema || {}
|
||||
})
|
||||
|
||||
const hasTemplateSchema = computed(() => {
|
||||
return Object.keys(templateSchema.value).length > 0
|
||||
})
|
||||
|
||||
// 计算要显示的键名
|
||||
const displayKeys = computed(() => {
|
||||
return Object.keys(props.modelValue).slice(0, props.maxDisplayItems)
|
||||
})
|
||||
|
||||
// 分离模板字段和普通字段
|
||||
const nonTemplatePairs = computed(() => {
|
||||
return localKeyValuePairs.value.filter(pair => !templateSchema.value[pair.key])
|
||||
})
|
||||
|
||||
// 监听 modelValue 变化,主要用于初始化
|
||||
watch(() => props.modelValue, (newValue) => {
|
||||
// This watch is primarily for initialization or external changes
|
||||
@@ -168,10 +283,26 @@ watch(() => props.modelValue, (newValue) => {
|
||||
function initializeLocalKeyValuePairs() {
|
||||
localKeyValuePairs.value = []
|
||||
for (const [key, value] of Object.entries(props.modelValue)) {
|
||||
let _type = (typeof value) === 'object' ? 'json':(typeof value)
|
||||
let _value = _type === 'json'?JSON.stringify(value):value
|
||||
|
||||
// Check if this key has a template schema
|
||||
const template = templateSchema.value[key]
|
||||
if (template) {
|
||||
// Use template type if available
|
||||
_type = template.type || _type
|
||||
// Use template default if value is missing
|
||||
if (_value === undefined || _value === null) {
|
||||
_value = template.default !== undefined ? template.default : _value
|
||||
}
|
||||
}
|
||||
|
||||
localKeyValuePairs.value.push({
|
||||
key: key,
|
||||
value: value,
|
||||
type: typeof value // Store the original type
|
||||
value: _value,
|
||||
type: _type,
|
||||
slider: template?.slider,
|
||||
template: template
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -201,6 +332,9 @@ function addKeyValuePair() {
|
||||
case 'boolean':
|
||||
defaultValue = false
|
||||
break
|
||||
case 'json':
|
||||
defaultValue = "{}"
|
||||
break
|
||||
default: // string
|
||||
defaultValue = ""
|
||||
break
|
||||
@@ -215,8 +349,20 @@ function addKeyValuePair() {
|
||||
}
|
||||
}
|
||||
|
||||
function removeKeyValuePair(index) {
|
||||
localKeyValuePairs.value.splice(index, 1)
|
||||
function updateJSON(index, newValue) {
|
||||
try {
|
||||
JSON.parse(newValue)
|
||||
localKeyValuePairs.value[index].jsonError = ''
|
||||
} catch (e) {
|
||||
localKeyValuePairs.value[index].jsonError = 'JSON 格式错误'
|
||||
}
|
||||
}
|
||||
|
||||
function removeKeyValuePairByKey(key) {
|
||||
const index = localKeyValuePairs.value.findIndex(pair => pair.key === key)
|
||||
if (index >= 0) {
|
||||
localKeyValuePairs.value.splice(index, 1)
|
||||
}
|
||||
}
|
||||
|
||||
function updateKey(index, newKey) {
|
||||
@@ -234,28 +380,110 @@ function updateKey(index, newKey) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查新键名是否有模板
|
||||
const template = templateSchema.value[newKey]
|
||||
if (template) {
|
||||
// 更新类型和默认值
|
||||
localKeyValuePairs.value[index].type = template.type || localKeyValuePairs.value[index].type
|
||||
if (localKeyValuePairs.value[index].value === undefined || localKeyValuePairs.value[index].value === null || localKeyValuePairs.value[index].value === '') {
|
||||
localKeyValuePairs.value[index].value = template.default !== undefined ? template.default : localKeyValuePairs.value[index].value
|
||||
}
|
||||
localKeyValuePairs.value[index].slider = template.slider
|
||||
localKeyValuePairs.value[index].template = template
|
||||
} else {
|
||||
// 清除模板信息
|
||||
localKeyValuePairs.value[index].slider = undefined
|
||||
localKeyValuePairs.value[index].template = undefined
|
||||
}
|
||||
|
||||
// 更新本地副本
|
||||
localKeyValuePairs.value[index].key = newKey
|
||||
}
|
||||
|
||||
function isTemplateKeyAdded(templateKey) {
|
||||
return localKeyValuePairs.value.some(pair => pair.key === templateKey)
|
||||
}
|
||||
|
||||
function getTemplateValue(templateKey) {
|
||||
const pair = localKeyValuePairs.value.find(pair => pair.key === templateKey)
|
||||
if (pair) {
|
||||
return pair.value
|
||||
}
|
||||
const template = templateSchema.value[templateKey]
|
||||
return template?.default !== undefined ? template.default : getDefaultValueForType(template?.type || 'string')
|
||||
}
|
||||
|
||||
function updateTemplateValue(templateKey, newValue) {
|
||||
const existingIndex = localKeyValuePairs.value.findIndex(pair => pair.key === templateKey)
|
||||
const template = templateSchema.value[templateKey]
|
||||
|
||||
if (existingIndex >= 0) {
|
||||
// 更新现有值
|
||||
localKeyValuePairs.value[existingIndex].value = newValue
|
||||
} else {
|
||||
// 添加新字段
|
||||
let valueType = template?.type || 'string'
|
||||
localKeyValuePairs.value.push({
|
||||
key: templateKey,
|
||||
value: newValue,
|
||||
type: valueType,
|
||||
slider: template?.slider,
|
||||
template: template
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function removeTemplateKey(templateKey) {
|
||||
const index = localKeyValuePairs.value.findIndex(pair => pair.key === templateKey)
|
||||
if (index >= 0) {
|
||||
localKeyValuePairs.value.splice(index, 1)
|
||||
}
|
||||
}
|
||||
|
||||
function getDefaultValueForType(type) {
|
||||
switch (type) {
|
||||
case 'int':
|
||||
case 'float':
|
||||
case 'number':
|
||||
return 0
|
||||
case 'bool':
|
||||
case 'boolean':
|
||||
return false
|
||||
case 'json':
|
||||
return "{}"
|
||||
case 'string':
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
function confirmDialog() {
|
||||
const updatedValue = {}
|
||||
for (const pair of localKeyValuePairs.value) {
|
||||
if (pair.type === 'json' && pair.jsonError) return
|
||||
let convertedValue = pair.value
|
||||
// 根据声明的类型进行转换
|
||||
switch (pair.type) {
|
||||
case 'int':
|
||||
convertedValue = parseInt(pair.value) || 0
|
||||
break
|
||||
case 'float':
|
||||
case 'number':
|
||||
// 尝试转换为数字,如果失败则保持原值(或设为默认值0)
|
||||
convertedValue = Number(pair.value)
|
||||
// 可选:检查是否为有效数字,无效则设为0或报错
|
||||
// if (isNaN(convertedValue)) convertedValue = 0;
|
||||
break
|
||||
case 'bool':
|
||||
case 'boolean':
|
||||
// 布尔值通常由 v-switch 正确处理,但为保险起见可以显式转换
|
||||
// 注意:在 JavaScript 中,只有严格的 false, 0, "", null, undefined, NaN 会被转换为 false
|
||||
// 这里直接赋值 pair.value 应该是安全的,因为 v-model 绑定的就是布尔值
|
||||
// convertedValue = Boolean(pair.value)
|
||||
break
|
||||
case 'json':
|
||||
convertedValue = JSON.parse(pair.value)
|
||||
break
|
||||
case 'string':
|
||||
default:
|
||||
// 默认转换为字符串
|
||||
@@ -279,4 +507,12 @@ function cancelDialog() {
|
||||
.key-value-pair {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.template-field {
|
||||
transition: opacity 0.2s;
|
||||
}
|
||||
|
||||
.template-field-inactive {
|
||||
opacity: 0.8;
|
||||
}
|
||||
</style>
|
||||
@@ -0,0 +1,450 @@
|
||||
<template>
|
||||
<div class="template-list-editor">
|
||||
<div class="top-bar d-flex align-center justify-end mb-3">
|
||||
<v-menu transition="fade-transition">
|
||||
<template #activator="{ props: menuProps }">
|
||||
<v-btn
|
||||
color="primary"
|
||||
variant="tonal"
|
||||
size="small"
|
||||
v-bind="menuProps"
|
||||
prepend-icon="mdi-plus"
|
||||
>
|
||||
{{ addButtonText }}
|
||||
</v-btn>
|
||||
</template>
|
||||
<v-list density="compact">
|
||||
<v-list-item
|
||||
v-for="option in templateOptions"
|
||||
:key="option.value"
|
||||
@click="addEntry(option.value)"
|
||||
>
|
||||
<v-list-item-title>{{ option.label }}</v-list-item-title>
|
||||
<v-list-item-subtitle v-if="option.hint">{{ option.hint }}</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-menu>
|
||||
</div>
|
||||
|
||||
<v-alert
|
||||
v-if="!modelValue || modelValue.length === 0"
|
||||
type="info"
|
||||
variant="tonal"
|
||||
density="compact"
|
||||
class="mb-3"
|
||||
>
|
||||
{{ emptyHintText }}
|
||||
</v-alert>
|
||||
|
||||
<v-card
|
||||
v-for="(entry, entryIndex) in modelValue"
|
||||
:key="entryIndex"
|
||||
variant="outlined"
|
||||
class="mb-3"
|
||||
>
|
||||
<v-card-title
|
||||
class="d-flex align-center justify-space-between entry-header"
|
||||
@click="toggleEntry(entryIndex)"
|
||||
>
|
||||
<div class="d-flex align-center ga-2">
|
||||
<v-btn
|
||||
icon
|
||||
size="small"
|
||||
variant="text"
|
||||
:title="expandedEntries[entryIndex] ? (t('core.common.collapse') || '收起') : (t('core.common.expand') || '展开')"
|
||||
>
|
||||
<v-icon>{{ expandedEntries[entryIndex] ? 'mdi-chevron-down' : 'mdi-chevron-right' }}</v-icon>
|
||||
</v-btn>
|
||||
<div class="d-flex flex-column">
|
||||
<v-list-item-title class="property-name">{{ templateLabel(entry.__template_key) }}</v-list-item-title>
|
||||
<v-list-item-subtitle class="property-hint" v-if="getTemplate(entry)?.hint || getTemplate(entry)?.description">
|
||||
{{ getTemplate(entry)?.hint || getTemplate(entry)?.description }}
|
||||
</v-list-item-subtitle>
|
||||
</div>
|
||||
</div>
|
||||
<div class="d-flex align-center ga-1">
|
||||
<v-btn icon size="small" variant="text" color="error" @click.stop="removeEntry(entryIndex)">
|
||||
<v-icon>mdi-delete</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-card-title>
|
||||
<v-expand-transition>
|
||||
<v-card-text v-show="expandedEntries[entryIndex]" class="px-0 py-1">
|
||||
<div v-if="!getTemplate(entry)" class="px-4 py-2">
|
||||
<v-alert type="error" variant="tonal" density="compact">{{ t('core.common.templateList.missingTemplate') || '找不到对应模板,请删除后重新添加。' }}</v-alert>
|
||||
</div>
|
||||
<div v-else class="template-entry-body">
|
||||
<template v-for="(itemMeta, itemKey, metaIndex) in getTemplate(entry).items" :key="itemKey">
|
||||
<!-- Nested Object -->
|
||||
<div
|
||||
v-if="itemMeta?.type === 'object' && !itemMeta?.invisible && shouldShowItem(itemMeta, entry)"
|
||||
class="nested-container mx-4"
|
||||
>
|
||||
<div class="config-section mb-2">
|
||||
<v-list-item-title class="config-title">
|
||||
{{ itemMeta?.description || itemKey }}
|
||||
</v-list-item-title>
|
||||
<v-list-item-subtitle class="config-hint" v-if="itemMeta?.hint">
|
||||
{{ itemMeta.hint }}
|
||||
</v-list-item-subtitle>
|
||||
</div>
|
||||
<div v-for="(childMeta, childKey, childIndex) in itemMeta.items" :key="childKey">
|
||||
<template v-if="!childMeta?.invisible && shouldShowItem(childMeta, entry)">
|
||||
<v-row class="config-row">
|
||||
<v-col cols="12" sm="6" class="property-info">
|
||||
<v-list-item density="compact">
|
||||
<v-list-item-title class="property-name">
|
||||
{{ childMeta?.description || childKey }}
|
||||
</v-list-item-title>
|
||||
<v-list-item-subtitle class="property-hint">
|
||||
{{ childMeta?.hint }}
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</v-col>
|
||||
<v-col cols="12" sm="6" class="config-input">
|
||||
<ConfigItemRenderer
|
||||
v-model="entry[itemKey][childKey]"
|
||||
:item-meta="childMeta"
|
||||
/>
|
||||
</v-col>
|
||||
</v-row>
|
||||
<v-divider
|
||||
v-if="hasVisibleItemsAfter(Object.entries(itemMeta.items), childIndex, entry)"
|
||||
class="config-divider"
|
||||
></v-divider>
|
||||
</template>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Regular Property -->
|
||||
<template v-else-if="!itemMeta?.invisible && shouldShowItem(itemMeta, entry)">
|
||||
<v-row class="config-row">
|
||||
<v-col cols="12" sm="6" class="property-info">
|
||||
<v-list-item density="compact">
|
||||
<v-list-item-title class="property-name">
|
||||
<span v-if="itemMeta?.description">{{ itemMeta?.description }} <span class="property-key">({{ itemKey }})</span></span>
|
||||
<span v-else>{{ itemKey }}</span>
|
||||
</v-list-item-title>
|
||||
<v-list-item-subtitle class="property-hint">
|
||||
{{ itemMeta?.hint }}
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</v-col>
|
||||
<v-col cols="12" sm="6" class="config-input">
|
||||
<ConfigItemRenderer
|
||||
v-model="entry[itemKey]"
|
||||
:item-meta="itemMeta"
|
||||
/>
|
||||
</v-col>
|
||||
</v-row>
|
||||
<v-divider
|
||||
v-if="hasVisibleItemsAfter(Object.entries(getTemplate(entry).items), metaIndex, entry)"
|
||||
class="config-divider"
|
||||
></v-divider>
|
||||
</template>
|
||||
</template>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-expand-transition>
|
||||
</v-card>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup>
|
||||
import { computed, ref, watch } from 'vue'
|
||||
import ConfigItemRenderer from './ConfigItemRenderer.vue'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
|
||||
const props = defineProps({
|
||||
modelValue: {
|
||||
type: Array,
|
||||
default: () => []
|
||||
},
|
||||
templates: {
|
||||
type: Object,
|
||||
default: () => ({})
|
||||
}
|
||||
})
|
||||
|
||||
const emit = defineEmits(['update:modelValue'])
|
||||
const { t } = useI18n()
|
||||
|
||||
const expandedEntries = ref({})
|
||||
|
||||
const safeText = (val, fallback) => (val && typeof val === 'string' ? val : fallback)
|
||||
const addButtonText = computed(() => safeText(t('core.common.templateList.addEntry'), '添加条目'))
|
||||
const emptyHintText = computed(() => safeText(t('core.common.templateList.empty'), '暂无条目,请先选择模板并添加。'))
|
||||
const defaultValueMap = {
|
||||
int: 0,
|
||||
float: 0.0,
|
||||
bool: false,
|
||||
string: '',
|
||||
text: '',
|
||||
list: [],
|
||||
object: {},
|
||||
template_list: []
|
||||
}
|
||||
|
||||
const templateOptions = computed(() => {
|
||||
return Object.entries(props.templates || {}).map(([value, meta]) => ({
|
||||
label: meta?.name || value,
|
||||
value,
|
||||
hint: meta?.hint || meta?.description || ''
|
||||
}))
|
||||
})
|
||||
|
||||
function templateLabel(key) {
|
||||
if (!key) return t('core.common.templateList.unknownTemplate') || '未指定模板'
|
||||
return props.templates?.[key]?.name || key
|
||||
}
|
||||
|
||||
function buildDefaults(itemsMeta = {}) {
|
||||
const result = {}
|
||||
for (const [k, meta] of Object.entries(itemsMeta)) {
|
||||
if (!meta || !meta.type) continue
|
||||
const fallback = Object.prototype.hasOwnProperty.call(meta, 'default')
|
||||
? meta.default
|
||||
: defaultValueMap[meta.type]
|
||||
|
||||
if (meta.type === 'object') {
|
||||
result[k] = buildDefaults(meta.items || {})
|
||||
} else {
|
||||
result[k] = fallback
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
function applyDefaults(target, itemsMeta = {}) {
|
||||
let changed = false
|
||||
for (const [k, meta] of Object.entries(itemsMeta)) {
|
||||
if (!meta || !meta.type) continue
|
||||
const hasDefault = Object.prototype.hasOwnProperty.call(meta, 'default')
|
||||
const fallback = hasDefault ? meta.default : defaultValueMap[meta.type]
|
||||
|
||||
if (meta.type === 'object') {
|
||||
if (!target[k] || typeof target[k] !== 'object') {
|
||||
target[k] = buildDefaults(meta.items || {})
|
||||
changed = true
|
||||
} else {
|
||||
if (applyDefaults(target[k], meta.items || {})) {
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
} else if (!(k in target)) {
|
||||
target[k] = fallback
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
return changed
|
||||
}
|
||||
|
||||
function ensureEntryDefaults() {
|
||||
if (!Array.isArray(props.modelValue)) return
|
||||
|
||||
let totalChanged = false
|
||||
const nextValue = props.modelValue.map((entry, idx) => {
|
||||
const template = getTemplate(entry)
|
||||
if (!template || !template.items) return entry
|
||||
|
||||
// 我们必须克隆以避免就地修改
|
||||
const newEntry = JSON.parse(JSON.stringify(entry))
|
||||
let entryChanged = applyDefaults(newEntry, template.items)
|
||||
|
||||
if (!Object.prototype.hasOwnProperty.call(newEntry, '__template_key')) {
|
||||
newEntry.__template_key = ''
|
||||
entryChanged = true
|
||||
}
|
||||
|
||||
if (!(idx in expandedEntries.value)) {
|
||||
expandedEntries.value[idx] = false
|
||||
}
|
||||
|
||||
if (entryChanged) {
|
||||
totalChanged = true
|
||||
}
|
||||
return newEntry
|
||||
})
|
||||
|
||||
if (totalChanged) {
|
||||
emit('update:modelValue', nextValue)
|
||||
}
|
||||
}
|
||||
|
||||
watch(
|
||||
() => props.modelValue,
|
||||
() => ensureEntryDefaults(),
|
||||
{ immediate: true, deep: true }
|
||||
)
|
||||
|
||||
function addEntry(templateKey) {
|
||||
if (!templateKey) return
|
||||
const template = props.templates?.[templateKey]
|
||||
if (!template) return
|
||||
const newEntry = {
|
||||
__template_key: templateKey,
|
||||
...buildDefaults(template.items || {})
|
||||
}
|
||||
emit('update:modelValue', [...(props.modelValue || []), newEntry])
|
||||
expandedEntries.value[props.modelValue.length] = true
|
||||
}
|
||||
|
||||
function removeEntry(index) {
|
||||
const next = [...(props.modelValue || [])]
|
||||
next.splice(index, 1)
|
||||
const rebuilt = {}
|
||||
next.forEach((_, idx) => {
|
||||
const sourceIdx = idx >= index ? idx + 1 : idx
|
||||
rebuilt[idx] = expandedEntries.value[sourceIdx] ?? false
|
||||
})
|
||||
expandedEntries.value = rebuilt
|
||||
emit('update:modelValue', next)
|
||||
}
|
||||
|
||||
function toggleEntry(index) {
|
||||
expandedEntries.value[index] = !expandedEntries.value[index]
|
||||
}
|
||||
|
||||
function getTemplate(entry) {
|
||||
if (!entry) return null
|
||||
const key = entry.__template_key
|
||||
if (!key) return null
|
||||
return props.templates?.[key] || null
|
||||
}
|
||||
|
||||
function getValueBySelector(obj, selector) {
|
||||
const keys = selector.split('.')
|
||||
let current = obj
|
||||
for (const key of keys) {
|
||||
if (current && typeof current === 'object' && key in current) {
|
||||
current = current[key]
|
||||
} else {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
return current
|
||||
}
|
||||
|
||||
function shouldShowItem(itemMeta, entry) {
|
||||
if (!itemMeta?.condition) {
|
||||
return true
|
||||
}
|
||||
for (const [conditionKey, expectedValue] of Object.entries(itemMeta.condition)) {
|
||||
const actualValue = getValueBySelector(entry, conditionKey)
|
||||
if (actualValue !== expectedValue) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
function hasVisibleItemsAfter(entries, currentIndex, entry) {
|
||||
for (let i = currentIndex + 1; i < entries.length; i++) {
|
||||
const [k, meta] = entries[i]
|
||||
if (!meta?.invisible && shouldShowItem(meta, entry)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.template-list-editor {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.entry-header {
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.entry-header:hover {
|
||||
background-color: rgba(0, 0, 0, 0.02);
|
||||
}
|
||||
|
||||
.top-bar {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.config-section {
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.config-title {
|
||||
font-weight: 600;
|
||||
font-size: 1rem;
|
||||
color: var(--v-theme-primaryText);
|
||||
}
|
||||
|
||||
.config-hint {
|
||||
font-size: 0.75rem;
|
||||
color: var(--v-theme-secondaryText);
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.template-entry-body {
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.config-row {
|
||||
margin: 0;
|
||||
align-items: center;
|
||||
padding: 4px 8px;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.config-row:hover {
|
||||
background-color: rgba(0, 0, 0, 0.03);
|
||||
}
|
||||
|
||||
.property-info {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.property-name {
|
||||
font-size: 0.875rem;
|
||||
font-weight: 600;
|
||||
color: var(--v-theme-primaryText);
|
||||
}
|
||||
|
||||
.property-hint {
|
||||
font-size: 0.75rem;
|
||||
color: var(--v-theme-secondaryText);
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.property-key {
|
||||
font-size: 0.85em;
|
||||
opacity: 0.7;
|
||||
font-weight: normal;
|
||||
}
|
||||
|
||||
.config-input {
|
||||
padding: 4px 8px;
|
||||
}
|
||||
|
||||
.config-field {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.config-divider {
|
||||
border-color: rgba(0, 0, 0, 0.05);
|
||||
margin: 0px 16px;
|
||||
}
|
||||
|
||||
.nested-container {
|
||||
border: 1px solid rgba(0, 0, 0, 0.1);
|
||||
border-radius: 8px;
|
||||
padding: 12px;
|
||||
margin: 12px 0;
|
||||
background-color: rgba(0, 0, 0, 0.02);
|
||||
box-shadow: 0 1px 2px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.editor-container {
|
||||
position: relative;
|
||||
display: flex;
|
||||
width: 100%;
|
||||
}
|
||||
</style>
|
||||
@@ -508,12 +508,24 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
const sourceId = editableProviderSource.value?.id || selectedProviderSource.value.id
|
||||
const newId = `${sourceId}/${modelName}`
|
||||
|
||||
const modalities = ['text']
|
||||
if (supportsImageInput(getModelMetadata(modelName))) {
|
||||
modalities.push('image')
|
||||
const metadata = getModelMetadata(modelName)
|
||||
let modalities: string[]
|
||||
|
||||
if (!metadata) {
|
||||
modalities = ['text', 'image', 'tool_use']
|
||||
} else {
|
||||
modalities = ['text']
|
||||
if (supportsImageInput(metadata)) {
|
||||
modalities.push('image')
|
||||
}
|
||||
if (supportsToolCall(metadata)) {
|
||||
modalities.push('tool_use')
|
||||
}
|
||||
}
|
||||
if (supportsToolCall(getModelMetadata(modelName))) {
|
||||
modalities.push('tool_use')
|
||||
|
||||
let max_context_tokens = 0
|
||||
if (metadata?.limit?.context && typeof metadata.limit.context === 'number') {
|
||||
max_context_tokens = metadata.limit.context
|
||||
}
|
||||
|
||||
const newProvider = {
|
||||
@@ -522,7 +534,8 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
provider_source_id: sourceId,
|
||||
model: modelName,
|
||||
modalities,
|
||||
custom_extra_body: {}
|
||||
custom_extra_body: {},
|
||||
max_context_tokens: max_context_tokens
|
||||
}
|
||||
|
||||
try {
|
||||
|
||||
@@ -65,9 +65,16 @@
|
||||
"fullscreen": "Fullscreen Edit",
|
||||
"editingTitle": "Editing Content"
|
||||
},
|
||||
"templateList": {
|
||||
"addEntry": "Add Entry",
|
||||
"empty": "No entries yet, pick a template to add",
|
||||
"missingTemplate": "Template not found, please remove and add again.",
|
||||
"unknownTemplate": "Template not specified"
|
||||
},
|
||||
"list": {
|
||||
"addItemPlaceholder": "Add new item, press Enter to confirm",
|
||||
"addButton": "Add",
|
||||
"addMore": "Add More",
|
||||
"batchImport": "Batch Import",
|
||||
"batchImportTitle": "Batch Import",
|
||||
"batchImportLabel": "One item per line",
|
||||
@@ -84,7 +91,6 @@
|
||||
"enabled": "Enabled",
|
||||
"disabled": "Disabled",
|
||||
"delete": "Delete",
|
||||
"copy": "Copy",
|
||||
"edit": "Edit",
|
||||
"copy": "Copy",
|
||||
"noData": "No data available"
|
||||
|
||||
@@ -11,7 +11,12 @@
|
||||
},
|
||||
"agent_runner_type": {
|
||||
"description": "Runner",
|
||||
"labels": ["Built-in Agent", "Dify", "Coze", "Alibaba Cloud Bailian Application"]
|
||||
"labels": [
|
||||
"Built-in Agent",
|
||||
"Dify",
|
||||
"Coze",
|
||||
"Alibaba Cloud Bailian Application"
|
||||
]
|
||||
},
|
||||
"coze_agent_runner_provider_id": {
|
||||
"description": "Coze Agent Runner Provider ID"
|
||||
@@ -128,6 +133,39 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"truncate_and_compress": {
|
||||
"description": "Context Management Strategy",
|
||||
"provider_settings": {
|
||||
"max_context_length": {
|
||||
"description": "Maximum Conversation Turns",
|
||||
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "Dequeue Conversation Turns",
|
||||
"hint": "Number of conversation turns to discard at once when maximum context length is exceeded"
|
||||
},
|
||||
"context_limit_reached_strategy": {
|
||||
"description": "Handling When Model Context Window is Exceeded",
|
||||
"labels": [
|
||||
"Truncate by Turns",
|
||||
"Compress by LLM"
|
||||
],
|
||||
"hint": "When 'Truncate by Turns' is selected, the oldest N conversation turns will be discarded based on the 'Dequeue Conversation Turns' setting above. When 'Compress by LLM' is selected, the specified model will be used for context compression."
|
||||
},
|
||||
"llm_compress_instruction": {
|
||||
"description": "Context Compression Instruction",
|
||||
"hint": "If empty, the default prompt will be used."
|
||||
},
|
||||
"llm_compress_keep_recent": {
|
||||
"description": "Keep Recent Turns When Compressing",
|
||||
"hint": "Always keep the most recent N turns of conversation when compressing context."
|
||||
},
|
||||
"llm_compress_provider_id": {
|
||||
"description": "Model Provider ID for Context Compression",
|
||||
"hint": "When left empty, will fall back to the 'Truncate by Turns' strategy."
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "Other Settings",
|
||||
"provider_settings": {
|
||||
@@ -161,15 +199,10 @@
|
||||
"unsupported_streaming_strategy": {
|
||||
"description": "Platforms Without Streaming Support",
|
||||
"hint": "Select the handling method for platforms that don't support streaming responses. Real-time segmented reply sends content immediately when the system detects segment points like punctuation during streaming reception",
|
||||
"labels": ["Real-time Segmented Reply", "Disable Streaming Response"]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "Maximum Conversation Rounds",
|
||||
"hint": "Discards the oldest parts when this count is exceeded. One conversation round counts as 1, -1 means unlimited"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "Dequeue Conversation Rounds",
|
||||
"hint": "Number of conversation rounds to discard at once when maximum context length is exceeded"
|
||||
"labels": [
|
||||
"Real-time Segmented Reply",
|
||||
"Disable Streaming Response"
|
||||
]
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "Additional LLM Chat Wake Prefix",
|
||||
@@ -387,7 +420,10 @@
|
||||
},
|
||||
"split_mode": {
|
||||
"description": "Split Mode",
|
||||
"labels": ["Regex", "Words List"]
|
||||
"labels": [
|
||||
"Regex",
|
||||
"Words List"
|
||||
]
|
||||
},
|
||||
"regex": {
|
||||
"description": "Segmentation Regular Expression"
|
||||
@@ -488,4 +524,4 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -145,6 +145,11 @@
|
||||
"message": "This plugin has been flagged as containing security risks, including unsafe code or functionalities that may cause system malfunctions or data loss. Do you wish to proceed with the installation?",
|
||||
"confirm": "Continue",
|
||||
"cancel": "Cancel"
|
||||
},
|
||||
"forceUpdate": {
|
||||
"title": "No New Version Detected",
|
||||
"message": "No new version detected for this plugin. Do you want to force reinstall? This will pull the latest code from the remote repository.",
|
||||
"confirm": "Force Update"
|
||||
}
|
||||
},
|
||||
"messages": {
|
||||
@@ -185,7 +190,8 @@
|
||||
"reloadPlugin": "Reload Extension",
|
||||
"togglePlugin": "Extension",
|
||||
"viewHandlers": "View Handlers",
|
||||
"updateTo": "Update to"
|
||||
"updateTo": "Update to",
|
||||
"reinstall": "Reinstall"
|
||||
},
|
||||
"status": {
|
||||
"hasUpdate": "New version available",
|
||||
@@ -207,4 +213,4 @@
|
||||
"goToManage": "Go to Manage",
|
||||
"later": "Later"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -129,6 +129,7 @@
|
||||
"manualDialogPreviewLabel": "Display ID (auto generated)",
|
||||
"manualDialogPreviewHint": "Generated as sourceId/modelId",
|
||||
"manualModelRequired": "Please enter a model ID",
|
||||
"manualModelExists": "Model already exists"
|
||||
"manualModelExists": "Model already exists",
|
||||
"configure": "Configure"
|
||||
}
|
||||
}
|
||||
@@ -64,6 +64,10 @@
|
||||
"uploadAndCheck": "Upload & Check",
|
||||
"uploading": "Uploading...",
|
||||
"uploadWait": "Please wait, uploading backup file...",
|
||||
"uploadInit": "Initializing upload...",
|
||||
"uploadingChunks": "Uploading chunks...",
|
||||
"uploadComplete": "Upload complete, merging file...",
|
||||
"checking": "Checking backup file...",
|
||||
"invalidBackup": "Invalid backup file",
|
||||
"backupContents": "Backup Contents",
|
||||
"tables": "tables",
|
||||
@@ -93,7 +97,17 @@
|
||||
"list": {
|
||||
"empty": "No backup files",
|
||||
"refresh": "Refresh List",
|
||||
"confirmDelete": "Are you sure you want to delete this backup file? This action cannot be undone."
|
||||
"confirmDelete": "Are you sure you want to delete this backup file? This action cannot be undone.",
|
||||
"uploaded": "Uploaded",
|
||||
"restore": "Restore this backup",
|
||||
"rename": "Rename",
|
||||
"renameTitle": "Rename Backup File",
|
||||
"newName": "New Filename",
|
||||
"renameHint": "Filename can only contain letters, numbers, underscores, hyphens and dots",
|
||||
"renameRequired": "Please enter a filename",
|
||||
"renameInvalidChars": "Filename contains invalid characters",
|
||||
"renameFailed": "Rename failed",
|
||||
"ftpHint": "For large backup files, you can also upload directly to the data/backups directory via FTP/SFTP"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -65,9 +65,16 @@
|
||||
"fullscreen": "全屏编辑",
|
||||
"editingTitle": "编辑内容"
|
||||
},
|
||||
"templateList": {
|
||||
"addEntry": "添加条目",
|
||||
"empty": "暂无条目,请选择模板添加",
|
||||
"missingTemplate": "找不到对应模板,请删除后重新添加。",
|
||||
"unknownTemplate": "未指定模板"
|
||||
},
|
||||
"list": {
|
||||
"addItemPlaceholder": "添加新项,按回车确认添加",
|
||||
"addButton": "添加",
|
||||
"addMore": "添加更多",
|
||||
"batchImport": "批量导入",
|
||||
"batchImportTitle": "批量导入",
|
||||
"batchImportLabel": "每行一个项目",
|
||||
@@ -88,4 +95,4 @@
|
||||
"copy": "复制",
|
||||
"noData": "暂无数据"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,6 +133,36 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"truncate_and_compress": {
|
||||
"description": "上下文管理策略",
|
||||
"provider_settings": {
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
|
||||
},
|
||||
"context_limit_reached_strategy": {
|
||||
"description": "超出模型上下文窗口时的处理方式",
|
||||
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
|
||||
"hint": "当按对话轮数截断时,会根据上面\"丢弃对话轮数\"的配置丢弃最旧的 N 轮对话。当由 LLM 压缩上下文时,会使用指定的模型进行上下文压缩。"
|
||||
},
|
||||
"llm_compress_instruction": {
|
||||
"description": "上下文压缩提示词",
|
||||
"hint": "如果为空则使用默认提示词。"
|
||||
},
|
||||
"llm_compress_keep_recent": {
|
||||
"description": "压缩时保留最近对话轮数",
|
||||
"hint": "始终保留的最近 N 轮对话。"
|
||||
},
|
||||
"llm_compress_provider_id": {
|
||||
"description": "用于上下文压缩的模型提供商 ID",
|
||||
"hint": "留空时将降级为\"按对话轮数截断\"的策略。"
|
||||
}
|
||||
}
|
||||
},
|
||||
"others": {
|
||||
"description": "其他配置",
|
||||
"provider_settings": {
|
||||
@@ -171,14 +201,7 @@
|
||||
"关闭流式回复"
|
||||
]
|
||||
},
|
||||
"max_context_length": {
|
||||
"description": "最多携带对话轮数",
|
||||
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制"
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话轮数",
|
||||
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数"
|
||||
},
|
||||
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求"
|
||||
|
||||
@@ -145,6 +145,11 @@
|
||||
"message": "该插件可能包含不安全的代码或功能,可能导致系统异常或数据损失等。请确认是否继续安装?",
|
||||
"confirm": "继续",
|
||||
"cancel": "取消"
|
||||
},
|
||||
"forceUpdate": {
|
||||
"title": "未检测到新版本",
|
||||
"message": "当前插件未检测到新版本,是否强制重新安装?这将从远程仓库拉取最新代码。",
|
||||
"confirm": "强制更新"
|
||||
}
|
||||
},
|
||||
"messages": {
|
||||
@@ -185,7 +190,8 @@
|
||||
"reloadPlugin": "重载插件",
|
||||
"togglePlugin": "插件",
|
||||
"viewHandlers": "查看行为",
|
||||
"updateTo": "更新到"
|
||||
"updateTo": "更新到",
|
||||
"reinstall": "重新安装"
|
||||
},
|
||||
"status": {
|
||||
"hasUpdate": "有新版本可用",
|
||||
|
||||
@@ -130,6 +130,7 @@
|
||||
"manualDialogPreviewLabel": "显示 ID(自动生成)",
|
||||
"manualDialogPreviewHint": "生成规则:源ID/模型ID",
|
||||
"manualModelRequired": "请输入模型 ID",
|
||||
"manualModelExists": "该模型已存在"
|
||||
"manualModelExists": "该模型已存在",
|
||||
"configure": "配置"
|
||||
}
|
||||
}
|
||||
@@ -64,6 +64,10 @@
|
||||
"uploadAndCheck": "上传并检查",
|
||||
"uploading": "正在上传...",
|
||||
"uploadWait": "请稍候,正在上传备份文件...",
|
||||
"uploadInit": "正在初始化上传...",
|
||||
"uploadingChunks": "正在上传分片...",
|
||||
"uploadComplete": "上传完成,正在合并文件...",
|
||||
"checking": "正在检查备份文件...",
|
||||
"invalidBackup": "无效的备份文件",
|
||||
"backupContents": "备份内容",
|
||||
"tables": "个数据表",
|
||||
@@ -93,7 +97,17 @@
|
||||
"list": {
|
||||
"empty": "暂无备份文件",
|
||||
"refresh": "刷新列表",
|
||||
"confirmDelete": "确定要删除这个备份文件吗?此操作不可撤销。"
|
||||
"confirmDelete": "确定要删除这个备份文件吗?此操作不可撤销。",
|
||||
"uploaded": "已上传",
|
||||
"restore": "恢复此备份",
|
||||
"rename": "重命名",
|
||||
"renameTitle": "重命名备份文件",
|
||||
"newName": "新文件名",
|
||||
"renameHint": "文件名只能包含字母、数字、下划线、连字符和点",
|
||||
"renameRequired": "请输入文件名",
|
||||
"renameInvalidChars": "文件名包含非法字符",
|
||||
"renameFailed": "重命名失败",
|
||||
"ftpHint": "对于较大的备份文件,也可以通过 FTP/SFTP 等方式直接上传到 data/backups 目录"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
<script setup>
|
||||
import { ref, shallowRef, onMounted, onUnmounted } from 'vue';
|
||||
import { ref, shallowRef, onMounted, onUnmounted, watch } from 'vue';
|
||||
import { useCustomizerStore } from '../../../stores/customizer';
|
||||
import { useI18n } from '@/i18n/composables';
|
||||
import sidebarItems from './sidebarItem';
|
||||
@@ -12,6 +12,10 @@ const { t } = useI18n();
|
||||
const customizer = useCustomizerStore();
|
||||
const sidebarMenu = shallowRef(sidebarItems);
|
||||
|
||||
// 侧边栏分组展开状态持久化
|
||||
const openedItems = ref(JSON.parse(localStorage.getItem('sidebar_openedItems') || '[]'));
|
||||
watch(openedItems, (val) => localStorage.setItem('sidebar_openedItems', JSON.stringify(val)), { deep: true });
|
||||
|
||||
// Apply customization on mount and listen for storage changes
|
||||
const handleStorageChange = (e) => {
|
||||
if (e.key === 'astrbot_sidebar_customization') {
|
||||
@@ -243,7 +247,7 @@ function openChangelogDialog() {
|
||||
:rail="customizer.mini_sidebar"
|
||||
>
|
||||
<div class="sidebar-container">
|
||||
<v-list class="pa-4 listitem flex-grow-1">
|
||||
<v-list class="pa-4 listitem flex-grow-1" v-model:opened="openedItems" :open-strategy="'multiple'">
|
||||
<template v-for="(item, i) in sidebarMenu" :key="i">
|
||||
<NavItem :item="item" class="leftPadding" />
|
||||
</template>
|
||||
|
||||
@@ -77,14 +77,26 @@ const readmeDialog = reactive({
|
||||
repoUrl: null
|
||||
});
|
||||
|
||||
// 强制更新确认对话框
|
||||
const forceUpdateDialog = reactive({
|
||||
show: false,
|
||||
extensionName: ''
|
||||
});
|
||||
|
||||
// 新增变量支持列表视图
|
||||
const isListView = ref(false);
|
||||
// 从 localStorage 恢复显示模式,默认为 false(卡片视图)
|
||||
const getInitialListViewMode = () => {
|
||||
if (typeof window !== 'undefined' && window.localStorage) {
|
||||
return localStorage.getItem('pluginListViewMode') === 'true';
|
||||
}
|
||||
return false;
|
||||
};
|
||||
const isListView = ref(getInitialListViewMode());
|
||||
const pluginSearch = ref("");
|
||||
const loading_ = ref(false);
|
||||
|
||||
// 分页相关
|
||||
const currentPage = ref(1);
|
||||
const itemsPerPage = ref(6); // 每页显示6个卡片 (2行 x 3列,避免滚动)
|
||||
|
||||
// 危险插件确认对话框
|
||||
const dangerConfirmDialog = ref(false);
|
||||
@@ -113,7 +125,6 @@ const uploadTab = ref('file');
|
||||
const showPluginFullName = ref(false);
|
||||
const marketSearch = ref("");
|
||||
const debouncedMarketSearch = ref("");
|
||||
const filterKeys = ['name', 'desc', 'author'];
|
||||
const refreshingMarket = ref(false);
|
||||
const sortBy = ref('default'); // default, stars, author, updated
|
||||
const sortOrder = ref('desc'); // desc (降序) or asc (升序)
|
||||
@@ -162,18 +173,6 @@ const pluginHeaders = computed(() => [
|
||||
]);
|
||||
|
||||
|
||||
// 插件市场表头
|
||||
const pluginMarketHeaders = computed(() => [
|
||||
{ title: tm('table.headers.name'), key: 'name', maxWidth: '200px' },
|
||||
{ title: tm('table.headers.description'), key: 'desc', maxWidth: '250px' },
|
||||
{ title: tm('table.headers.author'), key: 'author', maxWidth: '90px' },
|
||||
{ title: tm('table.headers.stars'), key: 'stars', maxWidth: '80px' },
|
||||
{ title: tm('table.headers.lastUpdate'), key: 'updated_at', maxWidth: '100px' },
|
||||
{ title: tm('table.headers.tags'), key: 'tags', maxWidth: '100px' },
|
||||
{ title: tm('table.headers.actions'), key: 'actions', sortable: false }
|
||||
]);
|
||||
|
||||
|
||||
// 过滤要显示的插件
|
||||
const filteredExtensions = computed(() => {
|
||||
const data = Array.isArray(extension_data?.data) ? extension_data.data : [];
|
||||
@@ -197,9 +196,6 @@ const filteredPlugins = computed(() => {
|
||||
});
|
||||
});
|
||||
|
||||
const pinnedPlugins = computed(() => {
|
||||
return pluginMarketData.value.filter(plugin => plugin?.pinned);
|
||||
});
|
||||
|
||||
// 过滤后的插件市场数据(带搜索)
|
||||
const filteredMarketPlugins = computed(() => {
|
||||
@@ -385,7 +381,17 @@ const handleUninstallConfirm = (options) => {
|
||||
}
|
||||
};
|
||||
|
||||
const updateExtension = async (extension_name) => {
|
||||
const updateExtension = async (extension_name, forceUpdate = false) => {
|
||||
// 查找插件信息
|
||||
const ext = extension_data.data?.find(e => e.name === extension_name);
|
||||
|
||||
// 如果没有检测到更新且不是强制更新,则弹窗确认
|
||||
if (!ext?.has_update && !forceUpdate) {
|
||||
forceUpdateDialog.extensionName = extension_name;
|
||||
forceUpdateDialog.show = true;
|
||||
return;
|
||||
}
|
||||
|
||||
loadingDialog.title = tm('status.loading');
|
||||
loadingDialog.show = true;
|
||||
try {
|
||||
@@ -417,6 +423,14 @@ const updateExtension = async (extension_name) => {
|
||||
}
|
||||
};
|
||||
|
||||
// 确认强制更新
|
||||
const confirmForceUpdate = () => {
|
||||
const name = forceUpdateDialog.extensionName;
|
||||
forceUpdateDialog.show = false;
|
||||
forceUpdateDialog.extensionName = '';
|
||||
updateExtension(name, true);
|
||||
};
|
||||
|
||||
const updateAllExtensions = async () => {
|
||||
if (updatingAll.value || updatableExtensions.value.length === 0) return;
|
||||
updatingAll.value = true;
|
||||
@@ -552,14 +566,6 @@ const viewReadme = (plugin) => {
|
||||
readmeDialog.show = true;
|
||||
};
|
||||
|
||||
|
||||
|
||||
const open = (link) => {
|
||||
if (link) {
|
||||
window.open(link, '_blank');
|
||||
}
|
||||
};
|
||||
|
||||
// 为表格视图创建一个处理安装插件的函数
|
||||
const handleInstallPlugin = async (plugin) => {
|
||||
if (plugin.tags && plugin.tags.includes('danger')) {
|
||||
@@ -918,6 +924,13 @@ watch(marketSearch, (newVal) => {
|
||||
}, 300); // 300ms 防抖延迟
|
||||
});
|
||||
|
||||
// 监听显示模式变化并保存到 localStorage
|
||||
watch(isListView, (newVal) => {
|
||||
if (typeof window !== 'undefined' && window.localStorage) {
|
||||
localStorage.setItem('pluginListViewMode', String(newVal));
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
</script>
|
||||
|
||||
@@ -1037,8 +1050,21 @@ watch(marketSearch, (newVal) => {
|
||||
|
||||
<template v-slot:item.name="{ item }">
|
||||
<div class="d-flex align-center py-2">
|
||||
<div v-if="item.logo" class="mr-3" style="flex-shrink: 0;">
|
||||
<img :src="item.logo" :alt="item.name"
|
||||
style="height: 40px; width: 40px; border-radius: 8px; object-fit: cover;" />
|
||||
</div>
|
||||
<div v-else class="mr-3" style="flex-shrink: 0;">
|
||||
<img :src="defaultPluginIcon" :alt="item.name"
|
||||
style="height: 40px; width: 40px; border-radius: 8px; object-fit: cover;" />
|
||||
</div>
|
||||
<div>
|
||||
<div class="text-subtitle-1 font-weight-medium">{{ item.name }}</div>
|
||||
<div class="text-subtitle-1 font-weight-medium">
|
||||
{{ item.display_name && item.display_name.length ? item.display_name : item.name }}
|
||||
</div>
|
||||
<div v-if="item.display_name && item.display_name.length" class="text-caption text-medium-emphasis mt-1">
|
||||
{{ item.name }}
|
||||
</div>
|
||||
<div v-if="item.reserved" class="d-flex align-center mt-1">
|
||||
<v-chip color="primary" size="x-small" class="font-weight-medium">{{ tm('status.system')
|
||||
}}</v-chip>
|
||||
@@ -1048,7 +1074,7 @@ watch(marketSearch, (newVal) => {
|
||||
</template>
|
||||
|
||||
<template v-slot:item.desc="{ item }">
|
||||
<div class="text-body-2 text-medium-emphasis">{{ item.desc }}</div>
|
||||
<div class="text-body-2 text-medium-emphasis mt-2 mb-2" style="display: -webkit-box; -webkit-line-clamp: 3; line-clamp: 3; -webkit-box-orient: vertical; overflow: hidden; text-overflow: ellipsis;">{{ item.desc }}</div>
|
||||
</template>
|
||||
|
||||
<template v-slot:item.version="{ item }">
|
||||
@@ -1084,7 +1110,7 @@ watch(marketSearch, (newVal) => {
|
||||
<v-tooltip activator="parent" location="top">{{ tm('tooltips.disable') }}</v-tooltip>
|
||||
</v-btn>
|
||||
|
||||
<v-btn icon size="small" color="info" @click="reloadPlugin(item.name)">
|
||||
<v-btn icon size="small" @click="reloadPlugin(item.name)">
|
||||
<v-icon>mdi-refresh</v-icon>
|
||||
<v-tooltip activator="parent" location="top">{{ tm('tooltips.reload') }}</v-tooltip>
|
||||
</v-btn>
|
||||
@@ -1104,8 +1130,7 @@ watch(marketSearch, (newVal) => {
|
||||
<v-tooltip activator="parent" location="top">{{ tm('tooltips.viewDocs') }}</v-tooltip>
|
||||
</v-btn>
|
||||
|
||||
<v-btn icon size="small" color="warning" @click="updateExtension(item.name)"
|
||||
:v-show="item.has_update">
|
||||
<v-btn icon size="small" @click="updateExtension(item.name)">
|
||||
<v-icon>mdi-update</v-icon>
|
||||
<v-tooltip activator="parent" location="top">{{ tm('tooltips.update') }}</v-tooltip>
|
||||
</v-btn>
|
||||
@@ -1772,6 +1797,24 @@ watch(marketSearch, (newVal) => {
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 强制更新确认对话框 -->
|
||||
<v-dialog v-model="forceUpdateDialog.show" max-width="420">
|
||||
<v-card class="rounded-lg">
|
||||
<v-card-title class="text-h6 d-flex align-center">
|
||||
<v-icon color="info" class="mr-2">mdi-information-outline</v-icon>
|
||||
{{ tm('dialogs.forceUpdate.title') }}
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
{{ tm('dialogs.forceUpdate.message') }}
|
||||
</v-card-text>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn variant="text" @click="forceUpdateDialog.show = false">{{ tm('buttons.cancel') }}</v-btn>
|
||||
<v-btn color="primary" variant="flat" @click="confirmForceUpdate">{{ tm('dialogs.forceUpdate.confirm') }}</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<style scoped>
|
||||
|
||||
@@ -230,7 +230,7 @@ export default {
|
||||
save_message: "",
|
||||
save_message_success: "success",
|
||||
|
||||
showConsole: false,
|
||||
showConsole: localStorage.getItem('platformPage_showConsole') === 'true',
|
||||
|
||||
showWebhookDialog: false,
|
||||
currentWebhookUuid: '',
|
||||
@@ -248,6 +248,10 @@ export default {
|
||||
},
|
||||
|
||||
watch: {
|
||||
showConsole(newValue) {
|
||||
localStorage.setItem('platformPage_showConsole', newValue.toString());
|
||||
},
|
||||
|
||||
showIdConflictDialog(newValue) {
|
||||
if (!newValue && this.idConflictResolve) {
|
||||
this.idConflictResolve(false);
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "AstrBot"
|
||||
version = "4.10.3"
|
||||
version = "4.11.0"
|
||||
description = "Easy-to-use multi-platform LLM chatbot and development framework"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@@ -0,0 +1,774 @@
|
||||
"""Comprehensive tests for ContextManager."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add parent directory to path to avoid circular import issues
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
|
||||
from astrbot.core.agent.context.config import ContextConfig
|
||||
from astrbot.core.agent.context.manager import ContextManager
|
||||
from astrbot.core.agent.message import Message, TextPart
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
|
||||
class MockProvider:
|
||||
"""模拟 Provider"""
|
||||
|
||||
def __init__(self):
|
||||
self.provider_config = {
|
||||
"id": "test_provider",
|
||||
"model": "gpt-4",
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
}
|
||||
|
||||
async def text_chat(self, **kwargs):
|
||||
"""模拟 LLM 调用,返回摘要"""
|
||||
messages = kwargs.get("messages", [])
|
||||
# 简单的摘要逻辑:返回消息数量统计
|
||||
return LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=f"历史对话包含 {len(messages) - 1} 条消息,主要讨论了技术话题。",
|
||||
)
|
||||
|
||||
def get_model(self):
|
||||
return "gpt-4"
|
||||
|
||||
def meta(self):
|
||||
return MagicMock(id="test_provider", type="openai")
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""Test suite for ContextManager."""
|
||||
|
||||
def create_message(
|
||||
self, role: Literal["system", "user", "assistant", "tool"], content: str
|
||||
) -> Message:
|
||||
"""Helper to create a simple text message."""
|
||||
return Message(role=role, content=content)
|
||||
|
||||
def create_messages(self, count: int) -> list[Message]:
|
||||
"""Helper to create alternating user/assistant messages."""
|
||||
messages = []
|
||||
for i in range(count):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append(self.create_message(role, f"Message {i}"))
|
||||
return messages
|
||||
|
||||
# ==================== Basic Initialization Tests ====================
|
||||
|
||||
def test_init_with_minimal_config(self):
|
||||
"""Test initialization with minimal configuration."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
assert manager.config == config
|
||||
assert manager.token_counter is not None
|
||||
assert manager.truncator is not None
|
||||
assert manager.compressor is not None
|
||||
|
||||
def test_init_with_llm_compressor(self):
|
||||
"""Test initialization with LLM-based compression."""
|
||||
mock_provider = MockProvider()
|
||||
config = ContextConfig(
|
||||
llm_compress_provider=mock_provider, # type: ignore
|
||||
llm_compress_keep_recent=5,
|
||||
llm_compress_instruction="Summarize the conversation",
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
from astrbot.core.agent.context.compressor import LLMSummaryCompressor
|
||||
|
||||
assert isinstance(manager.compressor, LLMSummaryCompressor)
|
||||
|
||||
def test_init_with_truncate_compressor(self):
|
||||
"""Test initialization with truncate-based compression (default)."""
|
||||
config = ContextConfig(truncate_turns=3)
|
||||
manager = ContextManager(config)
|
||||
|
||||
from astrbot.core.agent.context.compressor import TruncateByTurnsCompressor
|
||||
|
||||
assert isinstance(manager.compressor, TruncateByTurnsCompressor)
|
||||
|
||||
# ==================== Empty and Edge Cases ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_empty_messages(self):
|
||||
"""Test processing an empty message list."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
result = await manager.process([])
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_single_message(self):
|
||||
"""Test processing a single message."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_with_no_limits(self):
|
||||
"""Test processing when no limits are set (no truncation or compression)."""
|
||||
config = ContextConfig(max_context_tokens=0, enforce_max_turns=-1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 20
|
||||
assert result == messages
|
||||
|
||||
# ==================== Enforce Max Turns Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_basic(self):
|
||||
"""Test basic enforce_max_turns functionality."""
|
||||
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create 10 turns (20 messages)
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should keep only 3 most recent turns (6 messages)
|
||||
assert len(result) <= 8 # May vary due to truncation logic
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_zero(self):
|
||||
"""Test enforce_max_turns with value 0 (should keep nothing)."""
|
||||
config = ContextConfig(enforce_max_turns=0, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(10)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should result in empty or minimal message list
|
||||
assert len(result) <= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_negative(self):
|
||||
"""Test enforce_max_turns with -1 (no limit)."""
|
||||
config = ContextConfig(enforce_max_turns=-1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 20
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enforce_max_turns_with_system_messages(self):
|
||||
"""Test enforce_max_turns preserves system messages."""
|
||||
config = ContextConfig(enforce_max_turns=2, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
self.create_message("system", "System instruction"),
|
||||
*self.create_messages(10),
|
||||
]
|
||||
result = await manager.process(messages)
|
||||
|
||||
# System message should be preserved
|
||||
system_msgs = [m for m in result if m.role == "system"]
|
||||
assert len(system_msgs) >= 1
|
||||
assert system_msgs[0].content == "System instruction"
|
||||
|
||||
# ==================== Token-based Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_not_triggered_below_threshold(self):
|
||||
"""Test that compression is not triggered below threshold."""
|
||||
config = ContextConfig(max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that total less than threshold
|
||||
messages = [self.create_message("user", "Hi" * 50)] # ~100 tokens
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "should_compress", return_value=False
|
||||
) as mock_should_compress:
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# should_compress should be called
|
||||
mock_should_compress.assert_called_once()
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_triggered_above_threshold(self):
|
||||
"""Test that compression is triggered above threshold."""
|
||||
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that exceed threshold (0.82 * 100 = 82 tokens)
|
||||
# 300 chars * 0.3 = 90 tokens > 82 threshold
|
||||
long_text = "x" * 300 # ~90 tokens, above threshold
|
||||
messages = [self.create_message("user", long_text)]
|
||||
|
||||
# Mock compressor to return smaller result
|
||||
compressed = [self.create_message("user", "short")]
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
|
||||
# Mock should_compress to return True first time, False after
|
||||
call_count = 0
|
||||
|
||||
def mock_should_compress(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count == 1
|
||||
|
||||
mock_compressor.should_compress = mock_should_compress
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should be called
|
||||
mock_compressor.assert_called_once()
|
||||
# Result should be the compressed version
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_with_zero_max_tokens(self):
|
||||
"""Test that compression is skipped when max_context_tokens is 0."""
|
||||
config = ContextConfig(max_context_tokens=0)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 10000)]
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called when max_context_tokens is 0
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_compression_with_negative_max_tokens(self):
|
||||
"""Test that compression is skipped when max_context_tokens is negative."""
|
||||
config = ContextConfig(max_context_tokens=-100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "x" * 10000)]
|
||||
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", new_callable=AsyncMock
|
||||
) as mock_compress:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should not be called
|
||||
mock_compress.assert_not_called()
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_check_after_compression(self):
|
||||
"""Test that halving is applied if still over threshold after compression."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that would still be over threshold after compression
|
||||
long_messages = [self.create_message("user", "x" * 200) for _ in range(10)]
|
||||
|
||||
# Mock compressor to return messages still over threshold
|
||||
async def mock_compress(msgs):
|
||||
return msgs # Return same messages (still over limit)
|
||||
|
||||
# Mock should_compress to return True twice (before and after compression)
|
||||
with patch.object(manager.compressor, "should_compress", return_value=True):
|
||||
with patch.object(manager.compressor, "__call__", new=mock_compress):
|
||||
with patch.object(
|
||||
manager.truncator,
|
||||
"truncate_by_halving",
|
||||
return_value=long_messages[:5],
|
||||
) as mock_halving:
|
||||
_ = await manager.process(long_messages)
|
||||
|
||||
# Halving should be called
|
||||
mock_halving.assert_called_once()
|
||||
|
||||
# ==================== Combined Truncation and Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combined_enforce_turns_and_token_limit(self):
|
||||
"""Test combining enforce_max_turns and token limit."""
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=5, max_context_tokens=500, truncate_turns=1
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create many messages
|
||||
messages = self.create_messages(30)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should be truncated by both mechanisms
|
||||
assert len(result) < 30
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequential_processing_order(self):
|
||||
"""Test that enforce_max_turns happens before token compression."""
|
||||
config = ContextConfig(enforce_max_turns=5, max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
|
||||
# Mock the truncator to track calls
|
||||
with patch.object(
|
||||
manager.truncator,
|
||||
"truncate_by_turns",
|
||||
wraps=manager.truncator.truncate_by_turns,
|
||||
) as mock_truncate:
|
||||
await manager.process(messages)
|
||||
|
||||
# Truncator should be called first
|
||||
mock_truncate.assert_called_once()
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_returns_original_messages(self):
|
||||
"""Test that errors during processing return original messages."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
|
||||
# Make compressor raise an exception
|
||||
with patch.object(
|
||||
manager.compressor, "__call__", side_effect=Exception("Test error")
|
||||
):
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should return original messages despite error
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_logs_exception(self):
|
||||
"""Test that errors are logged."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression (> 82 tokens)
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens
|
||||
|
||||
# Replace compressor with one that raises an exception
|
||||
mock_compressor = AsyncMock(side_effect=Exception("Test error"))
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.should_compress = MagicMock(return_value=True)
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
with patch("astrbot.core.agent.context.manager.logger") as mock_logger:
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Logger error method should be called
|
||||
assert mock_logger.error.called
|
||||
# Should return original messages on error
|
||||
assert result == messages
|
||||
|
||||
# ==================== Multi-modal Content Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_messages_with_textpart_content(self):
|
||||
"""Test processing messages with TextPart content."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
Message(role="user", content=[TextPart(text="Hello")]),
|
||||
Message(role="assistant", content=[TextPart(text="Hi there")]),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result == messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_counting_with_multimodal_content(self):
|
||||
"""Test token counting works with multi-modal content."""
|
||||
config = ContextConfig(max_context_tokens=50)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Need enough tokens to exceed threshold: 50 * 0.82 = 41 tokens
|
||||
# 150 chars * 0.3 = 45 tokens > 41
|
||||
messages = [
|
||||
Message(role="user", content=[TextPart(text="x" * 150)]),
|
||||
]
|
||||
|
||||
# Should trigger compression due to token count
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 50)
|
||||
|
||||
assert tokens > 0 # Tokens should be counted
|
||||
assert needs_compression # Should trigger compression
|
||||
|
||||
# ==================== Tool Calls Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_messages_with_tool_calls(self):
|
||||
"""Test processing messages with tool calls."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [
|
||||
Message(
|
||||
role="assistant",
|
||||
content="Let me search for that",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
Message(role="tool", content="Search result", tool_call_id="call_1"),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
# ==================== Compressor should_compress Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_empty_messages(self):
|
||||
"""Test should_compress with empty messages."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Compressor's should_compress should handle empty gracefully
|
||||
needs_compression = manager.compressor.should_compress([], 0, 100)
|
||||
assert not needs_compression
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_below_threshold(self):
|
||||
"""Test should_compress when below compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=1000)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 1000)
|
||||
assert not needs_compression
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_compress_above_threshold(self):
|
||||
"""Test should_compress when above compression threshold."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create message with many tokens
|
||||
messages = [self.create_message("user", "这是测试" * 50)]
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
|
||||
# Should need compression if tokens > 82 (0.82 * 100)
|
||||
assert needs_compression == (tokens > 82)
|
||||
|
||||
# ==================== Truncator Halving Tests ====================
|
||||
|
||||
def test_truncate_by_halving_basic(self):
|
||||
"""Test truncate_by_halving removes middle 50%."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(10)
|
||||
result = manager.truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should keep roughly half
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_truncate_by_halving_empty_list(self):
|
||||
"""Test truncate_by_halving with empty list."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
result = manager.truncator.truncate_by_halving([])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_truncate_by_halving_single_message(self):
|
||||
"""Test truncate_by_halving with single message."""
|
||||
config = ContextConfig()
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = manager.truncator.truncate_by_halving(messages)
|
||||
|
||||
assert len(result) <= 1
|
||||
|
||||
# ==================== Complex Scenarios ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_compression_cycles(self):
|
||||
"""Test that compression can be triggered multiple times in sequence."""
|
||||
config = ContextConfig(max_context_tokens=50, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Process messages multiple times
|
||||
messages = self.create_messages(10)
|
||||
|
||||
result1 = await manager.process(messages)
|
||||
result2 = await manager.process(result1)
|
||||
result3 = await manager.process(result2)
|
||||
|
||||
# Each cycle should maintain or reduce message count
|
||||
assert len(result3) <= len(result2) <= len(result1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alternating_roles_preserved(self):
|
||||
"""Test that user/assistant alternation is preserved after processing."""
|
||||
config = ContextConfig(enforce_max_turns=3, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(20)
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Check that roles still alternate (excluding system messages)
|
||||
non_system = [m for m in result if m.role != "system"]
|
||||
if len(non_system) >= 2:
|
||||
# Should start with user
|
||||
assert non_system[0].role == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compression_threshold_default(self):
|
||||
"""Test that compression threshold is used correctly."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Verify the default threshold is 0.82
|
||||
assert manager.compressor.compression_threshold == 0.82
|
||||
|
||||
# Test threshold logic
|
||||
messages = [self.create_message("user", "x" * 81)] # ~24 tokens
|
||||
tokens = manager.token_counter.count_tokens(messages)
|
||||
|
||||
needs_compression = manager.compressor.should_compress(messages, tokens, 100)
|
||||
# Should not compress if below threshold
|
||||
assert needs_compression == (tokens > 82)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_batch_processing(self):
|
||||
"""Test processing a large batch of messages."""
|
||||
config = ContextConfig(
|
||||
enforce_max_turns=10, max_context_tokens=1000, truncate_turns=2
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create 100 messages (50 turns)
|
||||
messages = self.create_messages(100)
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should be significantly reduced
|
||||
assert len(result) < 100
|
||||
assert len(result) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_persistence(self):
|
||||
"""Test that config settings are respected throughout processing."""
|
||||
config = ContextConfig(
|
||||
max_context_tokens=500,
|
||||
enforce_max_turns=5,
|
||||
truncate_turns=2,
|
||||
llm_compress_keep_recent=3,
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Verify config is stored
|
||||
assert manager.config.max_context_tokens == 500
|
||||
assert manager.config.enforce_max_turns == 5
|
||||
assert manager.config.truncate_turns == 2
|
||||
assert manager.config.llm_compress_keep_recent == 3
|
||||
|
||||
# ==================== Run Compression Tests ====================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_calls_compressor(self):
|
||||
"""Test _run_compression calls compressor."""
|
||||
config = ContextConfig(max_context_tokens=100)
|
||||
manager = ContextManager(config)
|
||||
|
||||
messages = self.create_messages(5)
|
||||
compressed = self.create_messages(3)
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
mock_compressor.should_compress = MagicMock(return_value=False)
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager._run_compression(messages, prev_tokens=100)
|
||||
|
||||
# Compressor __call__ should be invoked
|
||||
mock_compressor.assert_called_once_with(messages)
|
||||
assert result == compressed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_compression_applies_compressor_through_process(self):
|
||||
"""Test _run_compression calls compressor when needed through process()."""
|
||||
config = ContextConfig(max_context_tokens=100, truncate_turns=1)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression
|
||||
messages = [self.create_message("user", "x" * 300)] # ~90 tokens > 82 threshold
|
||||
compressed = [self.create_message("user", "short")] # Much smaller
|
||||
|
||||
# Create a mock compressor
|
||||
mock_compressor = AsyncMock()
|
||||
mock_compressor.compression_threshold = 0.82
|
||||
mock_compressor.return_value = compressed
|
||||
|
||||
# Mock should_compress to return True first time, False after
|
||||
call_count = 0
|
||||
|
||||
def mock_should_compress(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return call_count == 1
|
||||
|
||||
mock_compressor.should_compress = mock_should_compress
|
||||
manager.compressor = mock_compressor
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Compressor should have been called
|
||||
mock_compressor.assert_called_once()
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_compression_with_mock_provider(self):
|
||||
"""Test LLM compression using MockProvider."""
|
||||
mock_provider = MockProvider()
|
||||
config = ContextConfig(
|
||||
llm_compress_provider=mock_provider, # type: ignore
|
||||
llm_compress_keep_recent=3,
|
||||
llm_compress_instruction="请总结对话内容",
|
||||
max_context_tokens=100,
|
||||
)
|
||||
manager = ContextManager(config)
|
||||
|
||||
# Create messages that will trigger compression
|
||||
messages = [
|
||||
self.create_message("user", "x" * 100),
|
||||
self.create_message("assistant", "y" * 100),
|
||||
self.create_message("user", "z" * 100),
|
||||
]
|
||||
|
||||
result = await manager.process(messages)
|
||||
|
||||
# Should have been compressed
|
||||
assert len(result) <= len(messages)
|
||||
|
||||
# ==================== split_history Tests ====================
|
||||
|
||||
def test_split_history_ensures_user_start(self):
|
||||
"""Test split_history ensures recent_messages starts with user message."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
# Create alternating messages: user, assistant, user, assistant, user, assistant
|
||||
messages = [
|
||||
self.create_message("system", "System prompt"),
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
self.create_message("assistant", "msg4"),
|
||||
self.create_message("user", "msg5"),
|
||||
self.create_message("assistant", "msg6"),
|
||||
]
|
||||
|
||||
# Keep recent 3 messages - should adjust to start with user
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=3)
|
||||
|
||||
# recent_messages should start with user message
|
||||
assert len(recent) > 0
|
||||
assert recent[0].role == "user"
|
||||
|
||||
# messages_to_summarize should end with assistant (complete turn)
|
||||
if len(to_summarize) > 0:
|
||||
assert to_summarize[-1].role == "assistant"
|
||||
|
||||
def test_split_history_handles_assistant_at_split_point(self):
|
||||
"""Test split_history when assistant message is at the intended split point."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
self.create_message("assistant", "msg4"), # <- intended split here
|
||||
self.create_message("user", "msg5"),
|
||||
self.create_message("assistant", "msg6"),
|
||||
]
|
||||
|
||||
# keep_recent=2 would normally split at index 4 (assistant msg4)
|
||||
# Should move back to include from msg5 (user)
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# recent should start with user message
|
||||
assert recent[0].role == "user"
|
||||
assert recent[0].content == "msg5"
|
||||
|
||||
def test_split_history_all_assistant_messages(self):
|
||||
"""Test split_history when there are consecutive assistant messages."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("assistant", "msg3"),
|
||||
self.create_message("assistant", "msg4"),
|
||||
]
|
||||
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# Should find the user message and keep from there
|
||||
if len(recent) > 0:
|
||||
# Find first user message backwards
|
||||
assert any(m.role == "user" for m in messages)
|
||||
|
||||
def test_split_history_with_system_messages(self):
|
||||
"""Test split_history preserves system messages separately."""
|
||||
from astrbot.core.agent.context.compressor import split_history
|
||||
|
||||
messages = [
|
||||
self.create_message("system", "System 1"),
|
||||
self.create_message("system", "System 2"),
|
||||
self.create_message("user", "msg1"),
|
||||
self.create_message("assistant", "msg2"),
|
||||
self.create_message("user", "msg3"),
|
||||
]
|
||||
|
||||
system, to_summarize, recent = split_history(messages, keep_recent=2)
|
||||
|
||||
# System messages should be separate
|
||||
assert len(system) == 2
|
||||
assert all(m.role == "system" for m in system)
|
||||
|
||||
# Recent should start with user
|
||||
if len(recent) > 0:
|
||||
assert recent[0].role == "user"
|
||||
@@ -0,0 +1,423 @@
|
||||
"""Tests for ContextTruncator."""
|
||||
|
||||
from astrbot.core.agent.context.truncator import ContextTruncator
|
||||
from astrbot.core.agent.message import Message
|
||||
|
||||
|
||||
class TestContextTruncator:
|
||||
"""Test suite for ContextTruncator."""
|
||||
|
||||
def create_message(self, role: str, content: str = "test content") -> Message:
|
||||
"""Helper to create a simple test message."""
|
||||
return Message(role=role, content=content)
|
||||
|
||||
def create_messages(
|
||||
self, count: int, include_system: bool = False
|
||||
) -> list[Message]:
|
||||
"""Helper to create alternating user/assistant messages.
|
||||
|
||||
Args:
|
||||
count: Number of messages to create
|
||||
include_system: Whether to include a system message at the start
|
||||
|
||||
Returns:
|
||||
List of messages
|
||||
"""
|
||||
messages = []
|
||||
if include_system:
|
||||
messages.append(self.create_message("system", "System prompt"))
|
||||
|
||||
for i in range(count):
|
||||
role = "user" if i % 2 == 0 else "assistant"
|
||||
messages.append(self.create_message(role, f"Message {i}"))
|
||||
return messages
|
||||
|
||||
# ==================== fix_messages Tests ====================
|
||||
|
||||
def test_fix_messages_empty_list(self):
|
||||
"""Test fix_messages with an empty list."""
|
||||
truncator = ContextTruncator()
|
||||
result = truncator.fix_messages([])
|
||||
assert result == []
|
||||
|
||||
def test_fix_messages_normal_messages(self):
|
||||
"""Test fix_messages with normal user/assistant messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Hello"),
|
||||
self.create_message("assistant", "Hi"),
|
||||
self.create_message("user", "How are you?"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_with_valid_context(self):
|
||||
"""Test fix_messages with tool message after user+assistant."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 3
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_tool_without_context(self):
|
||||
"""Test fix_messages with tool message without enough context."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
# Tool message without context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_tool_with_only_one_message(self):
|
||||
"""Test fix_messages with tool message after only one message."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Hello"),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
# Tool message without enough context should be removed
|
||||
assert len(result) == 0
|
||||
|
||||
def test_fix_messages_multiple_tools(self):
|
||||
"""Test fix_messages with multiple tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool 1 result"),
|
||||
self.create_message("tool", "Tool 2 result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
def test_fix_messages_mixed_system_tool(self):
|
||||
"""Test fix_messages with system message and tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("system", "System prompt"),
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
]
|
||||
result = truncator.fix_messages(messages)
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
# ==================== truncate_by_turns Tests ====================
|
||||
|
||||
def test_truncate_by_turns_no_limit(self):
|
||||
"""Test truncate_by_turns with -1 (no limit)."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
|
||||
assert len(result) == 20
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_basic(self):
|
||||
"""Test basic truncate_by_turns functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 10 messages = 5 turns (user/assistant pairs)
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# Should keep 3 most recent turns (6 messages)
|
||||
assert len(result) <= 8 # (3-1+1)*2 = 6, but may adjust for correct format
|
||||
|
||||
def test_truncate_by_turns_with_system_message(self):
|
||||
"""Test truncate_by_turns preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10, include_system=True)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=2, drop_turns=1
|
||||
)
|
||||
|
||||
# System message should always be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_turns_zero_keep(self):
|
||||
"""Test truncate_by_turns with keep_most_recent_turns=0."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=0, drop_turns=1
|
||||
)
|
||||
|
||||
# Should result in empty or minimal list
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_turns_below_threshold(self):
|
||||
"""Test truncate_by_turns when messages are below threshold."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=1
|
||||
)
|
||||
|
||||
# No truncation should happen
|
||||
assert len(result) == 4
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_exact_threshold(self):
|
||||
"""Test truncate_by_turns when messages exactly match threshold."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 6 messages = 3 turns
|
||||
messages = self.create_messages(6)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# No truncation should happen
|
||||
assert len(result) == 6
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_turns_ensures_user_first(self):
|
||||
"""Test that truncate_by_turns ensures user message comes first."""
|
||||
truncator = ContextTruncator()
|
||||
# Create scenario where truncation might start with assistant
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=3, drop_turns=1
|
||||
)
|
||||
|
||||
# First non-system message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_turns_multiple_drop(self):
|
||||
"""Test truncate_by_turns with multiple turns dropped at once."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=3
|
||||
)
|
||||
|
||||
# Should drop 3 turns when limit exceeded
|
||||
assert len(result) < len(messages)
|
||||
|
||||
# ==================== truncate_by_dropping_oldest_turns Tests ====================
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_zero(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with drop_turns=0."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_negative(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with negative drop_turns."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_basic(self):
|
||||
"""Test basic truncate_by_dropping_oldest_turns functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 10 messages = 5 turns
|
||||
messages = self.create_messages(10)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# Should drop 2 oldest turns (4 messages)
|
||||
assert len(result) == 6
|
||||
# Should start with user message
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_with_system(self):
|
||||
"""Test truncate_by_dropping_oldest_turns preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(10, include_system=True)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# System message should be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_drop_all(self):
|
||||
"""Test truncate_by_dropping_oldest_turns dropping all turns."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)
|
||||
|
||||
# Should drop all turns
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
|
||||
"""Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 4 messages = 2 turns
|
||||
messages = self.create_messages(4)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)
|
||||
|
||||
# Should result in empty list
|
||||
assert len(result) == 0
|
||||
|
||||
def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
|
||||
"""Test that result starts with user message after dropping."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)
|
||||
|
||||
# First message should be user
|
||||
if len(result) > 0:
|
||||
assert result[0].role == "user"
|
||||
|
||||
# ==================== truncate_by_halving Tests ====================
|
||||
|
||||
def test_truncate_by_halving_empty(self):
|
||||
"""Test truncate_by_halving with empty list."""
|
||||
truncator = ContextTruncator()
|
||||
result = truncator.truncate_by_halving([])
|
||||
assert result == []
|
||||
|
||||
def test_truncate_by_halving_single_message(self):
|
||||
"""Test truncate_by_halving with single message."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [self.create_message("user", "Hello")]
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
# Should not truncate if <= 2 messages
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_halving_two_messages(self):
|
||||
"""Test truncate_by_halving with two messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(2)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
# Should not truncate if <= 2 messages
|
||||
assert result == messages
|
||||
|
||||
def test_truncate_by_halving_basic(self):
|
||||
"""Test basic truncate_by_halving functionality."""
|
||||
truncator = ContextTruncator()
|
||||
# Create 20 messages
|
||||
messages = self.create_messages(20)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should delete 50% = 10 messages, keep 10
|
||||
assert len(result) == 10
|
||||
# First message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_with_system_message(self):
|
||||
"""Test truncate_by_halving preserves system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(20, include_system=True)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# System message should be preserved
|
||||
assert result[0].role == "system"
|
||||
assert result[0].content == "System prompt"
|
||||
|
||||
def test_truncate_by_halving_odd_count(self):
|
||||
"""Test truncate_by_halving with odd number of messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(11)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should delete floor(11/2) = 5 messages, keep 6
|
||||
# But after ensuring user first, may be 5
|
||||
assert len(result) >= 5
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_ensures_user_first(self):
|
||||
"""Test that result starts with user message."""
|
||||
truncator = ContextTruncator()
|
||||
# Create messages starting with user
|
||||
messages = self.create_messages(30)
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# First message should be user
|
||||
assert result[0].role == "user"
|
||||
|
||||
def test_truncate_by_halving_preserves_recent_messages(self):
|
||||
"""Test that truncate_by_halving keeps the most recent 50%."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Message 0"),
|
||||
self.create_message("assistant", "Message 1"),
|
||||
self.create_message("user", "Message 2"),
|
||||
self.create_message("assistant", "Message 3"),
|
||||
]
|
||||
result = truncator.truncate_by_halving(messages)
|
||||
|
||||
# Should keep last 2 messages
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "Message 2"
|
||||
assert result[1].content == "Message 3"
|
||||
|
||||
# ==================== Integration Tests ====================
|
||||
|
||||
def test_truncate_with_tool_messages(self):
|
||||
"""Test truncation with tool messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("user", "Run tool"),
|
||||
self.create_message("assistant", "Running..."),
|
||||
self.create_message("tool", "Tool result"),
|
||||
self.create_message("user", "Thanks"),
|
||||
self.create_message("assistant", "Welcome"),
|
||||
]
|
||||
|
||||
result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)
|
||||
|
||||
# First turn (user+assistant+tool) should be dropped
|
||||
# Tool message should be cleaned up by fix_messages
|
||||
assert len(result) <= 2
|
||||
|
||||
def test_chain_multiple_truncations(self):
|
||||
"""Test chaining multiple truncation methods."""
|
||||
truncator = ContextTruncator()
|
||||
messages = self.create_messages(40, include_system=True)
|
||||
|
||||
# First: truncate by turns
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=10, drop_turns=2
|
||||
)
|
||||
# Then: halve
|
||||
result = truncator.truncate_by_halving(result)
|
||||
|
||||
# Should have system message + truncated content
|
||||
assert result[0].role == "system"
|
||||
assert len(result) < len(messages)
|
||||
|
||||
def test_empty_after_system_message(self):
|
||||
"""Test truncation when only system message exists."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [self.create_message("system", "System prompt")]
|
||||
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=5, drop_turns=1
|
||||
)
|
||||
|
||||
# Should keep system message
|
||||
assert len(result) == 1
|
||||
assert result[0].role == "system"
|
||||
|
||||
def test_all_system_messages(self):
|
||||
"""Test truncation with only system messages."""
|
||||
truncator = ContextTruncator()
|
||||
messages = [
|
||||
self.create_message("system", "System 1"),
|
||||
self.create_message("system", "System 2"),
|
||||
]
|
||||
|
||||
result = truncator.truncate_by_turns(
|
||||
messages, keep_most_recent_turns=0, drop_turns=1
|
||||
)
|
||||
|
||||
# System messages should be preserved, but since there are no non-system
|
||||
# messages and keep_most_recent_turns=0, result should be system messages only
|
||||
assert len(result) >= 0 # May keep system messages or clear all
|
||||
if len(result) > 0:
|
||||
assert all(msg.role == "system" for msg in result)
|
||||
+14
-4
@@ -195,6 +195,7 @@ class TestAstrBotExporter:
|
||||
|
||||
assert manifest["version"] == BACKUP_MANIFEST_VERSION
|
||||
assert manifest["astrbot_version"] == VERSION
|
||||
assert manifest["origin"] == "exported" # 验证备份来源标记
|
||||
assert "exported_at" in manifest
|
||||
assert "tables" in manifest
|
||||
assert "statistics" in manifest
|
||||
@@ -412,11 +413,19 @@ class TestSecureFilename:
|
||||
def test_generate_unique_filename(self):
|
||||
"""测试生成唯一文件名"""
|
||||
result = generate_unique_filename("backup.zip")
|
||||
# 应包含 uploaded_ 前缀和时间戳
|
||||
assert result.startswith("uploaded_")
|
||||
assert result.endswith("_backup.zip")
|
||||
# 应包含原文件名和时间戳后缀
|
||||
assert result.startswith("backup_")
|
||||
assert result.endswith(".zip")
|
||||
# 应包含时间戳格式 YYYYMMDD_HHMMSS
|
||||
assert re.search(r"uploaded_\d{8}_\d{6}_backup\.zip", result)
|
||||
assert re.search(r"backup_\d{8}_\d{6}\.zip", result)
|
||||
|
||||
def test_generate_unique_filename_with_complex_name(self):
|
||||
"""测试复杂文件名生成唯一文件名"""
|
||||
result = generate_unique_filename("my_backup_file.zip")
|
||||
# 应在原文件名后添加时间戳
|
||||
assert result.startswith("my_backup_file_")
|
||||
assert result.endswith(".zip")
|
||||
assert re.search(r"my_backup_file_\d{8}_\d{6}\.zip", result)
|
||||
|
||||
|
||||
class TestVersionComparison:
|
||||
@@ -750,6 +759,7 @@ class TestBackupIntegration:
|
||||
# 读取 manifest
|
||||
manifest = json.loads(zf.read("manifest.json"))
|
||||
assert manifest["astrbot_version"] == VERSION
|
||||
assert manifest["origin"] == "exported" # 验证备份来源标记
|
||||
|
||||
# 读取配置
|
||||
config = json.loads(zf.read("config/cmd_config.json"))
|
||||
|
||||
Reference in New Issue
Block a user