Compare commits
91 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3e3599835e | |||
| 5255388e2d | |||
| fbdd60b64c | |||
| bd1b0a2836 | |||
| 19541d9d07 | |||
| 2a5d574394 | |||
| f2924fbd1b | |||
| 703e208947 | |||
| 9a5cc977c2 | |||
| aa38fe776a | |||
| 701399c00c | |||
| eaee98d4b8 | |||
| 76c66000a7 | |||
| 4b365143c0 | |||
| 6e4e5011e2 | |||
| d853bfde84 | |||
| a0e856f80f | |||
| 8c94a0010c | |||
| a44fdaaec0 | |||
| 60105c76f5 | |||
| bcf87d3ce4 | |||
| 4d7c8c8453 | |||
| a064a9115f | |||
| 6ef99e1553 | |||
| c0dbe5cf65 | |||
| 3598c51eff | |||
| b5cdb8f650 | |||
| fc5b520f9b | |||
| 904f56b32f | |||
| 2f15fd019c | |||
| 82330b8d10 | |||
| 3ee6af7027 | |||
| 6e20ebe901 | |||
| 4d6150fd6d | |||
| 544e52191b | |||
| f2c2a6da4a | |||
| dd3df425ee | |||
| 40b4a27a3d | |||
| 9d991c7468 | |||
| ad6a8b5c94 | |||
| 1b4bfcbd72 | |||
| 9d3cc593a1 | |||
| f0dee35ba9 | |||
| 4135bd84d5 | |||
| f6da614e5d | |||
| 5f531c9be5 | |||
| 94591d965b | |||
| 8a0f865af1 | |||
| 4aced976a8 | |||
| 0299aa6e4c | |||
| e8b54a019e | |||
| 98ce796275 | |||
| b87dcf2275 | |||
| 591a228431 | |||
| f52f375154 | |||
| 975c685a17 | |||
| 6db80d36a8 | |||
| 4651bd2807 | |||
| 94ada3793e | |||
| fd05b0bf09 | |||
| 4d046f8490 | |||
| 58e32b7b70 | |||
| 903dd0f9f7 | |||
| 1acac0cac2 | |||
| 80b89fd2ea | |||
| 26f863ba81 | |||
| f78a90218e | |||
| a3ecebd2aa | |||
| 67c33b842d | |||
| 5431c9f46e | |||
| 764b91a5f7 | |||
| c20c1b84bf | |||
| fd66a0ac00 | |||
| aaee283367 | |||
| 4a5b7d1976 | |||
| 08244548ab | |||
| b486de6a98 | |||
| e2f928a7e5 | |||
| b8e4068c75 | |||
| 0916177a57 | |||
| 02cd5e396b | |||
| 56673ad78f | |||
| 9a4d05e2b6 | |||
| b2e9dab233 | |||
| 45110200ea | |||
| c3f45449e8 | |||
| 65da469deb | |||
| 16df64c405 | |||
| 6b73b19e54 | |||
| a70088b799 | |||
| bb45d9cb54 |
@@ -15,7 +15,6 @@ Always reference these instructions first and fallback to search or bash command
|
||||
### Running the Application
|
||||
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
||||
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
||||
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
|
||||
|
||||
### Dashboard Build (Vue.js/Node.js)
|
||||
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
||||
@@ -35,7 +34,7 @@ Always reference these instructions first and fallback to search or bash command
|
||||
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
||||
|
||||
### Plugin Development
|
||||
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugins load from `astrbot/builtin_stars/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugin system supports function tools and message handlers
|
||||
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Archive production artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: dist-without-markdown
|
||||
path: |
|
||||
|
||||
+2
-2
@@ -24,9 +24,9 @@ configs/session
|
||||
configs/config.yaml
|
||||
cmd_config.json
|
||||
|
||||
# Plugins and packages
|
||||
# Plugins
|
||||
addons/plugins
|
||||
packages/python_interpreter/workplace
|
||||
astrbot/builtin_stars/python_interpreter/workplace
|
||||
tests/astrbot_plugin_openai
|
||||
|
||||
# Dashboard
|
||||
|
||||
+26
-1
@@ -33,6 +33,20 @@
|
||||
- 请使用英文描述您的 PR。
|
||||
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
|
||||
|
||||
#### 代码规范
|
||||
|
||||
##### Core
|
||||
|
||||
我们使用 Ruff 作为代码格式化和静态分析工具。在提交代码之前,请运行以下命令以确保代码符合规范:
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
|
||||
如果您使用 VSCode,可以安装 `Ruff` 插件。
|
||||
|
||||
|
||||
## Contributing Guide
|
||||
|
||||
First off, thanks for taking the time to contribute! ❤️
|
||||
@@ -62,4 +76,15 @@ We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features.
|
||||
|
||||
#### PR Description
|
||||
- Please use English to describe your PR.
|
||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||
|
||||
#### Code Style
|
||||
|
||||
##### Core
|
||||
|
||||
We use Ruff as our code formatter and static analysis tool. Before submitting your code, please run the following commands to ensure your code adheres to the style guidelines:
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
|
||||
@@ -243,4 +243,10 @@ pre-commit install
|
||||
|
||||
</details>
|
||||
|
||||
<div align="center">
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
</div
|
||||
|
||||
|
||||
+44
-16
@@ -7,6 +7,7 @@ from astrbot.api import logger, sp, star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.message_components import Image, Reply
|
||||
from astrbot.api.provider import Provider, ProviderRequest
|
||||
from astrbot.core.agent.message import TextPart
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
|
||||
|
||||
@@ -85,7 +86,9 @@ class ProcessLLMRequest:
|
||||
req.image_urls,
|
||||
)
|
||||
if caption:
|
||||
req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}"
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(text=f"<image_caption>{caption}</image_caption>")
|
||||
)
|
||||
req.image_urls = []
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片描述失败: {e}")
|
||||
@@ -129,13 +132,14 @@ class ProcessLLMRequest:
|
||||
else:
|
||||
req.prompt = prefix + req.prompt
|
||||
|
||||
# 收集系统提醒信息
|
||||
system_parts = []
|
||||
|
||||
# user identifier
|
||||
if cfg.get("identifier"):
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
req.prompt = (
|
||||
f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n{req.prompt}"
|
||||
)
|
||||
system_parts.append(f"User ID: {user_id}, Nickname: {user_nickname}")
|
||||
|
||||
# group name identifier
|
||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||
@@ -146,7 +150,7 @@ class ProcessLLMRequest:
|
||||
return
|
||||
group_name = event.message_obj.group.group_name
|
||||
if group_name:
|
||||
req.system_prompt += f"\nGroup name: {group_name}\n"
|
||||
system_parts.append(f"Group name: {group_name}")
|
||||
|
||||
# time info
|
||||
if cfg.get("datetime_system_prompt"):
|
||||
@@ -162,7 +166,7 @@ class ProcessLLMRequest:
|
||||
current_time = (
|
||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
)
|
||||
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
||||
system_parts.append(f"Current datetime: {current_time}")
|
||||
|
||||
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
|
||||
if req.conversation:
|
||||
@@ -181,37 +185,61 @@ class ProcessLLMRequest:
|
||||
quote = comp
|
||||
break
|
||||
if quote:
|
||||
sender_info = ""
|
||||
if quote.sender_nickname:
|
||||
sender_info = f"(Sent by {quote.sender_nickname})"
|
||||
message_str = quote.message_str or "[Empty Text]"
|
||||
req.system_prompt += (
|
||||
f"\nUser is quoting a message{sender_info}.\n"
|
||||
f"Here are the information of the quoted message: Text Content: {message_str}.\n"
|
||||
content_parts = []
|
||||
|
||||
# 1. 处理引用的文本
|
||||
sender_info = (
|
||||
f"({quote.sender_nickname}): " if quote.sender_nickname else ""
|
||||
)
|
||||
message_str = quote.message_str or "[Empty Text]"
|
||||
content_parts.append(f"{sender_info}{message_str}")
|
||||
|
||||
# 2. 处理引用的图片 (保留原有逻辑,但改变输出目标)
|
||||
image_seg = None
|
||||
if quote.chain:
|
||||
for comp in quote.chain:
|
||||
if isinstance(comp, Image):
|
||||
image_seg = comp
|
||||
break
|
||||
|
||||
if image_seg:
|
||||
try:
|
||||
# 找到可以生成图片描述的 provider
|
||||
prov = None
|
||||
if img_cap_prov_id:
|
||||
prov = self.ctx.get_provider_by_id(img_cap_prov_id)
|
||||
if prov is None:
|
||||
prov = self.ctx.get_using_provider(event.unified_msg_origin)
|
||||
|
||||
# 调用 provider 生成图片描述
|
||||
if prov and isinstance(prov, Provider):
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt="Please describe the image content.",
|
||||
image_urls=[await image_seg.convert_to_file_path()],
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
req.system_prompt += (
|
||||
f"Image Caption: {llm_resp.completion_text}\n"
|
||||
# 将图片描述作为文本添加到 content_parts
|
||||
content_parts.append(
|
||||
f"[Image Caption in quoted message]: {llm_resp.completion_text}"
|
||||
)
|
||||
else:
|
||||
logger.warning("No provider found for image captioning.")
|
||||
logger.warning(
|
||||
"No provider found for image captioning in quote."
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"处理引用图片失败: {e}")
|
||||
|
||||
# 3. 将所有部分组合成文本并添加到 extra_user_content_parts 中
|
||||
# 确保引用内容被正确的标签包裹
|
||||
quoted_content = "\n".join(content_parts)
|
||||
# 确保所有内容都在<Quoted Message>标签内
|
||||
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
|
||||
|
||||
req.extra_user_content_parts.append(TextPart(text=quoted_text))
|
||||
|
||||
# 统一包裹所有系统提醒
|
||||
if system_parts:
|
||||
system_content = (
|
||||
"<system_reminder>" + "\n".join(system_parts) + "</system_reminder>"
|
||||
)
|
||||
req.extra_user_content_parts.append(TextPart(text=system_content))
|
||||
+1
@@ -71,6 +71,7 @@ class AdminCommands:
|
||||
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
|
||||
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
"""更新管理面板"""
|
||||
await event.send(MessageChain().message("正在尝试更新管理面板..."))
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
await event.send(MessageChain().message("管理面板更新完成。"))
|
||||
@@ -0,0 +1,88 @@
|
||||
import aiohttp
|
||||
|
||||
from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.star import command_management
|
||||
from astrbot.core.utils.io import get_dashboard_version
|
||||
|
||||
|
||||
class HelpCommand:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
"https://astrbot.app/notice.json",
|
||||
timeout=2,
|
||||
) as resp:
|
||||
return (await resp.json())["notice"]
|
||||
except BaseException:
|
||||
return ""
|
||||
|
||||
async def _build_reserved_command_lines(self) -> list[str]:
|
||||
"""
|
||||
使用实时指令配置生成内置指令清单,确保重命名/禁用后与实际生效状态保持一致。
|
||||
"""
|
||||
try:
|
||||
commands = await command_management.list_commands()
|
||||
except BaseException:
|
||||
return []
|
||||
|
||||
lines: list[str] = []
|
||||
hidden_commands = {"set", "unset", "websearch"}
|
||||
|
||||
def walk(items: list[dict], indent: int = 0):
|
||||
for item in items:
|
||||
if not item.get("reserved") or not item.get("enabled"):
|
||||
continue
|
||||
# 仅展示顶级指令或指令组
|
||||
if item.get("type") == "sub_command":
|
||||
continue
|
||||
if item.get("parent_signature"):
|
||||
continue
|
||||
|
||||
effective = (
|
||||
item.get("effective_command")
|
||||
or item.get("original_command")
|
||||
or item.get("handler_name")
|
||||
)
|
||||
if not effective:
|
||||
continue
|
||||
if effective in hidden_commands:
|
||||
continue
|
||||
|
||||
description = item.get("description") or ""
|
||||
desc_text = f" - {description}" if description else ""
|
||||
indent_prefix = " " * indent
|
||||
lines.append(f"{indent_prefix}/{effective}{desc_text}")
|
||||
|
||||
walk(commands)
|
||||
return lines
|
||||
|
||||
async def help(self, event: AstrMessageEvent):
|
||||
"""查看帮助"""
|
||||
notice = ""
|
||||
try:
|
||||
notice = await self._query_astrbot_notice()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
dashboard_version = await get_dashboard_version()
|
||||
command_lines = await self._build_reserved_command_lines()
|
||||
commands_section = (
|
||||
"\n".join(command_lines) if command_lines else "暂无启用的内置指令"
|
||||
)
|
||||
|
||||
msg_parts = [
|
||||
f"AstrBot v{VERSION}(WebUI: {dashboard_version})",
|
||||
"内置指令:",
|
||||
commands_section,
|
||||
]
|
||||
if notice:
|
||||
msg_parts.append(notice)
|
||||
msg = "\n".join(msg_parts)
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
+6
-4
@@ -184,7 +184,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
@@ -198,7 +199,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
@@ -209,8 +211,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif isinstance(idx, int):
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
@@ -49,7 +49,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command_group("tool")
|
||||
def tool(self):
|
||||
pass
|
||||
"""函数工具管理"""
|
||||
|
||||
@tool.command("ls")
|
||||
async def tool_ls(self, event: AstrMessageEvent):
|
||||
@@ -73,7 +73,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command_group("plugin")
|
||||
def plugin(self):
|
||||
pass
|
||||
"""插件管理"""
|
||||
|
||||
@plugin.command("ls")
|
||||
async def plugin_ls(self, event: AstrMessageEvent):
|
||||
@@ -219,6 +219,7 @@ class Main(star.Star):
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dashboard_update")
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
"""更新管理面板"""
|
||||
await self.admin_c.update_dashboard(event)
|
||||
|
||||
@filter.command("set")
|
||||
+1
-1
@@ -249,7 +249,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command_group("pi")
|
||||
def pi(self):
|
||||
pass
|
||||
"""代码执行器配置"""
|
||||
|
||||
@pi.command("absdir")
|
||||
async def pi_absdir(self, event: AstrMessageEvent, path: str = ""):
|
||||
@@ -179,7 +179,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command_group("reminder")
|
||||
def reminder(self):
|
||||
"""The command group of the reminder."""
|
||||
"""待办提醒"""
|
||||
|
||||
async def get_upcoming_reminders(self, unified_msg_origin: str):
|
||||
"""Get upcoming reminders."""
|
||||
@@ -185,6 +185,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command("websearch")
|
||||
async def websearch(self, event: AstrMessageEvent, oper: str | None = None):
|
||||
"""网页搜索指令(已废弃)"""
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。",
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.9.0"
|
||||
__version__ = "4.10.3"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
@@ -122,10 +122,12 @@ class ToolCall(BaseModel):
|
||||
extra_content: dict[str, Any] | None = None
|
||||
"""Extra metadata for the tool call."""
|
||||
|
||||
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.extra_content is None:
|
||||
kwargs.setdefault("exclude", set()).add("extra_content")
|
||||
return super().model_dump(**kwargs)
|
||||
data.pop("extra_content", None)
|
||||
return data
|
||||
|
||||
|
||||
class ToolCallPart(BaseModel):
|
||||
@@ -167,6 +169,15 @@ class Message(BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.tool_calls is None:
|
||||
data.pop("tool_calls", None)
|
||||
if self.tool_call_id is None:
|
||||
data.pop("tool_call_id", None)
|
||||
return data
|
||||
|
||||
|
||||
class AssistantMessageSegment(Message):
|
||||
"""A message segment from the assistant."""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import TokenUsage
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
@@ -12,3 +13,23 @@ class AgentResponseData(T.TypedDict):
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStats:
|
||||
token_usage: TokenUsage = field(default_factory=TokenUsage)
|
||||
start_time: float = 0.0
|
||||
end_time: float = 0.0
|
||||
time_to_first_token: float = 0.0
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"token_usage": self.token_usage.__dict__,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"time_to_first_token": self.time_to_first_token,
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ from .message import Message
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
@dataclass
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing as T
|
||||
|
||||
@@ -12,6 +13,7 @@ from mcp.types import (
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
)
|
||||
@@ -24,7 +26,7 @@ from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from .base import AgentResponse, AgentState, BaseAgentRunner
|
||||
@@ -69,14 +71,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
self.run_context.messages = messages
|
||||
|
||||
self.stats = AgentStats()
|
||||
self.stats.start_time = time.time()
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
payload = {
|
||||
"contexts": self.run_context.messages, # list[Message]
|
||||
"func_tool": self.req.func_tool,
|
||||
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
|
||||
"session_id": self.req.session_id,
|
||||
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
|
||||
}
|
||||
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
stream = self.provider.text_chat_stream(**payload)
|
||||
async for resp in stream: # type: ignore
|
||||
yield resp
|
||||
else:
|
||||
yield await self.provider.text_chat(**self.req.__dict__)
|
||||
yield await self.provider.text_chat(**payload)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
@@ -98,6 +111,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
if llm_response.is_chunk:
|
||||
# update ttft
|
||||
if self.stats.time_to_first_token == 0:
|
||||
self.stats.time_to_first_token = time.time() - self.stats.start_time
|
||||
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
@@ -121,6 +138,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
|
||||
if not llm_response.is_chunk and llm_response.usage:
|
||||
# only count the token usage of the final response for computation purpose
|
||||
self.stats.token_usage += llm_response.usage
|
||||
break # got final response
|
||||
|
||||
if not llm_resp_result:
|
||||
@@ -132,6 +153,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
if llm_resp.role == "err":
|
||||
# 如果 LLM 响应错误,转换到错误状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self.stats.end_time = time.time()
|
||||
self._transition_state(AgentState.ERROR)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
@@ -146,11 +168,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
# record the final assistant message
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=llm_resp.completion_text or "",
|
||||
content=llm_resp.completion_text or "*No response*",
|
||||
),
|
||||
)
|
||||
try:
|
||||
@@ -175,22 +198,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果有工具调用,还需处理工具调用
|
||||
if llm_resp.tools_call_name:
|
||||
tool_call_result_blocks = []
|
||||
for tool_call_name in llm_resp.tools_call_name:
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain(type="tool_call").message(
|
||||
f"🔨 调用工具: {tool_call_name}"
|
||||
),
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
result.type = "tool_call_result"
|
||||
if result.type is None:
|
||||
# should not happen
|
||||
continue
|
||||
if result.type == "tool_direct_result":
|
||||
ar_type = "tool_call_result"
|
||||
else:
|
||||
ar_type = result.type
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
type=ar_type,
|
||||
data=AgentResponseData(chain=result),
|
||||
)
|
||||
# 将结果添加到上下文中
|
||||
@@ -218,6 +238,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
# 如果循环结束了但是 agent 还没有完成,说明是达到了 max_step
|
||||
if not self.done():
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
# 拔掉所有工具
|
||||
if self.req:
|
||||
self.req.func_tool = None
|
||||
# 注入提示词
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
# 再执行最后一步
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
@@ -233,6 +272,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
yield MessageChain(
|
||||
type="tool_call",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"name": func_tool_name,
|
||||
"args": func_tool_args,
|
||||
"ts": time.time(),
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
try:
|
||||
if not req.func_tool:
|
||||
return
|
||||
@@ -306,7 +358,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content=res.content[0].text,
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -328,7 +379,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content=resource.text,
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
@@ -352,20 +402,34 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content="返回的数据类型不受支持",
|
||||
),
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
# 发送消息逻辑在 ToolExecutor 中处理了。
|
||||
logger.warning(
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户。"
|
||||
)
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具没有返回值或者将结果直接发送给了用户*",
|
||||
),
|
||||
)
|
||||
else:
|
||||
# 不应该出现其他类型
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。",
|
||||
f"Tool 返回了不支持的类型: {type(resp)}。",
|
||||
)
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具返回了不支持的类型,请告诉用户检查这个工具的定义和实现。*",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -387,6 +451,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
)
|
||||
|
||||
# yield the last tool call result
|
||||
if tool_call_result_blocks:
|
||||
last_tcr_content = str(tool_call_result_blocks[-1].content)
|
||||
yield MessageChain(
|
||||
type="tool_call_result",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"ts": time.time(),
|
||||
"result": last_tcr_content,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# 处理函数调用响应
|
||||
if tool_call_result_blocks:
|
||||
yield tool_call_result_blocks
|
||||
|
||||
@@ -6,8 +6,10 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
@dataclass
|
||||
class AstrAgentContext:
|
||||
__pydantic_config__ = {"arbitrary_types_allowed": True}
|
||||
|
||||
context: Context
|
||||
"""The star context instance"""
|
||||
event: AstrMessageEvent
|
||||
|
||||
@@ -2,8 +2,10 @@ import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
@@ -23,8 +25,25 @@ async def run_agent(
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step:
|
||||
while step_idx < max_step + 1:
|
||||
step_idx += 1
|
||||
|
||||
if step_idx == max_step + 1:
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
if not agent_runner.done():
|
||||
# 拔掉所有工具
|
||||
if agent_runner.req:
|
||||
agent_runner.req.func_tool = None
|
||||
# 注入提示词
|
||||
agent_runner.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
@@ -33,16 +52,27 @@ async def run_agent(
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
await astr_event.send(resp.data["chain"])
|
||||
await astr_event.send(msg_chain)
|
||||
continue
|
||||
if astr_event.get_platform_id() == "webchat":
|
||||
await astr_event.send(msg_chain)
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use:
|
||||
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(resp.data["chain"])
|
||||
elif show_tool_use:
|
||||
json_comp = resp.data["chain"].chain[0]
|
||||
if isinstance(json_comp, Json):
|
||||
m = f"🔨 调用工具: {json_comp.data.get('name')}"
|
||||
else:
|
||||
m = "🔨 调用工具..."
|
||||
chain = MessageChain(type="tool_call").message(m)
|
||||
await astr_event.send(chain)
|
||||
continue
|
||||
|
||||
if stream_to_general and resp.type == "streaming_delta":
|
||||
@@ -69,6 +99,15 @@ async def run_agent(
|
||||
continue
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
# send agent stats to webchat
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(
|
||||
MessageChain(
|
||||
type="agent_stats",
|
||||
chain=[Json(data=agent_runner.stats.to_dict())],
|
||||
)
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -209,12 +209,42 @@ async def call_local_llm_tool(
|
||||
else:
|
||||
raise ValueError(f"未知的方法名: {method_name}")
|
||||
except ValueError as e:
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
raise Exception(f"Tool execution ValueError: {e}") from e
|
||||
except TypeError as e:
|
||||
# 获取函数的签名(包括类型),除了第一个 event/context 参数。
|
||||
try:
|
||||
sig = inspect.signature(handler)
|
||||
params = list(sig.parameters.values())
|
||||
# 跳过第一个参数(event 或 context)
|
||||
if params:
|
||||
params = params[1:]
|
||||
|
||||
param_strs = []
|
||||
for param in params:
|
||||
param_str = param.name
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
# 获取类型注解的字符串表示
|
||||
if isinstance(param.annotation, type):
|
||||
type_str = param.annotation.__name__
|
||||
else:
|
||||
type_str = str(param.annotation)
|
||||
param_str += f": {type_str}"
|
||||
if param.default != inspect.Parameter.empty:
|
||||
param_str += f" = {param.default!r}"
|
||||
param_strs.append(param_str)
|
||||
|
||||
handler_param_str = (
|
||||
", ".join(param_strs) if param_strs else "(no additional parameters)"
|
||||
)
|
||||
except Exception:
|
||||
handler_param_str = "(unable to inspect signature)"
|
||||
|
||||
raise Exception(
|
||||
f"Tool handler parameter mismatch, please check the handler definition. Handler parameters: {handler_param_str}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
trace_ = traceback.format_exc()
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||
raise Exception(f"Tool execution error: {e}. Traceback: {trace_}") from e
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""AstrBot 备份与恢复模块
|
||||
|
||||
提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。
|
||||
"""
|
||||
|
||||
# 从 constants 模块导入共享常量
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
# 导入导出器和导入器
|
||||
from .exporter import AstrBotExporter
|
||||
from .importer import AstrBotImporter, ImportPreCheckResult
|
||||
|
||||
__all__ = [
|
||||
"AstrBotExporter",
|
||||
"AstrBotImporter",
|
||||
"ImportPreCheckResult",
|
||||
"MAIN_DB_MODELS",
|
||||
"KB_METADATA_MODELS",
|
||||
"get_backup_directories",
|
||||
"BACKUP_MANIFEST_VERSION",
|
||||
]
|
||||
@@ -0,0 +1,77 @@
|
||||
"""AstrBot 备份模块共享常量
|
||||
|
||||
此文件定义了导出器和导入器共享的常量,确保两端配置一致。
|
||||
"""
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
PlatformSession,
|
||||
PlatformStat,
|
||||
Preference,
|
||||
)
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_plugin_data_path,
|
||||
get_astrbot_plugin_path,
|
||||
get_astrbot_t2i_templates_path,
|
||||
get_astrbot_temp_path,
|
||||
get_astrbot_webchat_path,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 共享常量 - 确保导出和导入端配置一致
|
||||
# ============================================================
|
||||
|
||||
# 主数据库模型类映射
|
||||
MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
|
||||
"platform_stats": PlatformStat,
|
||||
"conversations": ConversationV2,
|
||||
"personas": Persona,
|
||||
"preferences": Preference,
|
||||
"platform_message_history": PlatformMessageHistory,
|
||||
"platform_sessions": PlatformSession,
|
||||
"attachments": Attachment,
|
||||
"command_configs": CommandConfig,
|
||||
"command_conflicts": CommandConflict,
|
||||
}
|
||||
|
||||
# 知识库元数据模型类映射
|
||||
KB_METADATA_MODELS: dict[str, type[SQLModel]] = {
|
||||
"knowledge_bases": KnowledgeBase,
|
||||
"kb_documents": KBDocument,
|
||||
"kb_media": KBMedia,
|
||||
}
|
||||
|
||||
|
||||
def get_backup_directories() -> dict[str, str]:
|
||||
"""获取需要备份的目录列表
|
||||
|
||||
使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录。
|
||||
|
||||
Returns:
|
||||
dict: 键为备份文件中的目录名称,值为目录的绝对路径
|
||||
"""
|
||||
return {
|
||||
"plugins": get_astrbot_plugin_path(), # 插件本体
|
||||
"plugin_data": get_astrbot_plugin_data_path(), # 插件数据
|
||||
"config": get_astrbot_config_path(), # 配置目录
|
||||
"t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板
|
||||
"webchat": get_astrbot_webchat_path(), # WebChat 数据
|
||||
"temp": get_astrbot_temp_path(), # 临时文件
|
||||
}
|
||||
|
||||
|
||||
# 备份清单版本号
|
||||
BACKUP_MANIFEST_VERSION = "1.1"
|
||||
@@ -0,0 +1,476 @@
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
负责将所有数据导出为 ZIP 备份文件。
|
||||
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_backups_path,
|
||||
get_astrbot_data_path,
|
||||
)
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
|
||||
|
||||
class AstrBotExporter:
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
导出内容:
|
||||
- 主数据库所有表(data/data_v4.db)
|
||||
- 知识库元数据(data/knowledge_base/kb.db)
|
||||
- 每个知识库的向量文档数据
|
||||
- 配置文件(data/cmd_config.json)
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self._checksums: dict[str, str] = {}
|
||||
|
||||
async def export_all(
|
||||
self,
|
||||
output_dir: str | None = None,
|
||||
progress_callback: Any | None = None,
|
||||
) -> str:
|
||||
"""导出所有数据到 ZIP 文件
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
str: 生成的 ZIP 文件路径
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = get_astrbot_backups_path()
|
||||
|
||||
# 确保输出目录存在
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
zip_filename = f"astrbot_backup_{timestamp}.zip"
|
||||
zip_path = os.path.join(output_dir, zip_filename)
|
||||
|
||||
logger.info(f"开始导出备份到 {zip_path}")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
# 1. 导出主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导出主数据库...")
|
||||
main_data = await self._export_main_database()
|
||||
main_db_json = json.dumps(
|
||||
main_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/main_db.json", main_db_json)
|
||||
self._add_checksum("databases/main_db.json", main_db_json)
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导出完成")
|
||||
|
||||
# 2. 导出知识库数据
|
||||
kb_meta_data: dict[str, Any] = {
|
||||
"knowledge_bases": [],
|
||||
"kb_documents": [],
|
||||
"kb_media": [],
|
||||
}
|
||||
if self.kb_manager:
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 0, 100, "正在导出知识库元数据..."
|
||||
)
|
||||
kb_meta_data = await self._export_kb_metadata()
|
||||
kb_meta_json = json.dumps(
|
||||
kb_meta_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/kb_metadata.json", kb_meta_json)
|
||||
self._add_checksum("databases/kb_metadata.json", kb_meta_json)
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 100, 100, "知识库元数据导出完成"
|
||||
)
|
||||
|
||||
# 导出每个知识库的文档数据
|
||||
kb_insts = self.kb_manager.kb_insts
|
||||
total_kbs = len(kb_insts)
|
||||
for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()):
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents",
|
||||
idx,
|
||||
total_kbs,
|
||||
f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...",
|
||||
)
|
||||
doc_data = await self._export_kb_documents(kb_helper)
|
||||
doc_json = json.dumps(
|
||||
doc_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
zf.writestr(doc_path, doc_json)
|
||||
self._add_checksum(doc_path, doc_json)
|
||||
|
||||
# 导出 FAISS 索引文件
|
||||
await self._export_faiss_index(zf, kb_helper, kb_id)
|
||||
|
||||
# 导出知识库多媒体文件
|
||||
await self._export_kb_media_files(zf, kb_helper, kb_id)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents", total_kbs, total_kbs, "知识库文档导出完成"
|
||||
)
|
||||
|
||||
# 3. 导出配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导出配置文件...")
|
||||
if os.path.exists(self.config_path):
|
||||
with open(self.config_path, encoding="utf-8") as f:
|
||||
config_content = f.read()
|
||||
zf.writestr("config/cmd_config.json", config_content)
|
||||
self._add_checksum("config/cmd_config.json", config_content)
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导出完成")
|
||||
|
||||
# 4. 导出附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导出附件...")
|
||||
await self._export_attachments(zf, main_data.get("attachments", []))
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导出完成")
|
||||
|
||||
# 5. 导出插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导出插件和数据目录..."
|
||||
)
|
||||
dir_stats = await self._export_directories(zf)
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导出完成")
|
||||
|
||||
# 6. 生成 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 0, 100, "正在生成清单...")
|
||||
manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats)
|
||||
manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2)
|
||||
zf.writestr("manifest.json", manifest_json)
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 100, 100, "清单生成完成")
|
||||
|
||||
logger.info(f"备份导出完成: {zip_path}")
|
||||
return zip_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"备份导出失败: {e}")
|
||||
# 清理失败的文件
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
raise
|
||||
|
||||
async def _export_main_database(self) -> dict[str, list[dict]]:
|
||||
"""导出主数据库所有表"""
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_metadata(self) -> dict[str, list[dict]]:
|
||||
"""导出知识库元数据库"""
|
||||
if not self.kb_manager:
|
||||
return {"knowledge_bases": [], "kb_documents": [], "kb_media": []}
|
||||
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]:
|
||||
"""导出知识库的文档块数据"""
|
||||
try:
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
|
||||
vec_db: FaissVecDB = kb_helper.vec_db
|
||||
if not vec_db or not vec_db.document_storage:
|
||||
return {"documents": []}
|
||||
|
||||
# 获取所有文档
|
||||
docs = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={},
|
||||
offset=0,
|
||||
limit=None, # 获取全部
|
||||
)
|
||||
|
||||
return {"documents": docs}
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库文档失败: {e}")
|
||||
return {"documents": []}
|
||||
|
||||
async def _export_faiss_index(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_helper: Any,
|
||||
kb_id: str,
|
||||
) -> None:
|
||||
"""导出 FAISS 索引文件"""
|
||||
try:
|
||||
index_path = kb_helper.kb_dir / "index.faiss"
|
||||
if index_path.exists():
|
||||
archive_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
zf.write(str(index_path), archive_path)
|
||||
logger.debug(f"导出 FAISS 索引: {archive_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"导出 FAISS 索引失败: {e}")
|
||||
|
||||
async def _export_kb_media_files(
|
||||
self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str
|
||||
) -> None:
|
||||
"""导出知识库的多媒体文件"""
|
||||
try:
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if not media_dir.exists():
|
||||
return
|
||||
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
file_path = Path(root) / file
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(kb_helper.kb_dir)
|
||||
archive_path = f"files/kb_media/{kb_id}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库媒体文件失败: {e}")
|
||||
|
||||
async def _export_directories(
|
||||
self, zf: zipfile.ZipFile
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""导出插件和其他数据目录
|
||||
|
||||
Returns:
|
||||
dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}}
|
||||
"""
|
||||
stats: dict[str, dict[str, int]] = {}
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name, dir_path in backup_directories.items():
|
||||
full_path = Path(dir_path)
|
||||
if not full_path.exists():
|
||||
logger.debug(f"目录不存在,跳过: {full_path}")
|
||||
continue
|
||||
|
||||
file_count = 0
|
||||
total_size = 0
|
||||
|
||||
try:
|
||||
for root, dirs, files in os.walk(full_path):
|
||||
# 跳过 __pycache__ 目录
|
||||
dirs[:] = [d for d in dirs if d != "__pycache__"]
|
||||
|
||||
for file in files:
|
||||
# 跳过 .pyc 文件
|
||||
if file.endswith(".pyc"):
|
||||
continue
|
||||
|
||||
file_path = Path(root) / file
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(full_path)
|
||||
archive_path = f"directories/{dir_name}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
file_count += 1
|
||||
total_size += file_path.stat().st_size
|
||||
except Exception as e:
|
||||
logger.warning(f"导出文件 {file_path} 失败: {e}")
|
||||
|
||||
stats[dir_name] = {"files": file_count, "size": total_size}
|
||||
logger.debug(
|
||||
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出目录 {dir_path} 失败: {e}")
|
||||
stats[dir_name] = {"files": 0, "size": 0}
|
||||
|
||||
return stats
|
||||
|
||||
async def _export_attachments(
|
||||
self, zf: zipfile.ZipFile, attachments: list[dict]
|
||||
) -> None:
|
||||
"""导出附件文件"""
|
||||
for attachment in attachments:
|
||||
try:
|
||||
file_path = attachment.get("path", "")
|
||||
if file_path and os.path.exists(file_path):
|
||||
# 使用 attachment_id 作为文件名
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
archive_path = f"files/attachments/{attachment_id}{ext}"
|
||||
zf.write(file_path, archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出附件失败: {e}")
|
||||
|
||||
def _model_to_dict(self, record: Any) -> dict:
|
||||
"""将 SQLModel 实例转换为字典
|
||||
|
||||
这是数据库无关的序列化方式,支持未来迁移到其他数据库。
|
||||
"""
|
||||
# 使用 SQLModel 内置的 model_dump 方法(如果可用)
|
||||
if hasattr(record, "model_dump"):
|
||||
data = record.model_dump(mode="python")
|
||||
# 处理 datetime 类型
|
||||
for key, value in data.items():
|
||||
if isinstance(value, datetime):
|
||||
data[key] = value.isoformat()
|
||||
return data
|
||||
|
||||
# 回退到手动提取
|
||||
data = {}
|
||||
# 使用 inspect 获取表信息
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
mapper = sa_inspect(record.__class__)
|
||||
for column in mapper.columns:
|
||||
value = getattr(record, column.name)
|
||||
# 处理 datetime 类型 - 统一转为 ISO 格式字符串
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
data[column.name] = value
|
||||
return data
|
||||
|
||||
def _add_checksum(self, path: str, content: str | bytes) -> None:
|
||||
"""计算并添加文件校验和"""
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
checksum = hashlib.sha256(content).hexdigest()
|
||||
self._checksums[path] = f"sha256:{checksum}"
|
||||
|
||||
def _generate_manifest(
|
||||
self,
|
||||
main_data: dict[str, list[dict]],
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
dir_stats: dict[str, dict[str, int]] | None = None,
|
||||
) -> dict:
|
||||
"""生成备份清单"""
|
||||
if dir_stats is None:
|
||||
dir_stats = {}
|
||||
# 收集知识库 ID
|
||||
kb_document_tables = {}
|
||||
if self.kb_manager:
|
||||
for kb_id in self.kb_manager.kb_insts.keys():
|
||||
kb_document_tables[kb_id] = "documents"
|
||||
|
||||
# 收集附件文件列表
|
||||
attachment_files = []
|
||||
for attachment in main_data.get("attachments", []):
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
path = attachment.get("path", "")
|
||||
if attachment_id and path:
|
||||
ext = os.path.splitext(path)[1]
|
||||
attachment_files.append(f"{attachment_id}{ext}")
|
||||
|
||||
# 收集知识库媒体文件
|
||||
kb_media_files: dict[str, list[str]] = {}
|
||||
if self.kb_manager:
|
||||
for kb_id, kb_helper in self.kb_manager.kb_insts.items():
|
||||
media_files: list[str] = []
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if media_dir.exists():
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
media_files.append(file)
|
||||
if media_files:
|
||||
kb_media_files[kb_id] = media_files
|
||||
|
||||
manifest = {
|
||||
"version": BACKUP_MANIFEST_VERSION,
|
||||
"astrbot_version": VERSION,
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"schema_version": {
|
||||
"main_db": "v4",
|
||||
"kb_db": "v1",
|
||||
},
|
||||
"tables": {
|
||||
"main_db": list(main_data.keys()),
|
||||
"kb_metadata": list(kb_meta_data.keys()),
|
||||
"kb_documents": kb_document_tables,
|
||||
},
|
||||
"files": {
|
||||
"attachments": attachment_files,
|
||||
"kb_media": kb_media_files,
|
||||
},
|
||||
"directories": list(dir_stats.keys()),
|
||||
"checksums": self._checksums,
|
||||
"statistics": {
|
||||
"main_db": {
|
||||
table: len(records) for table, records in main_data.items()
|
||||
},
|
||||
"kb_metadata": {
|
||||
table: len(records) for table, records in kb_meta_data.items()
|
||||
},
|
||||
"directories": dir_stats,
|
||||
},
|
||||
}
|
||||
|
||||
return manifest
|
||||
@@ -0,0 +1,761 @@
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
负责从 ZIP 备份文件恢复所有数据。
|
||||
导入时进行版本校验:
|
||||
- 主版本(前两位)不同时直接拒绝导入
|
||||
- 小版本(第三位)不同时提示警告,用户可选择强制导入
|
||||
- 版本匹配时也需要用户确认
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import delete
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_data_path,
|
||||
get_astrbot_knowledge_base_path,
|
||||
)
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
|
||||
def _get_major_version(version_str: str) -> str:
|
||||
"""提取版本的主版本部分(前两位)
|
||||
|
||||
Args:
|
||||
version_str: 版本字符串,如 "4.9.1", "4.10.0-beta"
|
||||
|
||||
Returns:
|
||||
主版本字符串,如 "4.9", "4.10"
|
||||
"""
|
||||
if not version_str:
|
||||
return "0.0"
|
||||
# 移除 v 前缀和预发布标签
|
||||
version = version_str.lower().replace("v", "").split("-")[0].split("+")[0]
|
||||
parts = [p for p in version.split(".") if p] # 过滤空字符串
|
||||
if len(parts) >= 2:
|
||||
return f"{parts[0]}.{parts[1]}"
|
||||
elif len(parts) == 1 and parts[0]:
|
||||
return f"{parts[0]}.0"
|
||||
return "0.0"
|
||||
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
KB_PATH = get_astrbot_knowledge_base_path()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportPreCheckResult:
|
||||
"""导入预检查结果
|
||||
|
||||
用于在实际导入前检查备份文件的版本兼容性,
|
||||
并返回确认信息让用户决定是否继续导入。
|
||||
"""
|
||||
|
||||
# 检查是否通过(文件有效且版本可导入)
|
||||
valid: bool = False
|
||||
# 是否可以导入(版本兼容)
|
||||
can_import: bool = False
|
||||
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
|
||||
version_status: str = ""
|
||||
# 备份文件中的 AstrBot 版本
|
||||
backup_version: str = ""
|
||||
# 当前运行的 AstrBot 版本
|
||||
current_version: str = VERSION
|
||||
# 备份创建时间
|
||||
backup_time: str = ""
|
||||
# 确认消息(显示给用户)
|
||||
confirm_message: str = ""
|
||||
# 警告消息列表
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
# 错误消息(如果检查失败)
|
||||
error: str = ""
|
||||
# 备份包含的内容摘要
|
||||
backup_summary: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"valid": self.valid,
|
||||
"can_import": self.can_import,
|
||||
"version_status": self.version_status,
|
||||
"backup_version": self.backup_version,
|
||||
"current_version": self.current_version,
|
||||
"backup_time": self.backup_time,
|
||||
"confirm_message": self.confirm_message,
|
||||
"warnings": self.warnings,
|
||||
"error": self.error,
|
||||
"backup_summary": self.backup_summary,
|
||||
}
|
||||
|
||||
|
||||
class ImportResult:
|
||||
"""导入结果"""
|
||||
|
||||
def __init__(self):
|
||||
self.success = True
|
||||
self.imported_tables: dict[str, int] = {}
|
||||
self.imported_files: dict[str, int] = {}
|
||||
self.imported_directories: dict[str, int] = {}
|
||||
self.warnings: list[str] = []
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_warning(self, msg: str) -> None:
|
||||
self.warnings.append(msg)
|
||||
logger.warning(msg)
|
||||
|
||||
def add_error(self, msg: str) -> None:
|
||||
self.errors.append(msg)
|
||||
self.success = False
|
||||
logger.error(msg)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"success": self.success,
|
||||
"imported_tables": self.imported_tables,
|
||||
"imported_files": self.imported_files,
|
||||
"imported_directories": self.imported_directories,
|
||||
"warnings": self.warnings,
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class AstrBotImporter:
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
导入备份文件中的所有数据,包括:
|
||||
- 主数据库所有表
|
||||
- 知识库元数据和文档
|
||||
- 配置文件
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
kb_root_dir: str = KB_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self.kb_root_dir = kb_root_dir
|
||||
|
||||
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
|
||||
"""预检查备份文件
|
||||
|
||||
在实际导入前检查备份文件的有效性和版本兼容性。
|
||||
返回检查结果供前端显示确认对话框。
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
|
||||
Returns:
|
||||
ImportPreCheckResult: 预检查结果
|
||||
"""
|
||||
result = ImportPreCheckResult()
|
||||
result.current_version = VERSION
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.error = f"备份文件不存在: {zip_path}"
|
||||
return result
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 读取 manifest
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.error = f"manifest.json 格式错误: {e}"
|
||||
return result
|
||||
|
||||
# 提取基本信息
|
||||
result.backup_version = manifest.get("astrbot_version", "未知")
|
||||
result.backup_time = manifest.get("exported_at", "未知")
|
||||
result.valid = True
|
||||
|
||||
# 构建备份摘要
|
||||
result.backup_summary = {
|
||||
"tables": list(manifest.get("tables", {}).keys()),
|
||||
"has_knowledge_bases": manifest.get("has_knowledge_bases", False),
|
||||
"has_config": manifest.get("has_config", False),
|
||||
"directories": manifest.get("directories", []),
|
||||
}
|
||||
|
||||
# 检查版本兼容性
|
||||
version_check = self._check_version_compatibility(result.backup_version)
|
||||
result.version_status = version_check["status"]
|
||||
result.can_import = version_check["can_import"]
|
||||
|
||||
# 版本信息由前端根据 version_status 和 i18n 生成显示
|
||||
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
|
||||
# warnings 列表保留用于其他非版本相关的警告
|
||||
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.error = "无效的 ZIP 文件"
|
||||
return result
|
||||
except Exception as e:
|
||||
result.error = f"检查备份文件失败: {e}"
|
||||
return result
|
||||
|
||||
def _check_version_compatibility(self, backup_version: str) -> dict:
|
||||
"""检查版本兼容性
|
||||
|
||||
规则:
|
||||
- 主版本(前两位,如 4.9)必须一致,否则拒绝
|
||||
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
|
||||
|
||||
Returns:
|
||||
dict: {status, can_import, message}
|
||||
"""
|
||||
if not backup_version:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": "备份文件缺少版本信息",
|
||||
}
|
||||
|
||||
# 提取主版本(前两位)进行比较
|
||||
backup_major = _get_major_version(backup_version)
|
||||
current_major = _get_major_version(VERSION)
|
||||
|
||||
# 比较主版本
|
||||
if VersionComparator.compare_version(backup_major, current_major) != 0:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": (
|
||||
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。"
|
||||
),
|
||||
}
|
||||
|
||||
# 比较完整版本
|
||||
version_cmp = VersionComparator.compare_version(backup_version, VERSION)
|
||||
if version_cmp != 0:
|
||||
return {
|
||||
"status": "minor_diff",
|
||||
"can_import": True,
|
||||
"message": (
|
||||
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
),
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "match",
|
||||
"can_import": True,
|
||||
"message": "版本匹配",
|
||||
}
|
||||
|
||||
async def import_all(
|
||||
self,
|
||||
zip_path: str,
|
||||
mode: str = "replace", # "replace" 清空后导入
|
||||
progress_callback: Any | None = None,
|
||||
) -> ImportResult:
|
||||
"""从 ZIP 文件导入所有数据
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
mode: 导入模式,目前仅支持 "replace"(清空后导入)
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
ImportResult: 导入结果
|
||||
"""
|
||||
result = ImportResult()
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.add_error(f"备份文件不存在: {zip_path}")
|
||||
return result
|
||||
|
||||
logger.info(f"开始从 {zip_path} 导入备份")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 1. 读取并验证 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 0, 100, "正在验证备份文件...")
|
||||
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.add_error("备份文件缺少 manifest.json")
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.add_error(f"manifest.json 格式错误: {e}")
|
||||
return result
|
||||
|
||||
# 版本校验
|
||||
try:
|
||||
self._validate_version(manifest)
|
||||
except ValueError as e:
|
||||
result.add_error(str(e))
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 100, 100, "验证完成")
|
||||
|
||||
# 2. 导入主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导入主数据库...")
|
||||
|
||||
try:
|
||||
main_data_content = zf.read("databases/main_db.json")
|
||||
main_data = json.loads(main_data_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_main_db()
|
||||
|
||||
imported = await self._import_main_database(main_data)
|
||||
result.imported_tables.update(imported)
|
||||
except Exception as e:
|
||||
result.add_error(f"导入主数据库失败: {e}")
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导入完成")
|
||||
|
||||
# 3. 导入知识库
|
||||
if self.kb_manager and "databases/kb_metadata.json" in zf.namelist():
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 0, 100, "正在导入知识库...")
|
||||
|
||||
try:
|
||||
kb_meta_content = zf.read("databases/kb_metadata.json")
|
||||
kb_meta_data = json.loads(kb_meta_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_kb_data()
|
||||
|
||||
await self._import_knowledge_bases(zf, kb_meta_data, result)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 100, 100, "知识库导入完成")
|
||||
|
||||
# 4. 导入配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导入配置文件...")
|
||||
|
||||
if "config/cmd_config.json" in zf.namelist():
|
||||
try:
|
||||
config_content = zf.read("config/cmd_config.json")
|
||||
# 备份现有配置
|
||||
if os.path.exists(self.config_path):
|
||||
backup_path = f"{self.config_path}.bak"
|
||||
shutil.copy2(self.config_path, backup_path)
|
||||
|
||||
with open(self.config_path, "wb") as f:
|
||||
f.write(config_content)
|
||||
result.imported_files["config"] = 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入配置文件失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导入完成")
|
||||
|
||||
# 5. 导入附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导入附件...")
|
||||
|
||||
attachment_count = await self._import_attachments(
|
||||
zf, main_data.get("attachments", [])
|
||||
)
|
||||
result.imported_files["attachments"] = attachment_count
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导入完成")
|
||||
|
||||
# 6. 导入插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导入插件和数据目录..."
|
||||
)
|
||||
|
||||
dir_stats = await self._import_directories(zf, manifest, result)
|
||||
result.imported_directories = dir_stats
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导入完成")
|
||||
|
||||
logger.info(f"备份导入完成: {result.to_dict()}")
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.add_error("无效的 ZIP 文件")
|
||||
return result
|
||||
except Exception as e:
|
||||
result.add_error(f"导入失败: {e}")
|
||||
return result
|
||||
|
||||
def _validate_version(self, manifest: dict) -> None:
|
||||
"""验证版本兼容性 - 仅允许相同主版本导入
|
||||
|
||||
注意:此方法仅在 import_all 中调用,用于双重校验。
|
||||
前端应先调用 pre_check 获取详细的版本信息并让用户确认。
|
||||
"""
|
||||
backup_version = manifest.get("astrbot_version")
|
||||
if not backup_version:
|
||||
raise ValueError("备份文件缺少版本信息")
|
||||
|
||||
# 使用新的版本兼容性检查
|
||||
version_check = self._check_version_compatibility(backup_version)
|
||||
|
||||
if version_check["status"] == "major_diff":
|
||||
raise ValueError(version_check["message"])
|
||||
|
||||
# minor_diff 和 match 都允许导入
|
||||
if version_check["status"] == "minor_diff":
|
||||
logger.warning(f"版本差异警告: {version_check['message']}")
|
||||
|
||||
async def _clear_main_db(self) -> None:
|
||||
"""清空主数据库所有表"""
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空表 {table_name} 失败: {e}")
|
||||
|
||||
async def _clear_kb_data(self) -> None:
|
||||
"""清空知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 清空知识库元数据表
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空知识库表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
|
||||
|
||||
# 删除知识库文件目录
|
||||
for kb_id in list(self.kb_manager.kb_insts.keys()):
|
||||
try:
|
||||
kb_helper = self.kb_manager.kb_insts[kb_id]
|
||||
await kb_helper.terminate()
|
||||
if kb_helper.kb_dir.exists():
|
||||
shutil.rmtree(kb_helper.kb_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理知识库 {kb_id} 失败: {e}")
|
||||
|
||||
self.kb_manager.kb_insts.clear()
|
||||
|
||||
async def _import_main_database(
|
||||
self, data: dict[str, list[dict]]
|
||||
) -> dict[str, int]:
|
||||
"""导入主数据库数据"""
|
||||
imported: dict[str, int] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in data.items():
|
||||
model_class = MAIN_DB_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
logger.warning(f"未知的表: {table_name}")
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
# 转换 datetime 字符串为 datetime 对象
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入记录到 {table_name} 失败: {e}")
|
||||
|
||||
imported[table_name] = count
|
||||
logger.debug(f"导入表 {table_name}: {count} 条记录")
|
||||
|
||||
return imported
|
||||
|
||||
async def _import_knowledge_bases(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
result: ImportResult,
|
||||
) -> None:
|
||||
"""导入知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 1. 导入知识库元数据
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in kb_meta_data.items():
|
||||
model_class = KB_METADATA_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
|
||||
|
||||
result.imported_tables[f"kb_{table_name}"] = count
|
||||
|
||||
# 2. 导入每个知识库的文档和文件
|
||||
for kb_data in kb_meta_data.get("knowledge_bases", []):
|
||||
kb_id = kb_data.get("kb_id")
|
||||
if not kb_id:
|
||||
continue
|
||||
|
||||
# 创建知识库目录
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
kb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 导入文档数据
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
if doc_path in zf.namelist():
|
||||
try:
|
||||
doc_content = zf.read(doc_path)
|
||||
doc_data = json.loads(doc_content)
|
||||
|
||||
# 导入到文档存储数据库
|
||||
await self._import_kb_documents(kb_id, doc_data)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}")
|
||||
|
||||
# 导入 FAISS 索引
|
||||
faiss_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
if faiss_path in zf.namelist():
|
||||
try:
|
||||
target_path = kb_dir / "index.faiss"
|
||||
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
|
||||
|
||||
# 导入媒体文件
|
||||
media_prefix = f"files/kb_media/{kb_id}/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(media_prefix):
|
||||
try:
|
||||
rel_path = name[len(media_prefix) :]
|
||||
target_path = kb_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
|
||||
|
||||
# 3. 重新加载知识库实例
|
||||
await self.kb_manager.load_kbs()
|
||||
|
||||
async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None:
|
||||
"""导入知识库文档到向量数据库"""
|
||||
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
|
||||
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
doc_db_path = kb_dir / "doc.db"
|
||||
|
||||
# 初始化文档存储
|
||||
doc_storage = DocumentStorage(str(doc_db_path))
|
||||
await doc_storage.initialize()
|
||||
|
||||
try:
|
||||
documents = doc_data.get("documents", [])
|
||||
for doc in documents:
|
||||
try:
|
||||
await doc_storage.insert_document(
|
||||
doc_id=doc.get("doc_id", ""),
|
||||
text=doc.get("text", ""),
|
||||
metadata=json.loads(doc.get("metadata", "{}")),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导入文档块失败: {e}")
|
||||
finally:
|
||||
await doc_storage.close()
|
||||
|
||||
async def _import_attachments(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
attachments: list[dict],
|
||||
) -> int:
|
||||
"""导入附件文件"""
|
||||
count = 0
|
||||
|
||||
attachments_dir = Path(self.config_path).parent / "attachments"
|
||||
attachments_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
attachment_prefix = "files/attachments/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(attachment_prefix) and name != attachment_prefix:
|
||||
try:
|
||||
# 从附件记录中找到原始路径
|
||||
attachment_id = os.path.splitext(os.path.basename(name))[0]
|
||||
original_path = None
|
||||
for att in attachments:
|
||||
if att.get("attachment_id") == attachment_id:
|
||||
original_path = att.get("path")
|
||||
break
|
||||
|
||||
if original_path:
|
||||
target_path = Path(original_path)
|
||||
else:
|
||||
target_path = attachments_dir / os.path.basename(name)
|
||||
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入附件 {name} 失败: {e}")
|
||||
|
||||
return count
|
||||
|
||||
async def _import_directories(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
manifest: dict,
|
||||
result: ImportResult,
|
||||
) -> dict[str, int]:
|
||||
"""导入插件和其他数据目录
|
||||
|
||||
Args:
|
||||
zf: ZIP 文件对象
|
||||
manifest: 备份清单
|
||||
result: 导入结果对象
|
||||
|
||||
Returns:
|
||||
dict: 每个目录导入的文件数量
|
||||
"""
|
||||
dir_stats: dict[str, int] = {}
|
||||
|
||||
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
|
||||
backup_version = manifest.get("version", "1.0")
|
||||
if VersionComparator.compare_version(backup_version, "1.1") < 0:
|
||||
logger.info("备份版本不支持目录备份,跳过目录导入")
|
||||
return dir_stats
|
||||
|
||||
backed_up_dirs = manifest.get("directories", [])
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name in backed_up_dirs:
|
||||
if dir_name not in backup_directories:
|
||||
result.add_warning(f"未知的目录类型: {dir_name}")
|
||||
continue
|
||||
|
||||
target_dir = Path(backup_directories[dir_name])
|
||||
archive_prefix = f"directories/{dir_name}/"
|
||||
|
||||
file_count = 0
|
||||
|
||||
try:
|
||||
# 获取该目录下的所有文件
|
||||
dir_files = [
|
||||
name
|
||||
for name in zf.namelist()
|
||||
if name.startswith(archive_prefix) and name != archive_prefix
|
||||
]
|
||||
|
||||
if not dir_files:
|
||||
continue
|
||||
|
||||
# 备份现有目录(如果存在)
|
||||
if target_dir.exists():
|
||||
backup_path = Path(f"{target_dir}.bak")
|
||||
if backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
shutil.move(str(target_dir), str(backup_path))
|
||||
logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}")
|
||||
|
||||
# 创建目标目录
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 解压文件
|
||||
for name in dir_files:
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = name[len(archive_prefix) :]
|
||||
if not rel_path: # 跳过目录条目
|
||||
continue
|
||||
|
||||
target_path = target_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
file_count += 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入文件 {name} 失败: {e}")
|
||||
|
||||
dir_stats[dir_name] = file_count
|
||||
logger.debug(f"导入目录 {dir_name}: {file_count} 个文件")
|
||||
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入目录 {dir_name} 失败: {e}")
|
||||
dir_stats[dir_name] = 0
|
||||
|
||||
return dir_stats
|
||||
|
||||
def _convert_datetime_fields(self, row: dict, model_class: type) -> dict:
|
||||
"""转换 datetime 字符串字段为 datetime 对象"""
|
||||
result = row.copy()
|
||||
|
||||
# 获取模型的 datetime 字段
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
try:
|
||||
mapper = sa_inspect(model_class)
|
||||
for column in mapper.columns:
|
||||
if column.name in result and result[column.name] is not None:
|
||||
# 检查是否是 datetime 类型的列
|
||||
from sqlalchemy import DateTime
|
||||
|
||||
if isinstance(column.type, DateTime):
|
||||
value = result[column.name]
|
||||
if isinstance(value, str):
|
||||
# 解析 ISO 格式的日期时间字符串
|
||||
result[column.name] = datetime.fromisoformat(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
+168
-212
@@ -1,10 +1,11 @@
|
||||
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
|
||||
|
||||
import os
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.9.0"
|
||||
VERSION = "4.10.3"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -61,7 +62,8 @@ DEFAULT_CONFIG = {
|
||||
"ignore_bot_self_message": False,
|
||||
"ignore_at_all": False,
|
||||
},
|
||||
"provider": [],
|
||||
"provider_sources": [], # provider sources
|
||||
"provider": [], # models from provider_sources
|
||||
"provider_settings": {
|
||||
"enable": True,
|
||||
"default_provider_id": "",
|
||||
@@ -108,6 +110,7 @@ DEFAULT_CONFIG = {
|
||||
"provider_id": "",
|
||||
"dual_output": False,
|
||||
"use_file_service": False,
|
||||
"trigger_probability": 1.0,
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"group_icl_enable": False,
|
||||
@@ -170,6 +173,22 @@ DEFAULT_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
class ChatProviderTemplate(TypedDict):
|
||||
id: str
|
||||
provider_source_id: str
|
||||
model: str
|
||||
modalities: list
|
||||
custom_extra_body: dict[str, Any]
|
||||
|
||||
|
||||
CHAT_PROVIDER_TEMPLATE = {
|
||||
"id": "",
|
||||
"provide_source_id": "",
|
||||
"model": "",
|
||||
"modalities": [],
|
||||
"custom_extra_body": {},
|
||||
}
|
||||
|
||||
"""
|
||||
AstrBot v3 时代的配置元数据,目前仅承担以下功能:
|
||||
|
||||
@@ -208,7 +227,7 @@ CONFIG_METADATA_2 = {
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6196,
|
||||
},
|
||||
"QQ 个人号(OneBot v11)": {
|
||||
"OneBot v11": {
|
||||
"id": "default",
|
||||
"type": "aiocqhttp",
|
||||
"enable": False,
|
||||
@@ -843,6 +862,7 @@ CONFIG_METADATA_2 = {
|
||||
"metadata": {
|
||||
"provider": {
|
||||
"type": "list",
|
||||
# provider sources templates
|
||||
"config_template": {
|
||||
"OpenAI": {
|
||||
"id": "openai",
|
||||
@@ -853,107 +873,10 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
"id": "azure",
|
||||
"provider": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"provider": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"xai_native_search": False,
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Anthropic": {
|
||||
"hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错",
|
||||
"id": "claude",
|
||||
"provider": "anthropic",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Ollama": {
|
||||
"hint": "启用前请确保已正确安装并运行 Ollama 服务端,Ollama默认不带鉴权,无需修改key",
|
||||
"id": "ollama_default",
|
||||
"provider": "ollama",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"provider": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://localhost:1234/v1",
|
||||
"model_config": {
|
||||
"model": "llama-3.1-8b",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Gemini(OpenAI兼容)": {
|
||||
"id": "gemini_default",
|
||||
"provider": "google",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Gemini": {
|
||||
"id": "gemini_default",
|
||||
"Google Gemini": {
|
||||
"id": "google_gemini",
|
||||
"provider": "google",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -961,10 +884,6 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"gm_resp_image_modal": False,
|
||||
"gm_native_search": False,
|
||||
"gm_native_coderunner": False,
|
||||
@@ -975,13 +894,43 @@ CONFIG_METADATA_2 = {
|
||||
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"budget": 0,
|
||||
},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"gm_thinking_config": {"budget": 0, "level": "HIGH"},
|
||||
},
|
||||
"Anthropic": {
|
||||
"id": "anthropic",
|
||||
"provider": "anthropic",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
},
|
||||
"Moonshot": {
|
||||
"id": "moonshot",
|
||||
"provider": "moonshot",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"provider": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
"xai_native_search": False,
|
||||
},
|
||||
"DeepSeek": {
|
||||
"id": "deepseek_default",
|
||||
"id": "deepseek",
|
||||
"provider": "deepseek",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -989,13 +938,75 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"Zhipu": {
|
||||
"id": "zhipu",
|
||||
"provider": "zhipu",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
"id": "azure_openai",
|
||||
"provider": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Ollama": {
|
||||
"id": "ollama",
|
||||
"provider": "ollama",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://127.0.0.1:11434/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"provider": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://127.0.0.1:1234/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"ModelStack": {
|
||||
"id": "modelstack",
|
||||
"provider": "modelstack",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://modelstack.app/v1",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Gemini_OpenAI_API": {
|
||||
"id": "google_gemini_openai",
|
||||
"provider": "google",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Groq": {
|
||||
"id": "groq_default",
|
||||
"id": "groq",
|
||||
"provider": "groq",
|
||||
"type": "groq_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -1003,13 +1014,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.groq.com/openai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "openai/gpt-oss-20b",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"302.AI": {
|
||||
"id": "302ai",
|
||||
@@ -1020,12 +1025,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.302.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"硅基流动": {
|
||||
"SiliconFlow": {
|
||||
"id": "siliconflow",
|
||||
"provider": "siliconflow",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -1034,15 +1036,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.siliconflow.cn/v1",
|
||||
"model_config": {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"PPIO派欧云": {
|
||||
"PPIO": {
|
||||
"id": "ppio",
|
||||
"provider": "ppio",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -1051,14 +1047,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.ppinfra.com/v3/openai",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"小马算力": {
|
||||
"TokenPony": {
|
||||
"id": "tokenpony",
|
||||
"provider": "tokenpony",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -1067,14 +1058,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.tokenpony.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "kimi-k2-instruct-0905",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"优云智算": {
|
||||
"Compshare": {
|
||||
"id": "compshare",
|
||||
"provider": "compshare",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -1083,42 +1069,18 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.modelverse.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "moonshotai/Kimi-K2-Instruct",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Kimi": {
|
||||
"id": "moonshot",
|
||||
"provider": "moonshot",
|
||||
"ModelScope": {
|
||||
"id": "modelscope",
|
||||
"provider": "modelscope",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"智谱 AI": {
|
||||
"id": "zhipu_default",
|
||||
"provider": "zhipu",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Dify": {
|
||||
"id": "dify_app_default",
|
||||
@@ -1133,7 +1095,6 @@ CONFIG_METADATA_2 = {
|
||||
"dify_query_input_key": "astrbot_text_query",
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
||||
},
|
||||
"Coze": {
|
||||
"id": "coze",
|
||||
@@ -1164,20 +1125,6 @@ CONFIG_METADATA_2 = {
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
},
|
||||
"ModelScope": {
|
||||
"id": "modelscope",
|
||||
"provider": "modelscope",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"FastGPT": {
|
||||
"id": "fastgpt",
|
||||
"provider": "fastgpt",
|
||||
@@ -1201,7 +1148,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "whisper-1",
|
||||
},
|
||||
"Whisper(Local)": {
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"provider": "openai",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"provider_type": "speech_to_text",
|
||||
@@ -1210,7 +1156,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "tiny",
|
||||
},
|
||||
"SenseVoice(Local)": {
|
||||
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"type": "sensevoice_stt_selfhost",
|
||||
"provider": "sensevoice",
|
||||
"provider_type": "speech_to_text",
|
||||
@@ -1232,7 +1177,6 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": "20",
|
||||
},
|
||||
"Edge TTS": {
|
||||
"hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
||||
"id": "edge_tts",
|
||||
"provider": "microsoft",
|
||||
"type": "edge_tts",
|
||||
@@ -1448,6 +1392,10 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"provider_source_id": {
|
||||
"invisible": True,
|
||||
"type": "string",
|
||||
},
|
||||
"xai_native_search": {
|
||||
"description": "启用原生搜索功能",
|
||||
"type": "bool",
|
||||
@@ -1818,13 +1766,24 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"description": "Gemini思考设置",
|
||||
"description": "Thinking Config",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"budget": {
|
||||
"description": "思考预算",
|
||||
"description": "Thinking Budget",
|
||||
"type": "int",
|
||||
"hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
|
||||
"hint": "Guides the model on the specific number of thinking tokens to use for reasoning. See: https://ai.google.dev/gemini-api/docs/thinking#set-budget",
|
||||
},
|
||||
"level": {
|
||||
"description": "Thinking Level",
|
||||
"type": "string",
|
||||
"hint": "Recommended for Gemini 3 models and onwards, lets you control reasoning behavior.See: https://ai.google.dev/gemini-api/docs/thinking#thinking-levels",
|
||||
"options": [
|
||||
"MINIMAL",
|
||||
"LOW",
|
||||
"MEDIUM",
|
||||
"HIGH",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -2005,7 +1964,6 @@ CONFIG_METADATA_2 = {
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
"hint": "模型提供商名字。",
|
||||
},
|
||||
"type": {
|
||||
"description": "模型提供商种类",
|
||||
@@ -2025,29 +1983,15 @@ CONFIG_METADATA_2 = {
|
||||
"description": "API Key",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "提供商 API Key。",
|
||||
},
|
||||
"api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||
},
|
||||
"model_config": {
|
||||
"description": "模型配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"model": {
|
||||
"description": "模型名称",
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_tokens": {
|
||||
"description": "模型最大输出长度(tokens)",
|
||||
"type": "int",
|
||||
},
|
||||
"temperature": {"description": "温度", "type": "float"},
|
||||
"top_p": {"description": "Top P值", "type": "float"},
|
||||
},
|
||||
"model": {
|
||||
"description": "模型 ID",
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
@@ -2209,6 +2153,9 @@ CONFIG_METADATA_2 = {
|
||||
"use_file_service": {
|
||||
"type": "bool",
|
||||
},
|
||||
"trigger_probability": {
|
||||
"type": "float",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
@@ -2419,6 +2366,14 @@ CONFIG_METADATA_3 = {
|
||||
"provider_tts_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_tts_settings.trigger_probability": {
|
||||
"description": "TTS 触发概率",
|
||||
"type": "float",
|
||||
"slider": {"min": 0, "max": 1, "step": 0.05},
|
||||
"condition": {
|
||||
"provider_tts_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.image_caption_prompt": {
|
||||
"description": "图片转述提示词",
|
||||
"type": "text",
|
||||
@@ -2986,6 +2941,7 @@ CONFIG_METADATA_3 = {
|
||||
"description": "回复概率",
|
||||
"type": "float",
|
||||
"hint": "0.0-1.0 之间的数值",
|
||||
"slider": {"min": 0, "max": 1, "step": 0.05},
|
||||
"condition": {
|
||||
"provider_ltm_settings.active_reply.enable": True,
|
||||
},
|
||||
|
||||
@@ -79,6 +79,7 @@ class ConfigMetadataI18n:
|
||||
"_special",
|
||||
"invisible",
|
||||
"options",
|
||||
"slider",
|
||||
]:
|
||||
if attr in field_data:
|
||||
field_result[attr] = field_data[attr]
|
||||
|
||||
@@ -33,6 +33,7 @@ from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.llm_metadata import update_llm_metadata
|
||||
from astrbot.core.utils.migra_helper import migra
|
||||
|
||||
from . import astrbot_config, html_renderer
|
||||
@@ -185,6 +186,8 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化关闭控制面板的事件
|
||||
self.dashboard_shutdown_event = asyncio.Event()
|
||||
|
||||
asyncio.create_task(update_llm_metadata())
|
||||
|
||||
def _load(self) -> None:
|
||||
"""加载事件总线和任务并初始化."""
|
||||
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||
|
||||
@@ -9,6 +9,8 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
@@ -314,6 +316,76 @@ class BaseDatabase(abc.ABC):
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_command_configs(self) -> list[CommandConfig]:
|
||||
"""Get all stored command configurations."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_command_config(self, handler_full_name: str) -> CommandConfig | None:
|
||||
"""Fetch a single command configuration by handler."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upsert_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
"""Create or update a command configuration."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||
"""Delete a single command configuration."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||
"""Bulk delete command configurations."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_command_conflicts(
|
||||
self,
|
||||
status: str | None = None,
|
||||
) -> list[CommandConflict]:
|
||||
"""List recorded command conflict entries."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upsert_command_conflict(
|
||||
self,
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
"""Create or update a conflict record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||
"""Delete conflict records."""
|
||||
...
|
||||
|
||||
# @abc.abstractmethod
|
||||
# async def insert_llm_message(
|
||||
# self,
|
||||
|
||||
@@ -234,6 +234,65 @@ class Attachment(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class CommandConfig(SQLModel, table=True):
|
||||
"""Per-command configuration overrides for dashboard management."""
|
||||
|
||||
__tablename__ = "command_configs" # type: ignore
|
||||
|
||||
handler_full_name: str = Field(
|
||||
primary_key=True,
|
||||
max_length=512,
|
||||
)
|
||||
plugin_name: str = Field(nullable=False, max_length=255)
|
||||
module_path: str = Field(nullable=False, max_length=255)
|
||||
original_command: str = Field(nullable=False, max_length=255)
|
||||
resolved_command: str | None = Field(default=None, max_length=255)
|
||||
enabled: bool = Field(default=True, nullable=False)
|
||||
keep_original_alias: bool = Field(default=False, nullable=False)
|
||||
conflict_key: str | None = Field(default=None, max_length=255)
|
||||
resolution_strategy: str | None = Field(default=None, max_length=64)
|
||||
note: str | None = Field(default=None, sa_type=Text)
|
||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||
auto_managed: bool = Field(default=False, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
class CommandConflict(SQLModel, table=True):
|
||||
"""Conflict tracking for duplicated command names."""
|
||||
|
||||
__tablename__ = "command_conflicts" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
conflict_key: str = Field(nullable=False, max_length=255)
|
||||
handler_full_name: str = Field(nullable=False, max_length=512)
|
||||
plugin_name: str = Field(nullable=False, max_length=255)
|
||||
status: str = Field(default="pending", max_length=32)
|
||||
resolution: str | None = Field(default=None, max_length=64)
|
||||
resolved_command: str | None = Field(default=None, max_length=255)
|
||||
note: str | None = Field(default=None, sa_type=Text)
|
||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||
auto_generated: bool = Field(default=False, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"conflict_key",
|
||||
"handler_full_name",
|
||||
name="uix_conflict_handler",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话类
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import typing as T
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import CursorResult
|
||||
@@ -10,6 +11,8 @@ from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
@@ -26,6 +29,7 @@ from astrbot.core.db.po import (
|
||||
)
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
TxResult = T.TypeVar("TxResult")
|
||||
|
||||
|
||||
class SQLiteDatabase(BaseDatabase):
|
||||
@@ -670,6 +674,242 @@ class SQLiteDatabase(BaseDatabase):
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# ====
|
||||
# Command Configuration & Conflict Tracking
|
||||
# ====
|
||||
|
||||
async def _run_in_tx(
|
||||
self,
|
||||
fn: Callable[[AsyncSession], Awaitable[TxResult]],
|
||||
) -> TxResult:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
return await fn(session)
|
||||
|
||||
@staticmethod
|
||||
def _apply_updates(model, **updates) -> None:
|
||||
for field, value in updates.items():
|
||||
if value is not None:
|
||||
setattr(model, field, value)
|
||||
|
||||
@staticmethod
|
||||
def _new_command_config(
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
return CommandConfig(
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=plugin_name,
|
||||
module_path=module_path,
|
||||
original_command=original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=True if enabled is None else enabled,
|
||||
keep_original_alias=False
|
||||
if keep_original_alias is None
|
||||
else keep_original_alias,
|
||||
conflict_key=conflict_key or original_command,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=bool(auto_managed),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _new_command_conflict(
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
return CommandConflict(
|
||||
conflict_key=conflict_key,
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=plugin_name,
|
||||
status=status or "pending",
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=bool(auto_generated),
|
||||
)
|
||||
|
||||
async def get_command_configs(self) -> list[CommandConfig]:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(select(CommandConfig))
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
) -> CommandConfig | None:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
return await session.get(CommandConfig, handler_full_name)
|
||||
|
||||
async def upsert_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
async def _op(session: AsyncSession) -> CommandConfig:
|
||||
config = await session.get(CommandConfig, handler_full_name)
|
||||
if not config:
|
||||
config = self._new_command_config(
|
||||
handler_full_name,
|
||||
plugin_name,
|
||||
module_path,
|
||||
original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=enabled,
|
||||
keep_original_alias=keep_original_alias,
|
||||
conflict_key=conflict_key,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=auto_managed,
|
||||
)
|
||||
session.add(config)
|
||||
else:
|
||||
self._apply_updates(
|
||||
config,
|
||||
plugin_name=plugin_name,
|
||||
module_path=module_path,
|
||||
original_command=original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=enabled,
|
||||
keep_original_alias=keep_original_alias,
|
||||
conflict_key=conflict_key,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=auto_managed,
|
||||
)
|
||||
await session.flush()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
return await self._run_in_tx(_op)
|
||||
|
||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||
await self.delete_command_configs([handler_full_name])
|
||||
|
||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||
if not handler_full_names:
|
||||
return
|
||||
|
||||
async def _op(session: AsyncSession) -> None:
|
||||
await session.execute(
|
||||
delete(CommandConfig).where(
|
||||
col(CommandConfig.handler_full_name).in_(handler_full_names),
|
||||
),
|
||||
)
|
||||
|
||||
await self._run_in_tx(_op)
|
||||
|
||||
async def list_command_conflicts(
|
||||
self,
|
||||
status: str | None = None,
|
||||
) -> list[CommandConflict]:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(CommandConflict)
|
||||
if status:
|
||||
query = query.where(CommandConflict.status == status)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def upsert_command_conflict(
|
||||
self,
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
async def _op(session: AsyncSession) -> CommandConflict:
|
||||
result = await session.execute(
|
||||
select(CommandConflict).where(
|
||||
CommandConflict.conflict_key == conflict_key,
|
||||
CommandConflict.handler_full_name == handler_full_name,
|
||||
),
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
record = self._new_command_conflict(
|
||||
conflict_key,
|
||||
handler_full_name,
|
||||
plugin_name,
|
||||
status=status,
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=auto_generated,
|
||||
)
|
||||
session.add(record)
|
||||
else:
|
||||
self._apply_updates(
|
||||
record,
|
||||
plugin_name=plugin_name,
|
||||
status=status,
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=auto_generated,
|
||||
)
|
||||
await session.flush()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
return await self._run_in_tx(_op)
|
||||
|
||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
|
||||
async def _op(session: AsyncSession) -> None:
|
||||
await session.execute(
|
||||
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
|
||||
)
|
||||
|
||||
await self._run_in_tx(_op)
|
||||
|
||||
# ====
|
||||
# Deprecated Methods
|
||||
# ====
|
||||
|
||||
+1
-1
@@ -58,7 +58,7 @@ def is_plugin_path(pathname):
|
||||
return False
|
||||
|
||||
norm_path = os.path.normpath(pathname)
|
||||
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
|
||||
return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path)
|
||||
|
||||
|
||||
def get_short_level_name(level_name):
|
||||
|
||||
@@ -629,12 +629,11 @@ class Nodes(BaseMessageComponent):
|
||||
|
||||
class Json(BaseMessageComponent):
|
||||
type = ComponentType.Json
|
||||
data: str | dict
|
||||
resid: int | None = 0
|
||||
data: dict
|
||||
|
||||
def __init__(self, data, **_):
|
||||
if isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
def __init__(self, data: str | dict, **_):
|
||||
if isinstance(data, str):
|
||||
data = json.loads(data)
|
||||
super().__init__(data=data, **_)
|
||||
|
||||
|
||||
|
||||
@@ -321,7 +321,12 @@ class InternalAgentSubStage(Stage):
|
||||
elif isinstance(req.tool_calls_result, list):
|
||||
for tcr in req.tool_calls_result:
|
||||
messages.extend(tcr.to_openai_messages())
|
||||
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": llm_response.completion_text or "*No response*",
|
||||
}
|
||||
)
|
||||
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
@@ -385,7 +390,7 @@ class InternalAgentSubStage(Stage):
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
|
||||
@@ -119,7 +119,7 @@ class RespondStage(Stage):
|
||||
|
||||
if (result := event.get_result()) is None:
|
||||
return False
|
||||
if self.only_llm_result and result.is_llm_result():
|
||||
if self.only_llm_result and not result.is_llm_result():
|
||||
return False
|
||||
|
||||
if event.get_platform_name() in [
|
||||
@@ -158,7 +158,11 @@ class RespondStage(Stage):
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
if event.get_extra("_streaming_finished", False):
|
||||
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
|
||||
return
|
||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||
event.set_extra("_streaming_finished", True)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
@@ -42,6 +43,18 @@ class ResultDecorateStage(Stage):
|
||||
"forward_threshold"
|
||||
]
|
||||
|
||||
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
|
||||
"trigger_probability",
|
||||
1,
|
||||
)
|
||||
try:
|
||||
self.tts_trigger_probability = max(
|
||||
0.0,
|
||||
min(float(trigger_probability), 1.0),
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
self.tts_trigger_probability = 1.0
|
||||
|
||||
# 分段回复
|
||||
self.words_count_threshold = int(
|
||||
ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
||||
@@ -246,7 +259,14 @@ class ResultDecorateStage(Stage):
|
||||
and result.is_llm_result()
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
if not tts_provider:
|
||||
should_tts = self.tts_trigger_probability >= 1.0 or (
|
||||
self.tts_trigger_probability > 0.0
|
||||
and random.random() <= self.tts_trigger_probability
|
||||
)
|
||||
|
||||
if not should_tts:
|
||||
logger.debug("跳过 TTS:触发概率未命中。")
|
||||
elif not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||
)
|
||||
|
||||
@@ -136,7 +136,8 @@ class WakingCheckStage(Stage):
|
||||
):
|
||||
if (
|
||||
self.disable_builtin_commands
|
||||
and handler.handler_module_path == "packages.builtin_commands.main"
|
||||
and handler.handler_module_path
|
||||
== "astrbot.builtin_stars.builtin_commands.main"
|
||||
):
|
||||
logger.debug("skipping builtin command")
|
||||
continue
|
||||
|
||||
@@ -385,10 +385,25 @@ class AiocqhttpAdapter(Platform):
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
|
||||
message_str += "".join(at_parts)
|
||||
elif t == "markdown":
|
||||
text = m["data"].get("markdown") or m["data"].get("content", "")
|
||||
abm.message.append(Plain(text=text))
|
||||
message_str += text
|
||||
else:
|
||||
for m in m_group:
|
||||
a = ComponentTypes[t](**m["data"])
|
||||
abm.message.append(a)
|
||||
try:
|
||||
if t not in ComponentTypes:
|
||||
logger.warning(
|
||||
f"不支持的消息段类型,已忽略: {t}, data={m['data']}"
|
||||
)
|
||||
continue
|
||||
a = ComponentTypes[t](**m["data"])
|
||||
abm.message.append(a)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"消息段解析失败: type={t}, data={m['data']}. {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
|
||||
@@ -81,7 +81,12 @@ class LarkPlatformAdapter(Platform):
|
||||
)
|
||||
|
||||
self.lark_api = (
|
||||
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
|
||||
lark.Client.builder()
|
||||
.app_id(self.appid)
|
||||
.app_secret(self.appsecret)
|
||||
.log_level(lark.LogLevel.ERROR)
|
||||
.domain(self.domain)
|
||||
.build()
|
||||
)
|
||||
|
||||
self.webhook_server = None
|
||||
|
||||
@@ -200,6 +200,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
if isinstance(chain, MessageChain):
|
||||
if chain.type == "break":
|
||||
# 分割符
|
||||
if message_id:
|
||||
try:
|
||||
await self.client.edit_message_text(
|
||||
text=delta,
|
||||
chat_id=payload["chat_id"],
|
||||
message_id=message_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"编辑消息失败(streaming-break): {e!s}")
|
||||
message_id = None # 重置消息 ID
|
||||
delta = "" # 重置 delta
|
||||
continue
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import File, Image, Plain, Record
|
||||
from astrbot.api.message_components import File, Image, Json, Plain, Record
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from .webchat_queue_mgr import webchat_queue_mgr
|
||||
@@ -41,12 +42,20 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "plain",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
"chain_type": message.type,
|
||||
},
|
||||
)
|
||||
elif isinstance(comp, Json):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "plain",
|
||||
"data": json.dumps(comp.data, ensure_ascii=False),
|
||||
"streaming": streaming,
|
||||
"chain_type": message.type,
|
||||
},
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = f"{str(uuid.uuid4())}.jpg"
|
||||
@@ -58,7 +67,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "image",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
},
|
||||
@@ -74,7 +82,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "record",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
},
|
||||
@@ -91,7 +98,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "file",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
},
|
||||
@@ -111,18 +117,17 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
cid = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
async for chain in generator:
|
||||
if chain.type == "break" and final_data:
|
||||
# 分割符
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "break", # break means a segment end
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
},
|
||||
)
|
||||
final_data = ""
|
||||
continue
|
||||
# if chain.type == "break" and final_data:
|
||||
# # 分割符
|
||||
# await web_chat_back_queue.put(
|
||||
# {
|
||||
# "type": "break", # break means a segment end
|
||||
# "data": final_data,
|
||||
# "streaming": True,
|
||||
# },
|
||||
# )
|
||||
# final_data = ""
|
||||
# continue
|
||||
|
||||
r = await WebChatMessageEvent._send(
|
||||
chain,
|
||||
@@ -142,7 +147,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
"data": final_data,
|
||||
"reasoning": reasoning_content,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
},
|
||||
)
|
||||
await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import enum
|
||||
import json
|
||||
@@ -12,6 +14,7 @@ import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import (
|
||||
AssistantMessageSegment,
|
||||
ContentPart,
|
||||
ToolCall,
|
||||
ToolCallMessageSegment,
|
||||
)
|
||||
@@ -90,6 +93,8 @@ class ProviderRequest:
|
||||
"""会话 ID"""
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
|
||||
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。支持 dict 或 ContentPart 对象"""
|
||||
func_tool: ToolSet | None = None
|
||||
"""可用的函数工具"""
|
||||
contexts: list[dict] = field(default_factory=list)
|
||||
@@ -164,13 +169,23 @@ class ProviderRequest:
|
||||
|
||||
async def assemble_context(self) -> dict:
|
||||
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if self.prompt and self.prompt.strip():
|
||||
content_blocks.append({"type": "text", "text": self.prompt})
|
||||
elif self.image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if self.extra_user_content_parts:
|
||||
for part in self.extra_user_content_parts:
|
||||
content_blocks.append(part.model_dump())
|
||||
|
||||
# 3. 图片内容
|
||||
if self.image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt if self.prompt else "[图片]"},
|
||||
],
|
||||
}
|
||||
for image_url in self.image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
@@ -183,11 +198,21 @@ class ProviderRequest:
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
content_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": image_data}},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": self.prompt}
|
||||
|
||||
# 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容
|
||||
if (
|
||||
len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
and not self.extra_user_content_parts
|
||||
and not self.image_urls
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def _encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
@@ -199,6 +224,38 @@ class ProviderRequest:
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
input_other: int = 0
|
||||
"""The number of input tokens, excluding cached tokens."""
|
||||
input_cached: int = 0
|
||||
"""The number of input cached tokens."""
|
||||
output: int = 0
|
||||
"""The number of output tokens."""
|
||||
|
||||
@property
|
||||
def total(self) -> int:
|
||||
return self.input_other + self.input_cached + self.output
|
||||
|
||||
@property
|
||||
def input(self) -> int:
|
||||
return self.input_other + self.input_cached
|
||||
|
||||
def __add__(self, other: TokenUsage) -> TokenUsage:
|
||||
return TokenUsage(
|
||||
input_other=self.input_other + other.input_other,
|
||||
input_cached=self.input_cached + other.input_cached,
|
||||
output=self.output + other.output,
|
||||
)
|
||||
|
||||
def __sub__(self, other: TokenUsage) -> TokenUsage:
|
||||
return TokenUsage(
|
||||
input_other=self.input_other - other.input_other,
|
||||
input_cached=self.input_cached - other.input_cached,
|
||||
output=self.output - other.output,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
@@ -227,6 +284,11 @@ class LLMResponse:
|
||||
is_chunk: bool = False
|
||||
"""Indicates if the response is a chunked response."""
|
||||
|
||||
id: str | None = None
|
||||
"""The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response."""
|
||||
usage: TokenUsage | None = None
|
||||
"""The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: str,
|
||||
@@ -241,6 +303,8 @@ class LLMResponse:
|
||||
| AnthropicMessage
|
||||
| None = None,
|
||||
is_chunk: bool = False,
|
||||
id: str | None = None,
|
||||
usage: TokenUsage | None = None,
|
||||
):
|
||||
"""初始化 LLMResponse
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import traceback
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
@@ -32,10 +33,12 @@ class ProviderManager:
|
||||
persona_mgr: PersonaManager,
|
||||
):
|
||||
self.reload_lock = asyncio.Lock()
|
||||
self.resource_lock = asyncio.Lock()
|
||||
self.persona_mgr = persona_mgr
|
||||
self.acm = acm
|
||||
config = acm.confs["default"]
|
||||
self.providers_config: list = config["provider"]
|
||||
self.provider_sources_config: list = config.get("provider_sources", [])
|
||||
self.provider_settings: dict = config["provider_settings"]
|
||||
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||||
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
||||
@@ -148,6 +151,7 @@ class ProviderManager:
|
||||
|
||||
"""
|
||||
provider = None
|
||||
provider_id = None
|
||||
if umo:
|
||||
provider_id = sp.get(
|
||||
f"provider_perf_{provider_type.value}",
|
||||
@@ -185,6 +189,12 @@ class ProviderManager:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
if not provider and provider_id:
|
||||
logger.warning(
|
||||
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
|
||||
)
|
||||
|
||||
return provider
|
||||
|
||||
async def initialize(self):
|
||||
@@ -251,7 +261,136 @@ class ProviderManager:
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
||||
|
||||
def dynamic_import_provider(self, type: str):
|
||||
"""动态导入提供商适配器模块
|
||||
|
||||
Args:
|
||||
type (str): 提供商请求类型。
|
||||
|
||||
Raises:
|
||||
ImportError: 如果提供商类型未知或无法导入对应模块,则抛出异常。
|
||||
"""
|
||||
match type:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import (
|
||||
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
|
||||
)
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||
case "groq_chat_completion":
|
||||
from .sources.groq_source import ProviderGroq as ProviderGroq
|
||||
case "anthropic_chat_completion":
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
)
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import (
|
||||
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
||||
)
|
||||
case "sensevoice_stt_selfhost":
|
||||
from .sources.sensevoice_selfhosted_source import (
|
||||
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
|
||||
)
|
||||
case "openai_whisper_api":
|
||||
from .sources.whisper_api_source import (
|
||||
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
|
||||
)
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import (
|
||||
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
||||
)
|
||||
case "xinference_stt":
|
||||
from .sources.xinference_stt_provider import (
|
||||
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
||||
)
|
||||
case "openai_tts_api":
|
||||
from .sources.openai_tts_api_source import (
|
||||
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
||||
)
|
||||
case "edge_tts":
|
||||
from .sources.edge_tts_source import (
|
||||
ProviderEdgeTTS as ProviderEdgeTTS,
|
||||
)
|
||||
case "gsv_tts_selfhost":
|
||||
from .sources.gsv_selfhosted_source import (
|
||||
ProviderGSVTTS as ProviderGSVTTS,
|
||||
)
|
||||
case "gsvi_tts_api":
|
||||
from .sources.gsvi_tts_source import (
|
||||
ProviderGSVITTS as ProviderGSVITTS,
|
||||
)
|
||||
case "fishaudio_tts_api":
|
||||
from .sources.fishaudio_tts_api_source import (
|
||||
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
||||
)
|
||||
case "dashscope_tts":
|
||||
from .sources.dashscope_tts import (
|
||||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||
)
|
||||
case "azure_tts":
|
||||
from .sources.azure_tts_source import (
|
||||
AzureTTSProvider as AzureTTSProvider,
|
||||
)
|
||||
case "minimax_tts_api":
|
||||
from .sources.minimax_tts_api_source import (
|
||||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||||
)
|
||||
case "volcengine_tts":
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "gemini_tts":
|
||||
from .sources.gemini_tts_source import (
|
||||
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
)
|
||||
case "gemini_embedding":
|
||||
from .sources.gemini_embedding_source import (
|
||||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||||
)
|
||||
case "vllm_rerank":
|
||||
from .sources.vllm_rerank_source import (
|
||||
VLLMRerankProvider as VLLMRerankProvider,
|
||||
)
|
||||
case "xinference_rerank":
|
||||
from .sources.xinference_rerank_source import (
|
||||
XinferenceRerankProvider as XinferenceRerankProvider,
|
||||
)
|
||||
case "bailian_rerank":
|
||||
from .sources.bailian_rerank_source import (
|
||||
BailianRerankProvider as BailianRerankProvider,
|
||||
)
|
||||
|
||||
def get_merged_provider_config(self, provider_config: dict) -> dict:
|
||||
"""获取 provider 配置和 provider_source 配置合并后的结果
|
||||
|
||||
Returns:
|
||||
dict: 合并后的 provider 配置,key 为 provider id,value 为合并后的配置字典
|
||||
"""
|
||||
pc = copy.deepcopy(provider_config)
|
||||
provider_source_id = pc.get("provider_source_id", "")
|
||||
if provider_source_id:
|
||||
provider_source = None
|
||||
for ps in self.provider_sources_config:
|
||||
if ps.get("id") == provider_source_id:
|
||||
provider_source = ps
|
||||
break
|
||||
|
||||
if provider_source:
|
||||
# 合并配置,provider 的配置优先级更高
|
||||
merged_config = {**provider_source, **pc}
|
||||
# 保持 id 为 provider 的 id,而不是 source 的 id
|
||||
merged_config["id"] = pc["id"]
|
||||
pc = merged_config
|
||||
return pc
|
||||
|
||||
async def load_provider(self, provider_config: dict):
|
||||
# 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并
|
||||
provider_config = self.get_merged_provider_config(provider_config)
|
||||
|
||||
if not provider_config["enable"]:
|
||||
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
|
||||
return
|
||||
@@ -264,99 +403,7 @@ class ProviderManager:
|
||||
|
||||
# 动态导入
|
||||
try:
|
||||
match provider_config["type"]:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import (
|
||||
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
|
||||
)
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||
case "groq_chat_completion":
|
||||
from .sources.groq_source import ProviderGroq as ProviderGroq
|
||||
case "anthropic_chat_completion":
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
)
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import (
|
||||
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
||||
)
|
||||
case "sensevoice_stt_selfhost":
|
||||
from .sources.sensevoice_selfhosted_source import (
|
||||
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
|
||||
)
|
||||
case "openai_whisper_api":
|
||||
from .sources.whisper_api_source import (
|
||||
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
|
||||
)
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import (
|
||||
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
||||
)
|
||||
case "xinference_stt":
|
||||
from .sources.xinference_stt_provider import (
|
||||
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
||||
)
|
||||
case "openai_tts_api":
|
||||
from .sources.openai_tts_api_source import (
|
||||
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
||||
)
|
||||
case "edge_tts":
|
||||
from .sources.edge_tts_source import (
|
||||
ProviderEdgeTTS as ProviderEdgeTTS,
|
||||
)
|
||||
case "gsv_tts_selfhost":
|
||||
from .sources.gsv_selfhosted_source import (
|
||||
ProviderGSVTTS as ProviderGSVTTS,
|
||||
)
|
||||
case "gsvi_tts_api":
|
||||
from .sources.gsvi_tts_source import (
|
||||
ProviderGSVITTS as ProviderGSVITTS,
|
||||
)
|
||||
case "fishaudio_tts_api":
|
||||
from .sources.fishaudio_tts_api_source import (
|
||||
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
||||
)
|
||||
case "dashscope_tts":
|
||||
from .sources.dashscope_tts import (
|
||||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||
)
|
||||
case "azure_tts":
|
||||
from .sources.azure_tts_source import (
|
||||
AzureTTSProvider as AzureTTSProvider,
|
||||
)
|
||||
case "minimax_tts_api":
|
||||
from .sources.minimax_tts_api_source import (
|
||||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||||
)
|
||||
case "volcengine_tts":
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "gemini_tts":
|
||||
from .sources.gemini_tts_source import (
|
||||
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
)
|
||||
case "gemini_embedding":
|
||||
from .sources.gemini_embedding_source import (
|
||||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||||
)
|
||||
case "vllm_rerank":
|
||||
from .sources.vllm_rerank_source import (
|
||||
VLLMRerankProvider as VLLMRerankProvider,
|
||||
)
|
||||
case "xinference_rerank":
|
||||
from .sources.xinference_rerank_source import (
|
||||
XinferenceRerankProvider as XinferenceRerankProvider,
|
||||
)
|
||||
case "bailian_rerank":
|
||||
from .sources.bailian_rerank_source import (
|
||||
BailianRerankProvider as BailianRerankProvider,
|
||||
)
|
||||
self.dynamic_import_provider(provider_config["type"])
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
||||
@@ -499,6 +546,7 @@ class ProviderManager:
|
||||
|
||||
# 和配置文件保持同步
|
||||
self.providers_config = astrbot_config["provider"]
|
||||
self.provider_sources_config = astrbot_config.get("provider_sources", [])
|
||||
config_ids = [provider["id"] for provider in self.providers_config]
|
||||
logger.info(f"providers in user's config: {config_ids}")
|
||||
for key in list(self.inst_map.keys()):
|
||||
@@ -570,6 +618,68 @@ class ProviderManager:
|
||||
)
|
||||
del self.inst_map[provider_id]
|
||||
|
||||
async def delete_provider(
|
||||
self, provider_id: str | None = None, provider_source_id: str | None = None
|
||||
):
|
||||
"""Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion."""
|
||||
async with self.resource_lock:
|
||||
# delete from config
|
||||
target_prov_ids = []
|
||||
if provider_id:
|
||||
target_prov_ids.append(provider_id)
|
||||
else:
|
||||
for prov in self.providers_config:
|
||||
if prov.get("provider_source_id") == provider_source_id:
|
||||
target_prov_ids.append(prov.get("id"))
|
||||
config = self.acm.default_conf
|
||||
for tpid in target_prov_ids:
|
||||
await self.terminate_provider(tpid)
|
||||
config["provider"] = [
|
||||
prov for prov in config["provider"] if prov.get("id") != tpid
|
||||
]
|
||||
config.save_config()
|
||||
logger.info(f"Provider {target_prov_ids} 已从配置中删除。")
|
||||
|
||||
async def update_provider(self, origin_provider_id: str, new_config: dict):
|
||||
"""Update provider config and reload the instance. Config will be saved after update."""
|
||||
async with self.resource_lock:
|
||||
npid = new_config.get("id", None)
|
||||
if not npid:
|
||||
raise ValueError("New provider config must have an 'id' field")
|
||||
config = self.acm.default_conf
|
||||
for provider in config["provider"]:
|
||||
if (
|
||||
provider.get("id", None) == npid
|
||||
and provider.get("id", None) != origin_provider_id
|
||||
):
|
||||
raise ValueError(f"Provider ID {npid} already exists")
|
||||
# update config
|
||||
for idx, provider in enumerate(config["provider"]):
|
||||
if provider.get("id", None) == origin_provider_id:
|
||||
config["provider"][idx] = new_config
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Provider ID {origin_provider_id} not found")
|
||||
config.save_config()
|
||||
# reload instance
|
||||
await self.reload(new_config)
|
||||
|
||||
async def create_provider(self, new_config: dict):
|
||||
"""Add new provider config and load the instance. Config will be saved after addition."""
|
||||
async with self.resource_lock:
|
||||
npid = new_config.get("id", None)
|
||||
if not npid:
|
||||
raise ValueError("New provider config must have an 'id' field")
|
||||
config = self.acm.default_conf
|
||||
for provider in config["provider"]:
|
||||
if provider.get("id", None) == npid:
|
||||
raise ValueError(f"Provider ID {npid} already exists")
|
||||
# add to config
|
||||
config["provider"].append(new_config)
|
||||
config.save_config()
|
||||
# load instance
|
||||
await self.load_provider(new_config)
|
||||
|
||||
async def terminate(self):
|
||||
for provider_inst in self.provider_insts:
|
||||
if hasattr(provider_inst, "terminate"):
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TypeAlias, Union
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.message import ContentPart, Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
@@ -103,6 +103,7 @@ class Provider(AbstractProvider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||
@@ -114,6 +115,7 @@ class Provider(AbstractProvider):
|
||||
tools: tool set
|
||||
contexts: 上下文,和 prompt 二选一使用
|
||||
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||
extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
|
||||
@@ -6,10 +6,13 @@ from mimetypes import guess_type
|
||||
import anthropic
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.message_delta_usage import MessageDeltaUsage
|
||||
from anthropic.types.usage import Usage
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
@@ -45,7 +48,7 @@ class ProviderAnthropic(Provider):
|
||||
base_url=self.base_url,
|
||||
)
|
||||
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
self.set_model(provider_config.get("model", "unknown"))
|
||||
|
||||
def _prepare_payload(self, messages: list[dict]):
|
||||
"""准备 Anthropic API 的请求 payload
|
||||
@@ -66,7 +69,7 @@ class ProviderAnthropic(Provider):
|
||||
blocks = []
|
||||
if isinstance(message["content"], str):
|
||||
blocks.append({"type": "text", "text": message["content"]})
|
||||
if "tool_calls" in message:
|
||||
if "tool_calls" in message and isinstance(message["tool_calls"], list):
|
||||
for tool_call in message["tool_calls"]:
|
||||
blocks.append( # noqa: PERF401
|
||||
{
|
||||
@@ -107,12 +110,35 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
def _extract_usage(self, usage: Usage) -> TokenUsage:
|
||||
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
|
||||
return TokenUsage(
|
||||
input_other=usage.input_tokens or 0,
|
||||
input_cached=usage.cache_read_input_tokens or 0,
|
||||
output=usage.output_tokens,
|
||||
)
|
||||
|
||||
def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
|
||||
if usage.input_tokens is not None:
|
||||
token_usage.input_other = usage.input_tokens
|
||||
if usage.cache_read_input_tokens is not None:
|
||||
token_usage.input_cached = usage.cache_read_input_tokens
|
||||
if usage.output_tokens is not None:
|
||||
token_usage.output = usage.output_tokens
|
||||
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
payloads["tools"] = tool_list
|
||||
|
||||
completion = await self.client.messages.create(**payloads, stream=False)
|
||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
|
||||
if "max_tokens" not in payloads:
|
||||
payloads["max_tokens"] = 1024
|
||||
|
||||
completion = await self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
)
|
||||
|
||||
assert isinstance(completion, Message)
|
||||
logger.debug(f"completion: {completion}")
|
||||
@@ -131,6 +157,10 @@ class ProviderAnthropic(Provider):
|
||||
llm_response.tools_call_args.append(content_block.input)
|
||||
llm_response.tools_call_name.append(content_block.name)
|
||||
llm_response.tools_call_ids.append(content_block.id)
|
||||
|
||||
llm_response.id = completion.id
|
||||
llm_response.usage = self._extract_usage(completion.usage)
|
||||
|
||||
# TODO(Soulter): 处理 end_turn 情况
|
||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
||||
@@ -151,10 +181,22 @@ class ProviderAnthropic(Provider):
|
||||
# 用于累积最终结果
|
||||
final_text = ""
|
||||
final_tool_calls = []
|
||||
id = None
|
||||
usage = TokenUsage()
|
||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
|
||||
async with self.client.messages.stream(**payloads) as stream:
|
||||
if "max_tokens" not in payloads:
|
||||
payloads["max_tokens"] = 1024
|
||||
|
||||
async with self.client.messages.stream(
|
||||
**payloads, extra_body=extra_body
|
||||
) as stream:
|
||||
assert isinstance(stream, anthropic.AsyncMessageStream)
|
||||
async for event in stream:
|
||||
if event.type == "message_start":
|
||||
# the usage contains input token usage
|
||||
id = event.message.id
|
||||
usage = self._extract_usage(event.message.usage)
|
||||
if event.type == "content_block_start":
|
||||
if event.content_block.type == "text":
|
||||
# 文本块开始
|
||||
@@ -162,6 +204,8 @@ class ProviderAnthropic(Provider):
|
||||
role="assistant",
|
||||
completion_text="",
|
||||
is_chunk=True,
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
elif event.content_block.type == "tool_use":
|
||||
# 工具使用块开始,初始化缓冲区
|
||||
@@ -179,6 +223,8 @@ class ProviderAnthropic(Provider):
|
||||
role="assistant",
|
||||
completion_text=event.delta.text,
|
||||
is_chunk=True,
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
elif event.delta.type == "input_json_delta":
|
||||
# 工具调用参数增量
|
||||
@@ -215,6 +261,8 @@ class ProviderAnthropic(Provider):
|
||||
tools_call_name=[tool_info["name"]],
|
||||
tools_call_ids=[tool_info["id"]],
|
||||
is_chunk=True,
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# JSON 解析失败,跳过这个工具调用
|
||||
@@ -223,11 +271,17 @@ class ProviderAnthropic(Provider):
|
||||
# 清理缓冲区
|
||||
del tool_use_buffer[event.index]
|
||||
|
||||
elif event.type == "message_delta":
|
||||
if event.usage:
|
||||
self._update_usage(usage, event.usage)
|
||||
|
||||
# 返回最终的完整结果
|
||||
final_response = LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=final_text,
|
||||
is_chunk=False,
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
|
||||
if final_tool_calls:
|
||||
@@ -249,13 +303,16 @@ class ProviderAnthropic(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -277,10 +334,9 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": new_messages, **model_config}
|
||||
payloads = {"messages": new_messages, "model": model}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
@@ -290,28 +346,30 @@ class ProviderAnthropic(Provider):
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
# logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
prompt=None,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
):
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -332,10 +390,9 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": new_messages, **model_config}
|
||||
payloads = {"messages": new_messages, "model": model}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
@@ -344,15 +401,15 @@ class ProviderAnthropic(Provider):
|
||||
async for llm_response in self._query_stream(payloads, func_tool):
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
||||
async def assemble_context(
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
if not image_urls:
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
|
||||
for image_url in image_urls:
|
||||
async def resolve_image_url(image_url: str) -> dict | None:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
@@ -364,28 +421,68 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
return None
|
||||
|
||||
# Get mime type for the image
|
||||
mime_type, _ = guess_type(image_url)
|
||||
if not mime_type:
|
||||
mime_type = "image/jpeg" # Default to JPEG if can't determine
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
},
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
content = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for block in extra_user_content_parts:
|
||||
if isinstance(block, TextPart):
|
||||
content.append({"type": "text", "text": block.text})
|
||||
elif isinstance(block, ImageURLPart):
|
||||
image_dict = await resolve_image_url(block.image_url.url)
|
||||
if image_dict:
|
||||
content.append(image_dict)
|
||||
else:
|
||||
raise ValueError(f"不支持的额外内容块类型: {type(block)}")
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
image_dict = await resolve_image_url(image_url)
|
||||
if image_dict:
|
||||
content.append(image_dict)
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content) == 1
|
||||
and content[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
|
||||
@@ -56,10 +56,14 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
"api_base",
|
||||
"https://api.fish-audio.cn/v1",
|
||||
)
|
||||
try:
|
||||
self.timeout: int = int(provider_config.get("timeout", 20))
|
||||
except ValueError:
|
||||
self.timeout = 20
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
}
|
||||
self.set_model(provider_config["model"])
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
||||
"""获取角色的reference_id
|
||||
@@ -135,17 +139,21 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
|
||||
self.headers["content-type"] = "application/msgpack"
|
||||
request = await self._generate_request(text)
|
||||
async with AsyncClient(base_url=self.api_base).stream(
|
||||
async with AsyncClient(base_url=self.api_base, timeout=self.timeout).stream(
|
||||
"POST",
|
||||
"/tts",
|
||||
headers=self.headers,
|
||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||
) as response:
|
||||
if response.headers["content-type"] == "audio/wav":
|
||||
if response.status_code == 200 and response.headers.get(
|
||||
"content-type", ""
|
||||
).startswith("audio/"):
|
||||
with open(path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
f.write(chunk)
|
||||
return path
|
||||
body = await response.aread()
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
raise Exception(f"Fish Audio API请求失败: {text}")
|
||||
error_bytes = await response.aread()
|
||||
error_text = error_bytes.decode("utf-8", errors="replace")[:1024]
|
||||
raise Exception(
|
||||
f"Fish Audio API请求失败: 状态码 {response.status_code}, 响应内容: {error_text}"
|
||||
)
|
||||
|
||||
@@ -13,8 +13,9 @@ from google.genai.errors import APIError
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
@@ -68,7 +69,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
self.api_base = self.api_base[:-1]
|
||||
|
||||
self._init_client()
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
self.set_model(provider_config.get("model", "unknown"))
|
||||
self._init_safety_settings()
|
||||
|
||||
def _init_client(self) -> None:
|
||||
@@ -138,7 +139,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
modalities = ["TEXT"]
|
||||
|
||||
tool_list: list[types.Tool] | None = []
|
||||
model_name = self.get_model()
|
||||
model_name = cast(str, payloads.get("model", self.get_model()))
|
||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||
native_search = self.provider_config.get("gm_native_search", False)
|
||||
url_context = self.provider_config.get("gm_url_context", False)
|
||||
@@ -197,6 +198,53 @@ class ProviderGoogleGenAI(Provider):
|
||||
types.Tool(function_declarations=func_desc["function_declarations"]),
|
||||
]
|
||||
|
||||
# oper thinking config
|
||||
thinking_config = None
|
||||
if model_name in [
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-pro-preview",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-preview",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-lite-preview",
|
||||
"gemini-robotics-er-1.5-preview",
|
||||
"gemini-live-2.5-flash-preview-native-audio-09-2025",
|
||||
]:
|
||||
# The thinkingBudget parameter, introduced with the Gemini 2.5 series
|
||||
thinking_budget = self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget", 0
|
||||
)
|
||||
if thinking_budget is not None:
|
||||
thinking_config = types.ThinkingConfig(
|
||||
thinking_budget=thinking_budget,
|
||||
)
|
||||
elif model_name in [
|
||||
"gemini-3-pro",
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-3-flash",
|
||||
"gemini-3-flash-preview",
|
||||
"gemini-3-flash-lite",
|
||||
"gemini-3-flash-lite-preview",
|
||||
]:
|
||||
# The thinkingLevel parameter, recommended for Gemini 3 models and onwards
|
||||
# Gemini 2.5 series models don't support thinkingLevel; use thinkingBudget instead.
|
||||
thinking_level = self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"level", "HIGH"
|
||||
)
|
||||
if thinking_level and isinstance(thinking_level, str):
|
||||
thinking_level = thinking_level.upper()
|
||||
if thinking_level not in ["MINIMAL", "LOW", "MEDIUM", "HIGH"]:
|
||||
logger.warning(
|
||||
f"Invalid thinking level: {thinking_level}, using HIGH"
|
||||
)
|
||||
thinking_level = "HIGH"
|
||||
level = types.ThinkingLevel(thinking_level)
|
||||
thinking_config = types.ThinkingConfig()
|
||||
if not hasattr(types.ThinkingConfig, "thinking_level"):
|
||||
setattr(types.ThinkingConfig, "thinking_level", level)
|
||||
else:
|
||||
thinking_config.thinking_level = level
|
||||
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
@@ -216,22 +264,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
response_modalities=modalities,
|
||||
tools=cast(types.ToolListUnion | None, tool_list),
|
||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||
thinking_config=(
|
||||
types.ThinkingConfig(
|
||||
thinking_budget=min(
|
||||
int(
|
||||
self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget",
|
||||
0,
|
||||
),
|
||||
),
|
||||
24576,
|
||||
),
|
||||
)
|
||||
if "gemini-2.5-flash" in self.get_model()
|
||||
and hasattr(types.ThinkingConfig, "thinking_budget")
|
||||
else None
|
||||
),
|
||||
thinking_config=thinking_config,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True,
|
||||
),
|
||||
@@ -347,6 +380,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
]
|
||||
return "".join(thought_buf).strip()
|
||||
|
||||
def _extract_usage(
|
||||
self, usage_metadata: types.GenerateContentResponseUsageMetadata
|
||||
) -> TokenUsage:
|
||||
"""Extract usage from candidate"""
|
||||
return TokenUsage(
|
||||
input_other=usage_metadata.prompt_token_count or 0,
|
||||
input_cached=usage_metadata.cached_content_token_count or 0,
|
||||
output=usage_metadata.candidates_token_count or 0,
|
||||
)
|
||||
|
||||
def _process_content_parts(
|
||||
self,
|
||||
candidate: types.Candidate,
|
||||
@@ -431,6 +474,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
None,
|
||||
)
|
||||
|
||||
model = payloads.get("model", self.get_model())
|
||||
|
||||
modalities = ["TEXT"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalities.append("IMAGE")
|
||||
@@ -449,7 +494,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
temperature,
|
||||
)
|
||||
result = await self.client.models.generate_content(
|
||||
model=self.get_model(),
|
||||
model=model,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
)
|
||||
@@ -475,11 +520,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
e.message = ""
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
logger.warning(f"{model} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
elif (
|
||||
"Multi-modal output is not supported" in e.message
|
||||
@@ -488,7 +533,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
or "only supports text output" in e.message
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态",
|
||||
f"{model} 不支持多模态输出,降级为文本模态",
|
||||
)
|
||||
modalities = ["TEXT"]
|
||||
else:
|
||||
@@ -501,6 +546,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
result.candidates[0],
|
||||
llm_response,
|
||||
)
|
||||
llm_response.id = result.response_id
|
||||
if result.usage_metadata:
|
||||
llm_response.usage = self._extract_usage(result.usage_metadata)
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
@@ -513,7 +561,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
model = payloads.get("model", self.get_model())
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
result = None
|
||||
@@ -525,7 +573,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_instruction,
|
||||
)
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=self.get_model(),
|
||||
model=model,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
)
|
||||
@@ -535,11 +583,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
e.message = ""
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
logger.warning(f"{model} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
else:
|
||||
raise
|
||||
@@ -569,6 +617,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
chunk.candidates[0],
|
||||
llm_response,
|
||||
)
|
||||
llm_response.id = chunk.response_id
|
||||
if chunk.usage_metadata:
|
||||
llm_response.usage = self._extract_usage(chunk.usage_metadata)
|
||||
yield llm_response
|
||||
return
|
||||
|
||||
@@ -596,6 +647,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
chunk.candidates[0],
|
||||
final_response,
|
||||
)
|
||||
final_response.id = chunk.response_id
|
||||
if chunk.usage_metadata:
|
||||
final_response.usage = self._extract_usage(chunk.usage_metadata)
|
||||
break
|
||||
|
||||
# Yield final complete response with accumulated text
|
||||
@@ -627,13 +681,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -652,10 +709,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
@@ -680,13 +736,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -705,10 +764,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
@@ -746,33 +804,75 @@ class ProviderGoogleGenAI(Provider):
|
||||
self.chosen_api_key = key
|
||||
self._init_client()
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
||||
async def assemble_context(
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
):
|
||||
"""组装上下文。"""
|
||||
if image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
|
||||
async def resolve_image_part(image_url: str) -> dict | None:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
return None
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content_blocks.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for part in extra_user_content_parts:
|
||||
if isinstance(part, TextPart):
|
||||
content_blocks.append({"type": "text", "text": part.text})
|
||||
elif isinstance(part, ImageURLPart):
|
||||
image_part = await resolve_image_part(part.image_url.url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": text}
|
||||
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
image_part = await resolve_image_part(image_url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -12,14 +12,15 @@ from openai._exceptions import NotFoundError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.message import ContentPart, ImageURLPart, Message, TextPart
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
@@ -68,8 +69,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.client.chat.completions.create,
|
||||
).parameters.keys()
|
||||
|
||||
model_config = provider_config.get("model_config", {})
|
||||
model = model_config.get("model", "unknown")
|
||||
model = provider_config.get("model", "unknown")
|
||||
self.set_model(model)
|
||||
|
||||
self.reasoning_key = "reasoning_content"
|
||||
@@ -208,6 +208,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# handle the content delta
|
||||
reasoning = self._extract_reasoning_content(chunk)
|
||||
_y = False
|
||||
llm_response.id = chunk.id
|
||||
if reasoning:
|
||||
llm_response.reasoning_content = reasoning
|
||||
_y = True
|
||||
@@ -217,6 +218,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
chain=[Comp.Plain(completion_text)],
|
||||
)
|
||||
_y = True
|
||||
if chunk.usage:
|
||||
llm_response.usage = self._extract_usage(chunk.usage)
|
||||
if _y:
|
||||
yield llm_response
|
||||
|
||||
@@ -245,6 +248,15 @@ class ProviderOpenAIOfficial(Provider):
|
||||
reasoning_text = str(reasoning_attr)
|
||||
return reasoning_text
|
||||
|
||||
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
|
||||
ptd = usage.prompt_tokens_details
|
||||
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
|
||||
return TokenUsage(
|
||||
input_other=usage.prompt_tokens - cached,
|
||||
input_cached=ptd.cached_tokens if ptd and ptd.cached_tokens else 0,
|
||||
output=usage.completion_tokens,
|
||||
)
|
||||
|
||||
async def _parse_openai_completion(
|
||||
self, completion: ChatCompletion, tools: ToolSet | None
|
||||
) -> LLMResponse:
|
||||
@@ -321,6 +333,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
llm_response.raw_completion = completion
|
||||
llm_response.id = completion.id
|
||||
|
||||
if completion.usage:
|
||||
llm_response.usage = self._extract_usage(completion.usage)
|
||||
|
||||
return llm_response
|
||||
|
||||
@@ -332,6 +348,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
@@ -339,7 +356,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -358,10 +377,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
# xAI origin search tool inject
|
||||
self._maybe_inject_xai_search(payloads, **kwargs)
|
||||
@@ -461,6 +479,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
payloads, context_query = await self._prepare_chat_payload(
|
||||
@@ -470,6 +489,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
model=model,
|
||||
extra_user_content_parts=extra_user_content_parts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -609,33 +629,71 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
) -> dict:
|
||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||
if image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
|
||||
async def resolve_image_part(image_url: str) -> dict | None:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
return None
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content_blocks.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for part in extra_user_content_parts:
|
||||
if isinstance(part, TextPart):
|
||||
content_blocks.append({"type": "text", "text": part.text})
|
||||
elif isinstance(part, ImageURLPart):
|
||||
image_part = await resolve_image_part(part.image_url.url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": text}
|
||||
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
image_part = await resolve_image_part(image_url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -2,15 +2,19 @@ from astrbot.core import html_renderer
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.star.star_tools import StarTools
|
||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
|
||||
|
||||
from .context import Context
|
||||
from .star import StarMetadata, star_map, star_registry
|
||||
from .star_manager import PluginManager
|
||||
|
||||
|
||||
class Star(CommandParserMixin):
|
||||
class Star(CommandParserMixin, PluginKVStoreMixin):
|
||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||
|
||||
author: str
|
||||
name: str
|
||||
|
||||
def __init__(self, context: Context, config: dict | None = None):
|
||||
StarTools.initialize(context)
|
||||
self.context = context
|
||||
|
||||
@@ -0,0 +1,496 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core import db_helper, logger
|
||||
from astrbot.core.db.po import CommandConfig
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandDescriptor:
|
||||
handler: StarHandlerMetadata = field(repr=False)
|
||||
filter_ref: CommandFilter | CommandGroupFilter | None = field(
|
||||
default=None,
|
||||
repr=False,
|
||||
)
|
||||
handler_full_name: str = ""
|
||||
handler_name: str = ""
|
||||
plugin_name: str = ""
|
||||
plugin_display_name: str | None = None
|
||||
module_path: str = ""
|
||||
description: str = ""
|
||||
command_type: str = "command" # "command" | "group" | "sub_command"
|
||||
raw_command_name: str | None = None
|
||||
current_fragment: str | None = None
|
||||
parent_signature: str = ""
|
||||
parent_group_handler: str = ""
|
||||
original_command: str | None = None
|
||||
effective_command: str | None = None
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
permission: str = "everyone"
|
||||
enabled: bool = True
|
||||
is_group: bool = False
|
||||
is_sub_command: bool = False
|
||||
reserved: bool = False
|
||||
config: CommandConfig | None = None
|
||||
has_conflict: bool = False
|
||||
sub_commands: list[CommandDescriptor] = field(default_factory=list)
|
||||
|
||||
|
||||
async def sync_command_configs() -> None:
|
||||
"""同步指令配置,清理过期配置。"""
|
||||
descriptors = _collect_descriptors(include_sub_commands=False)
|
||||
config_records = await db_helper.get_command_configs()
|
||||
config_map = _bind_configs_to_descriptors(descriptors, config_records)
|
||||
live_handlers = {desc.handler_full_name for desc in descriptors}
|
||||
|
||||
stale_configs = [key for key in config_map if key not in live_handlers]
|
||||
if stale_configs:
|
||||
await db_helper.delete_command_configs(stale_configs)
|
||||
|
||||
|
||||
async def toggle_command(handler_full_name: str, enabled: bool) -> CommandDescriptor:
|
||||
descriptor = _build_descriptor_by_full_name(handler_full_name)
|
||||
if not descriptor:
|
||||
raise ValueError("指定的处理函数不存在或不是指令。")
|
||||
|
||||
existing_cfg = await db_helper.get_command_config(handler_full_name)
|
||||
config = await db_helper.upsert_command_config(
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=descriptor.plugin_name or "",
|
||||
module_path=descriptor.module_path,
|
||||
original_command=descriptor.original_command or descriptor.handler_name,
|
||||
resolved_command=(
|
||||
existing_cfg.resolved_command
|
||||
if existing_cfg
|
||||
else descriptor.current_fragment
|
||||
),
|
||||
enabled=enabled,
|
||||
keep_original_alias=False,
|
||||
conflict_key=existing_cfg.conflict_key
|
||||
if existing_cfg and existing_cfg.conflict_key
|
||||
else descriptor.original_command,
|
||||
resolution_strategy=existing_cfg.resolution_strategy if existing_cfg else None,
|
||||
note=existing_cfg.note if existing_cfg else None,
|
||||
extra_data=existing_cfg.extra_data if existing_cfg else None,
|
||||
auto_managed=False,
|
||||
)
|
||||
_bind_descriptor_with_config(descriptor, config)
|
||||
await sync_command_configs()
|
||||
return descriptor
|
||||
|
||||
|
||||
async def rename_command(
|
||||
handler_full_name: str,
|
||||
new_fragment: str,
|
||||
aliases: list[str] | None = None,
|
||||
) -> CommandDescriptor:
|
||||
descriptor = _build_descriptor_by_full_name(handler_full_name)
|
||||
if not descriptor:
|
||||
raise ValueError("指定的处理函数不存在或不是指令。")
|
||||
|
||||
new_fragment = new_fragment.strip()
|
||||
if not new_fragment:
|
||||
raise ValueError("指令名不能为空。")
|
||||
|
||||
# 校验主指令名
|
||||
candidate_full = _compose_command(descriptor.parent_signature, new_fragment)
|
||||
if _is_command_in_use(handler_full_name, candidate_full):
|
||||
raise ValueError(f"指令名 '{candidate_full}' 已被其他指令占用。")
|
||||
|
||||
# 校验别名
|
||||
if aliases:
|
||||
for alias in aliases:
|
||||
alias = alias.strip()
|
||||
if not alias:
|
||||
continue
|
||||
alias_full = _compose_command(descriptor.parent_signature, alias)
|
||||
if _is_command_in_use(handler_full_name, alias_full):
|
||||
raise ValueError(f"别名 '{alias_full}' 已被其他指令占用。")
|
||||
|
||||
existing_cfg = await db_helper.get_command_config(handler_full_name)
|
||||
merged_extra = dict(existing_cfg.extra_data or {}) if existing_cfg else {}
|
||||
merged_extra["resolved_aliases"] = aliases or []
|
||||
|
||||
config = await db_helper.upsert_command_config(
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=descriptor.plugin_name or "",
|
||||
module_path=descriptor.module_path,
|
||||
original_command=descriptor.original_command or descriptor.handler_name,
|
||||
resolved_command=new_fragment,
|
||||
enabled=True if descriptor.enabled else False,
|
||||
keep_original_alias=False,
|
||||
conflict_key=descriptor.original_command,
|
||||
resolution_strategy="manual_rename",
|
||||
note=None,
|
||||
extra_data=merged_extra,
|
||||
auto_managed=False,
|
||||
)
|
||||
_bind_descriptor_with_config(descriptor, config)
|
||||
|
||||
await sync_command_configs()
|
||||
return descriptor
|
||||
|
||||
|
||||
async def list_commands() -> list[dict[str, Any]]:
|
||||
descriptors = _collect_descriptors(include_sub_commands=True)
|
||||
config_records = await db_helper.get_command_configs()
|
||||
_bind_configs_to_descriptors(descriptors, config_records)
|
||||
|
||||
conflict_groups = _group_conflicts(descriptors)
|
||||
conflict_handler_names: set[str] = {
|
||||
d.handler_full_name for group in conflict_groups.values() for d in group
|
||||
}
|
||||
|
||||
# 分类,设置冲突标志,将子指令挂载到父指令组
|
||||
group_map: dict[str, CommandDescriptor] = {}
|
||||
sub_commands: list[CommandDescriptor] = []
|
||||
root_commands: list[CommandDescriptor] = []
|
||||
|
||||
for desc in descriptors:
|
||||
desc.has_conflict = desc.handler_full_name in conflict_handler_names
|
||||
if desc.is_group:
|
||||
group_map[desc.handler_full_name] = desc
|
||||
elif desc.is_sub_command:
|
||||
sub_commands.append(desc)
|
||||
else:
|
||||
root_commands.append(desc)
|
||||
|
||||
for sub in sub_commands:
|
||||
if sub.parent_group_handler and sub.parent_group_handler in group_map:
|
||||
group_map[sub.parent_group_handler].sub_commands.append(sub)
|
||||
else:
|
||||
root_commands.append(sub)
|
||||
|
||||
# 指令组 + 普通指令,按 effective_command 字母排序
|
||||
all_commands = list(group_map.values()) + root_commands
|
||||
all_commands.sort(key=lambda d: (d.effective_command or "").lower())
|
||||
|
||||
result = [_descriptor_to_dict(desc) for desc in all_commands]
|
||||
return result
|
||||
|
||||
|
||||
async def list_command_conflicts() -> list[dict[str, Any]]:
|
||||
"""列出所有冲突的指令组。"""
|
||||
descriptors = _collect_descriptors(include_sub_commands=False)
|
||||
config_records = await db_helper.get_command_configs()
|
||||
_bind_configs_to_descriptors(descriptors, config_records)
|
||||
|
||||
conflict_groups = _group_conflicts(descriptors)
|
||||
details = [
|
||||
{
|
||||
"conflict_key": key,
|
||||
"handlers": [
|
||||
{
|
||||
"handler_full_name": item.handler_full_name,
|
||||
"plugin": item.plugin_name,
|
||||
"current_name": item.effective_command,
|
||||
}
|
||||
for item in group
|
||||
],
|
||||
}
|
||||
for key, group in conflict_groups.items()
|
||||
]
|
||||
return details
|
||||
|
||||
|
||||
# Internal helpers ----------------------------------------------------------
|
||||
|
||||
|
||||
def _collect_descriptors(include_sub_commands: bool) -> list[CommandDescriptor]:
|
||||
"""收集指令,按需包含子指令。"""
|
||||
descriptors: list[CommandDescriptor] = []
|
||||
for handler in star_handlers_registry:
|
||||
try:
|
||||
desc = _build_descriptor(handler)
|
||||
if not desc:
|
||||
continue
|
||||
if not include_sub_commands and desc.is_sub_command:
|
||||
continue
|
||||
descriptors.append(desc)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"解析指令处理函数 {handler.handler_full_name} 失败,跳过该指令。原因: {e!s}"
|
||||
)
|
||||
continue
|
||||
return descriptors
|
||||
|
||||
|
||||
def _build_descriptor(handler: StarHandlerMetadata) -> CommandDescriptor | None:
|
||||
filter_ref = _locate_primary_filter(handler)
|
||||
if filter_ref is None:
|
||||
return None
|
||||
|
||||
plugin_meta = star_map.get(handler.handler_module_path)
|
||||
plugin_name = (
|
||||
plugin_meta.name if plugin_meta else None
|
||||
) or handler.handler_module_path
|
||||
plugin_display = plugin_meta.display_name if plugin_meta else None
|
||||
|
||||
is_sub_command = bool(handler.extras_configs.get("sub_command"))
|
||||
parent_group_handler = ""
|
||||
|
||||
if isinstance(filter_ref, CommandFilter):
|
||||
raw_fragment = getattr(
|
||||
filter_ref, "_original_command_name", filter_ref.command_name
|
||||
)
|
||||
current_fragment = filter_ref.command_name
|
||||
parent_signature = (filter_ref.parent_command_names or [""])[0].strip()
|
||||
# 如果是子指令,尝试找到父指令组的 handler_full_name
|
||||
if is_sub_command and parent_signature:
|
||||
parent_group_handler = _find_parent_group_handler(
|
||||
handler.handler_module_path, parent_signature
|
||||
)
|
||||
else:
|
||||
raw_fragment = getattr(
|
||||
filter_ref, "_original_group_name", filter_ref.group_name
|
||||
)
|
||||
current_fragment = filter_ref.group_name
|
||||
parent_signature = _resolve_group_parent_signature(filter_ref)
|
||||
|
||||
original_command = _compose_command(parent_signature, raw_fragment)
|
||||
effective_command = _compose_command(parent_signature, current_fragment)
|
||||
|
||||
# 确定 command_type
|
||||
if isinstance(filter_ref, CommandGroupFilter):
|
||||
command_type = "group"
|
||||
elif is_sub_command:
|
||||
command_type = "sub_command"
|
||||
else:
|
||||
command_type = "command"
|
||||
|
||||
descriptor = CommandDescriptor(
|
||||
handler=handler,
|
||||
filter_ref=filter_ref,
|
||||
handler_full_name=handler.handler_full_name,
|
||||
handler_name=handler.handler_name,
|
||||
plugin_name=plugin_name,
|
||||
plugin_display_name=plugin_display,
|
||||
module_path=handler.handler_module_path,
|
||||
description=handler.desc or "",
|
||||
command_type=command_type,
|
||||
raw_command_name=raw_fragment,
|
||||
current_fragment=current_fragment,
|
||||
parent_signature=parent_signature,
|
||||
parent_group_handler=parent_group_handler,
|
||||
original_command=original_command,
|
||||
effective_command=effective_command,
|
||||
aliases=sorted(getattr(filter_ref, "alias", set())),
|
||||
permission=_determine_permission(handler),
|
||||
enabled=handler.enabled,
|
||||
is_group=isinstance(filter_ref, CommandGroupFilter),
|
||||
is_sub_command=is_sub_command,
|
||||
reserved=plugin_meta.reserved if plugin_meta else False,
|
||||
)
|
||||
return descriptor
|
||||
|
||||
|
||||
def _build_descriptor_by_full_name(full_name: str) -> CommandDescriptor | None:
|
||||
handler = star_handlers_registry.get_handler_by_full_name(full_name)
|
||||
if not handler:
|
||||
return None
|
||||
return _build_descriptor(handler)
|
||||
|
||||
|
||||
def _locate_primary_filter(
|
||||
handler: StarHandlerMetadata,
|
||||
) -> CommandFilter | CommandGroupFilter | None:
|
||||
for filter_ref in handler.event_filters:
|
||||
if isinstance(filter_ref, (CommandFilter, CommandGroupFilter)):
|
||||
return filter_ref
|
||||
return None
|
||||
|
||||
|
||||
def _determine_permission(handler: StarHandlerMetadata) -> str:
|
||||
for filter_ref in handler.event_filters:
|
||||
if isinstance(filter_ref, PermissionTypeFilter):
|
||||
return (
|
||||
"admin"
|
||||
if filter_ref.permission_type == PermissionType.ADMIN
|
||||
else "member"
|
||||
)
|
||||
return "everyone"
|
||||
|
||||
|
||||
def _resolve_group_parent_signature(group_filter: CommandGroupFilter) -> str:
|
||||
signatures: list[str] = []
|
||||
parent = group_filter.parent_group
|
||||
while parent:
|
||||
signatures.append(getattr(parent, "_original_group_name", parent.group_name))
|
||||
parent = parent.parent_group
|
||||
return " ".join(reversed(signatures)).strip()
|
||||
|
||||
|
||||
def _find_parent_group_handler(module_path: str, parent_signature: str) -> str:
|
||||
"""根据模块路径和父级签名,找到对应的指令组 handler_full_name。"""
|
||||
parent_sig_normalized = parent_signature.strip()
|
||||
for handler in star_handlers_registry:
|
||||
if handler.handler_module_path != module_path:
|
||||
continue
|
||||
filter_ref = _locate_primary_filter(handler)
|
||||
if not isinstance(filter_ref, CommandGroupFilter):
|
||||
continue
|
||||
# 检查该指令组的完整指令名是否匹配 parent_signature
|
||||
group_names = filter_ref.get_complete_command_names()
|
||||
if parent_sig_normalized in group_names:
|
||||
return handler.handler_full_name
|
||||
return ""
|
||||
|
||||
|
||||
def _compose_command(parent_signature: str, fragment: str | None) -> str:
|
||||
fragment = (fragment or "").strip()
|
||||
parent_signature = parent_signature.strip()
|
||||
if not parent_signature:
|
||||
return fragment
|
||||
if not fragment:
|
||||
return parent_signature
|
||||
return f"{parent_signature} {fragment}"
|
||||
|
||||
|
||||
def _bind_descriptor_with_config(
|
||||
descriptor: CommandDescriptor,
|
||||
config: CommandConfig,
|
||||
) -> None:
|
||||
_apply_config_to_descriptor(descriptor, config)
|
||||
_apply_config_to_runtime(descriptor, config)
|
||||
|
||||
|
||||
def _apply_config_to_descriptor(
|
||||
descriptor: CommandDescriptor,
|
||||
config: CommandConfig,
|
||||
) -> None:
|
||||
descriptor.config = config
|
||||
descriptor.enabled = config.enabled
|
||||
|
||||
if config.original_command:
|
||||
descriptor.original_command = config.original_command
|
||||
|
||||
new_fragment = config.resolved_command or descriptor.current_fragment
|
||||
descriptor.current_fragment = new_fragment
|
||||
descriptor.effective_command = _compose_command(
|
||||
descriptor.parent_signature,
|
||||
new_fragment,
|
||||
)
|
||||
|
||||
extra = config.extra_data or {}
|
||||
resolved_aliases = extra.get("resolved_aliases")
|
||||
if isinstance(resolved_aliases, list):
|
||||
descriptor.aliases = [str(x) for x in resolved_aliases if str(x).strip()]
|
||||
|
||||
|
||||
def _apply_config_to_runtime(
|
||||
descriptor: CommandDescriptor,
|
||||
config: CommandConfig,
|
||||
) -> None:
|
||||
descriptor.handler.enabled = config.enabled
|
||||
if descriptor.filter_ref:
|
||||
if descriptor.current_fragment:
|
||||
_set_filter_fragment(descriptor.filter_ref, descriptor.current_fragment)
|
||||
extra = config.extra_data or {}
|
||||
resolved_aliases = extra.get("resolved_aliases")
|
||||
if isinstance(resolved_aliases, list):
|
||||
_set_filter_aliases(
|
||||
descriptor.filter_ref,
|
||||
[str(x) for x in resolved_aliases if str(x).strip()],
|
||||
)
|
||||
|
||||
|
||||
def _bind_configs_to_descriptors(
|
||||
descriptors: list[CommandDescriptor],
|
||||
config_records: list[CommandConfig],
|
||||
) -> dict[str, CommandConfig]:
|
||||
config_map = {cfg.handler_full_name: cfg for cfg in config_records}
|
||||
for desc in descriptors:
|
||||
if cfg := config_map.get(desc.handler_full_name):
|
||||
_bind_descriptor_with_config(desc, cfg)
|
||||
return config_map
|
||||
|
||||
|
||||
def _group_conflicts(
|
||||
descriptors: list[CommandDescriptor],
|
||||
) -> dict[str, list[CommandDescriptor]]:
|
||||
conflicts: dict[str, list[CommandDescriptor]] = defaultdict(list)
|
||||
for desc in descriptors:
|
||||
if desc.effective_command and desc.enabled:
|
||||
conflicts[desc.effective_command].append(desc)
|
||||
return {k: v for k, v in conflicts.items() if len(v) > 1}
|
||||
|
||||
|
||||
def _set_filter_fragment(
|
||||
filter_ref: CommandFilter | CommandGroupFilter,
|
||||
fragment: str,
|
||||
) -> None:
|
||||
attr = (
|
||||
"group_name" if isinstance(filter_ref, CommandGroupFilter) else "command_name"
|
||||
)
|
||||
current_value = getattr(filter_ref, attr)
|
||||
if fragment == current_value:
|
||||
return
|
||||
setattr(filter_ref, attr, fragment)
|
||||
if hasattr(filter_ref, "_cmpl_cmd_names"):
|
||||
filter_ref._cmpl_cmd_names = None
|
||||
|
||||
|
||||
def _set_filter_aliases(
|
||||
filter_ref: CommandFilter | CommandGroupFilter,
|
||||
aliases: list[str],
|
||||
) -> None:
|
||||
current_aliases = getattr(filter_ref, "alias", set())
|
||||
if set(aliases) == current_aliases:
|
||||
return
|
||||
setattr(filter_ref, "alias", set(aliases))
|
||||
if hasattr(filter_ref, "_cmpl_cmd_names"):
|
||||
filter_ref._cmpl_cmd_names = None
|
||||
|
||||
|
||||
def _is_command_in_use(
|
||||
target_handler_full_name: str,
|
||||
candidate_full_command: str,
|
||||
) -> bool:
|
||||
candidate = candidate_full_command.strip()
|
||||
for handler in star_handlers_registry:
|
||||
if handler.handler_full_name == target_handler_full_name:
|
||||
continue
|
||||
filter_ref = _locate_primary_filter(handler)
|
||||
if not filter_ref:
|
||||
continue
|
||||
names = {name.strip() for name in filter_ref.get_complete_command_names()}
|
||||
if candidate in names:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _descriptor_to_dict(desc: CommandDescriptor) -> dict[str, Any]:
|
||||
result = {
|
||||
"handler_full_name": desc.handler_full_name,
|
||||
"handler_name": desc.handler_name,
|
||||
"plugin": desc.plugin_name,
|
||||
"plugin_display_name": desc.plugin_display_name,
|
||||
"module_path": desc.module_path,
|
||||
"description": desc.description,
|
||||
"type": desc.command_type,
|
||||
"parent_signature": desc.parent_signature,
|
||||
"parent_group_handler": desc.parent_group_handler,
|
||||
"original_command": desc.original_command,
|
||||
"current_fragment": desc.current_fragment,
|
||||
"effective_command": desc.effective_command,
|
||||
"aliases": desc.aliases,
|
||||
"permission": desc.permission,
|
||||
"enabled": desc.enabled,
|
||||
"is_group": desc.is_group,
|
||||
"has_conflict": desc.has_conflict,
|
||||
"reserved": desc.reserved,
|
||||
}
|
||||
# 如果是指令组,包含子指令列表
|
||||
if desc.is_group and desc.sub_commands:
|
||||
result["sub_commands"] = [_descriptor_to_dict(sub) for sub in desc.sub_commands]
|
||||
else:
|
||||
result["sub_commands"] = []
|
||||
return result
|
||||
@@ -267,6 +267,10 @@ class Context:
|
||||
):
|
||||
"""通过 ID 获取对应的 LLM Provider。"""
|
||||
prov = self.provider_manager.inst_map.get(provider_id)
|
||||
if provider_id and not prov:
|
||||
logger.warning(
|
||||
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
|
||||
)
|
||||
return prov
|
||||
|
||||
def get_all_providers(self) -> list[Provider]:
|
||||
@@ -373,7 +377,7 @@ class Context:
|
||||
if not module_path:
|
||||
_parts = []
|
||||
module_part = tool.__module__.split(".")
|
||||
flags = ["packages", "plugins"]
|
||||
flags = ["builtin_stars", "plugins"]
|
||||
for i, part in enumerate(module_part):
|
||||
_parts.append(part)
|
||||
if part in flags and i + 1 < len(module_part):
|
||||
|
||||
@@ -40,6 +40,7 @@ class CommandFilter(HandlerFilter):
|
||||
):
|
||||
self.command_name = command_name
|
||||
self.alias = alias if alias else set()
|
||||
self._original_command_name = command_name
|
||||
self.parent_command_names = (
|
||||
parent_command_names if parent_command_names is not None else [""]
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
):
|
||||
self.group_name = group_name
|
||||
self.alias = alias if alias else set()
|
||||
self._original_group_name = group_name
|
||||
self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = []
|
||||
self.custom_filter_list: list[CustomFilter] = []
|
||||
self.parent_group = parent_group
|
||||
|
||||
@@ -118,6 +118,8 @@ class StarHandlerRegistry(Generic[T]):
|
||||
# 过滤事件类型
|
||||
if handler.event_type != event_type:
|
||||
continue
|
||||
if not handler.enabled:
|
||||
continue
|
||||
# 过滤启用状态
|
||||
if only_activated:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
@@ -220,6 +222,8 @@ class StarHandlerMetadata(Generic[H]):
|
||||
extras_configs: dict = field(default_factory=dict)
|
||||
"""插件注册的一些其他的信息, 如 priority 等"""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
def __lt__(self, other: StarHandlerMetadata):
|
||||
"""定义小于运算符以支持优先队列"""
|
||||
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
||||
|
||||
@@ -18,11 +18,13 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_path,
|
||||
get_astrbot_plugin_path,
|
||||
)
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
|
||||
from . import StarMetadata
|
||||
from .command_management import sync_command_configs
|
||||
from .context import Context
|
||||
from .filter.permission import PermissionType, PermissionTypeFilter
|
||||
from .star import star_map, star_registry
|
||||
@@ -48,13 +50,10 @@ class PluginManager:
|
||||
"""存储插件的路径。即 data/plugins"""
|
||||
self.plugin_config_path = get_astrbot_config_path()
|
||||
"""存储插件配置的路径。data/config"""
|
||||
self.reserved_plugin_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../../../packages",
|
||||
),
|
||||
self.reserved_plugin_path = os.path.join(
|
||||
get_astrbot_path(), "astrbot", "builtin_stars"
|
||||
)
|
||||
"""保留插件的路径。在 packages 目录下"""
|
||||
"""保留插件的路径。在 astrbot/builtin_stars 目录下"""
|
||||
self.conf_schema_fname = "_conf_schema.json"
|
||||
self.logo_fname = "logo.png"
|
||||
"""插件配置 Schema 文件名"""
|
||||
@@ -251,7 +250,7 @@ class PluginManager:
|
||||
list[str]: 与该插件相关的模块名列表
|
||||
|
||||
"""
|
||||
prefix = "packages." if is_reserved else "data.plugins."
|
||||
prefix = "astrbot.builtin_stars." if is_reserved else "data.plugins."
|
||||
return [
|
||||
key
|
||||
for key in list(sys.modules.keys())
|
||||
@@ -269,7 +268,7 @@ class PluginManager:
|
||||
可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存
|
||||
|
||||
Args:
|
||||
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "packages"])
|
||||
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "astrbot.builtin_stars"])
|
||||
root_dir_name: 插件根目录名,用于移除与该插件相关的所有模块
|
||||
is_reserved: 插件是否为保留插件(影响模块路径前缀)
|
||||
|
||||
@@ -381,9 +380,9 @@ class PluginManager:
|
||||
reserved = plugin_module.get(
|
||||
"reserved",
|
||||
False,
|
||||
) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
|
||||
) # 是否是保留插件。目前在 astrbot/builtin_stars 目录下的都是保留插件。保留插件不可以卸载。
|
||||
|
||||
path = "data.plugins." if not reserved else "packages."
|
||||
path = "data.plugins." if not reserved else "astrbot.builtin_stars."
|
||||
path += root_dir_name + "." + module_str
|
||||
|
||||
# 检查是否需要载入指定的插件
|
||||
@@ -467,6 +466,18 @@ class PluginManager:
|
||||
metadata.star_cls = metadata.star_cls_type(
|
||||
context=self.context,
|
||||
)
|
||||
|
||||
p_name = (metadata.name or "unknown").lower().replace("/", "_")
|
||||
p_author = (
|
||||
(metadata.author or "unknown").lower().replace("/", "_")
|
||||
)
|
||||
setattr(metadata.star_cls, "name", p_name)
|
||||
setattr(metadata.star_cls, "author", p_author)
|
||||
setattr(
|
||||
metadata.star_cls,
|
||||
"plugin_id",
|
||||
f"{p_author}/{p_name}",
|
||||
)
|
||||
else:
|
||||
logger.info(f"插件 {metadata.name} 已被禁用。")
|
||||
|
||||
@@ -618,6 +629,11 @@ class PluginManager:
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
try:
|
||||
await sync_command_configs()
|
||||
except Exception as e:
|
||||
logger.error(f"同步指令配置失败: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if not fail_rec:
|
||||
return True, None
|
||||
@@ -811,7 +827,7 @@ class PluginManager:
|
||||
if (
|
||||
mp
|
||||
and mp.startswith(plugin_module_path)
|
||||
and not mp.endswith(("packages", "data.plugins"))
|
||||
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||
):
|
||||
to_remove.append(func_tool)
|
||||
for func_tool in to_remove:
|
||||
@@ -866,7 +882,7 @@ class PluginManager:
|
||||
plugin.module_path
|
||||
and mp
|
||||
and plugin.module_path.startswith(mp)
|
||||
and not mp.endswith(("packages", "data.plugins"))
|
||||
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||
):
|
||||
func_tool.active = False
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
@@ -915,7 +931,7 @@ class PluginManager:
|
||||
plugin.module_path
|
||||
and mp
|
||||
and plugin.module_path.startswith(mp)
|
||||
and not mp.endswith(("packages", "data.plugins"))
|
||||
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||
and func_tool.name in inactivated_llm_tools
|
||||
):
|
||||
inactivated_llm_tools.remove(func_tool.name)
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
数据目录路径:固定为根目录下的 data 目录
|
||||
配置文件路径:固定为数据目录下的 config 目录
|
||||
插件目录路径:固定为数据目录下的 plugins 目录
|
||||
插件数据目录路径:固定为数据目录下的 plugin_data 目录
|
||||
T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录
|
||||
WebChat 数据目录路径:固定为数据目录下的 webchat 目录
|
||||
临时文件目录路径:固定为数据目录下的 temp 目录
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -37,3 +41,33 @@ def get_astrbot_config_path() -> str:
|
||||
def get_astrbot_plugin_path() -> str:
|
||||
"""获取Astrbot插件目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
|
||||
|
||||
|
||||
def get_astrbot_plugin_data_path() -> str:
|
||||
"""获取Astrbot插件数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data"))
|
||||
|
||||
|
||||
def get_astrbot_t2i_templates_path() -> str:
|
||||
"""获取Astrbot T2I 模板目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "t2i_templates"))
|
||||
|
||||
|
||||
def get_astrbot_webchat_path() -> str:
|
||||
"""获取Astrbot WebChat 数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "webchat"))
|
||||
|
||||
|
||||
def get_astrbot_temp_path() -> str:
|
||||
"""获取Astrbot临时文件目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp"))
|
||||
|
||||
|
||||
def get_astrbot_knowledge_base_path() -> str:
|
||||
"""获取Astrbot知识库根目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base"))
|
||||
|
||||
|
||||
def get_astrbot_backups_path() -> str:
|
||||
"""获取Astrbot备份目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "backups"))
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
import aiohttp
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class LLMModalities(TypedDict):
|
||||
input: list[Literal["text", "image", "audio", "video"]]
|
||||
output: list[Literal["text", "image", "audio", "video"]]
|
||||
|
||||
|
||||
class LLMLimit(TypedDict):
|
||||
context: int
|
||||
output: int
|
||||
|
||||
|
||||
class LLMMetadata(TypedDict):
|
||||
id: str
|
||||
reasoning: bool
|
||||
tool_call: bool
|
||||
knowledge: str
|
||||
release_date: str
|
||||
modalities: LLMModalities
|
||||
open_weights: bool
|
||||
limit: LLMLimit
|
||||
|
||||
|
||||
LLM_METADATAS: dict[str, LLMMetadata] = {}
|
||||
|
||||
|
||||
async def update_llm_metadata():
|
||||
url = "https://models.dev/api.json"
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
data = await response.json()
|
||||
global LLM_METADATAS
|
||||
models = {}
|
||||
for info in data.values():
|
||||
for model in info.get("models", {}).values():
|
||||
model_id = model.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
models[model_id] = LLMMetadata(
|
||||
id=model_id,
|
||||
reasoning=model.get("reasoning", False),
|
||||
tool_call=model.get("tool_call", False),
|
||||
knowledge=model.get("knowledge", "none"),
|
||||
release_date=model.get("release_date", ""),
|
||||
modalities=model.get(
|
||||
"modalities", {"input": [], "output": []}
|
||||
),
|
||||
open_weights=model.get("open_weights", False),
|
||||
limit=model.get("limit", {"context": 0, "output": 0}),
|
||||
)
|
||||
# Replace the global cache in-place so references remain valid
|
||||
LLM_METADATAS.clear()
|
||||
LLM_METADATAS.update(models)
|
||||
logger.info(f"Successfully fetched metadata for {len(models)} LLMs.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch LLM metadata: {e}")
|
||||
return
|
||||
@@ -32,6 +32,92 @@ def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def _migra_provider_to_source_structure(conf: AstrBotConfig) -> None:
|
||||
"""
|
||||
Migrate old provider structure to new provider-source separation.
|
||||
Provider only keeps: id, provider_source_id, model, modalities, custom_extra_body
|
||||
All other fields move to provider_sources.
|
||||
"""
|
||||
providers = conf.get("provider", [])
|
||||
provider_sources = conf.get("provider_sources", [])
|
||||
|
||||
# Track if any migration happened
|
||||
migrated = False
|
||||
|
||||
# Provider-only fields that should stay in provider
|
||||
provider_only_fields = {
|
||||
"id",
|
||||
"provider_source_id",
|
||||
"model",
|
||||
"modalities",
|
||||
"custom_extra_body",
|
||||
"enable",
|
||||
}
|
||||
|
||||
# Fields that should not go to source
|
||||
source_exclude_fields = provider_only_fields | {"model_config"}
|
||||
|
||||
for provider in providers:
|
||||
# Skip if already has provider_source_id
|
||||
if provider.get("provider_source_id"):
|
||||
continue
|
||||
|
||||
# Skip non-chat-completion types (they don't need source separation)
|
||||
provider_type = provider.get("provider_type", "")
|
||||
if provider_type != "chat_completion":
|
||||
# For old types without provider_type, check type field
|
||||
old_type = provider.get("type", "")
|
||||
if "chat_completion" not in old_type:
|
||||
continue
|
||||
|
||||
migrated = True
|
||||
logger.info(f"Migrating provider {provider.get('id')} to new structure")
|
||||
|
||||
# Extract source fields from provider
|
||||
source_fields = {}
|
||||
for key, value in list(provider.items()):
|
||||
if key not in source_exclude_fields:
|
||||
source_fields[key] = value
|
||||
|
||||
# Create new provider_source
|
||||
source_id = provider.get("id", "") + "_source"
|
||||
new_source = {"id": source_id, **source_fields}
|
||||
|
||||
# Update provider to only keep necessary fields
|
||||
provider["provider_source_id"] = source_id
|
||||
|
||||
# Extract model from model_config if exists
|
||||
if "model_config" in provider and isinstance(provider["model_config"], dict):
|
||||
model_config = provider["model_config"]
|
||||
provider["model"] = model_config.get("model", "")
|
||||
|
||||
# Put other model_config fields into custom_extra_body
|
||||
extra_body_fields = {k: v for k, v in model_config.items() if k != "model"}
|
||||
if extra_body_fields:
|
||||
if "custom_extra_body" not in provider:
|
||||
provider["custom_extra_body"] = {}
|
||||
provider["custom_extra_body"].update(extra_body_fields)
|
||||
|
||||
# Initialize new fields if not present
|
||||
if "modalities" not in provider:
|
||||
provider["modalities"] = []
|
||||
if "custom_extra_body" not in provider:
|
||||
provider["custom_extra_body"] = {}
|
||||
|
||||
# Remove fields that should be in source
|
||||
keys_to_remove = [k for k in provider.keys() if k not in provider_only_fields]
|
||||
for key in keys_to_remove:
|
||||
del provider[key]
|
||||
|
||||
# Add source to provider_sources
|
||||
provider_sources.append(new_source)
|
||||
|
||||
if migrated:
|
||||
conf["provider_sources"] = provider_sources
|
||||
conf.save_config()
|
||||
logger.info("Provider-source structure migration completed")
|
||||
|
||||
|
||||
async def migra(
|
||||
db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager
|
||||
) -> None:
|
||||
@@ -71,3 +157,10 @@ async def migra(
|
||||
|
||||
for conf in acm.confs.values():
|
||||
_migra_agent_runner_configs(conf, ids_map)
|
||||
|
||||
# Migrate providers to new structure: extract source fields to provider_sources
|
||||
try:
|
||||
_migra_provider_to_source_structure(astrbot_config)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for provider-source structure failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
from typing import TypeVar
|
||||
|
||||
from astrbot.core import sp
|
||||
|
||||
SUPPORTED_VALUE_TYPES = int | float | str | bytes | bool | dict | list | None
|
||||
_VT = TypeVar("_VT")
|
||||
|
||||
|
||||
class PluginKVStoreMixin:
|
||||
"""为插件提供键值存储功能的 Mixin 类"""
|
||||
|
||||
plugin_id: str
|
||||
|
||||
async def put_kv_data(
|
||||
self,
|
||||
key: str,
|
||||
value: SUPPORTED_VALUE_TYPES,
|
||||
) -> None:
|
||||
"""为指定插件存储一个键值对"""
|
||||
await sp.put_async("plugin", self.plugin_id, key, value)
|
||||
|
||||
async def get_kv_data(self, key: str, default: _VT) -> _VT | None:
|
||||
"""获取指定插件存储的键值对"""
|
||||
return await sp.get_async("plugin", self.plugin_id, key, default)
|
||||
|
||||
async def delete_kv_data(self, key: str) -> None:
|
||||
"""删除指定插件存储的键值对"""
|
||||
await sp.remove_async("plugin", self.plugin_id, key)
|
||||
@@ -1,5 +1,7 @@
|
||||
from .auth import AuthRoute
|
||||
from .backup import BackupRoute
|
||||
from .chat import ChatRoute
|
||||
from .command import CommandRoute
|
||||
from .config import ConfigRoute
|
||||
from .conversation import ConversationRoute
|
||||
from .file import FileRoute
|
||||
@@ -16,7 +18,9 @@ from .update import UpdateRoute
|
||||
|
||||
__all__ = [
|
||||
"AuthRoute",
|
||||
"BackupRoute",
|
||||
"ChatRoute",
|
||||
"CommandRoute",
|
||||
"ConfigRoute",
|
||||
"ConversationRoute",
|
||||
"FileRoute",
|
||||
|
||||
@@ -0,0 +1,589 @@
|
||||
"""备份管理 API 路由"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from quart import request, send_file
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.backup.exporter import AstrBotExporter
|
||||
from astrbot.core.backup.importer import AstrBotImporter
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_backups_path,
|
||||
get_astrbot_data_path,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
def secure_filename(filename: str) -> str:
|
||||
"""清洗文件名,移除路径遍历字符和危险字符
|
||||
|
||||
Args:
|
||||
filename: 原始文件名
|
||||
|
||||
Returns:
|
||||
安全的文件名
|
||||
"""
|
||||
# 跨平台处理:先将反斜杠替换为正斜杠,再取文件名
|
||||
filename = filename.replace("\\", "/")
|
||||
# 仅保留文件名部分,移除路径
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
# 替换路径遍历字符
|
||||
filename = filename.replace("..", "_")
|
||||
|
||||
# 仅保留字母、数字、下划线、连字符、点
|
||||
filename = re.sub(r"[^\w\-.]", "_", filename)
|
||||
|
||||
# 移除前导点(隐藏文件)和尾部点
|
||||
filename = filename.strip(".")
|
||||
|
||||
# 如果文件名为空或只包含下划线,生成一个默认名称
|
||||
if not filename or filename.replace("_", "") == "":
|
||||
filename = "backup"
|
||||
|
||||
return filename
|
||||
|
||||
|
||||
def generate_unique_filename(original_filename: str) -> str:
|
||||
"""生成唯一的文件名,添加时间戳前缀
|
||||
|
||||
Args:
|
||||
original_filename: 原始文件名(已清洗)
|
||||
|
||||
Returns:
|
||||
唯一的文件名
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
name, ext = os.path.splitext(original_filename)
|
||||
return f"uploaded_{timestamp}_{name}{ext}"
|
||||
|
||||
|
||||
class BackupRoute(Route):
|
||||
"""备份管理路由
|
||||
|
||||
提供备份导出、导入、列表等 API 接口
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
db: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.db = db
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.backup_dir = get_astrbot_backups_path()
|
||||
self.data_dir = get_astrbot_data_path()
|
||||
|
||||
# 任务状态跟踪
|
||||
self.backup_tasks: dict[str, dict] = {}
|
||||
self.backup_progress: dict[str, dict] = {}
|
||||
|
||||
# 注册路由
|
||||
self.routes = {
|
||||
"/backup/list": ("GET", self.list_backups),
|
||||
"/backup/export": ("POST", self.export_backup),
|
||||
"/backup/upload": ("POST", self.upload_backup), # 上传文件
|
||||
"/backup/check": ("POST", self.check_backup), # 预检查
|
||||
"/backup/import": ("POST", self.import_backup), # 确认导入
|
||||
"/backup/progress": ("GET", self.get_progress),
|
||||
"/backup/download": ("GET", self.download_backup),
|
||||
"/backup/delete": ("POST", self.delete_backup),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
def _init_task(self, task_id: str, task_type: str, status: str = "pending") -> None:
|
||||
"""初始化任务状态"""
|
||||
self.backup_tasks[task_id] = {
|
||||
"type": task_type,
|
||||
"status": status,
|
||||
"result": None,
|
||||
"error": None,
|
||||
}
|
||||
self.backup_progress[task_id] = {
|
||||
"status": status,
|
||||
"stage": "waiting",
|
||||
"current": 0,
|
||||
"total": 100,
|
||||
"message": "",
|
||||
}
|
||||
|
||||
def _set_task_result(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
result: dict | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""设置任务结果"""
|
||||
if task_id in self.backup_tasks:
|
||||
self.backup_tasks[task_id]["status"] = status
|
||||
self.backup_tasks[task_id]["result"] = result
|
||||
self.backup_tasks[task_id]["error"] = error
|
||||
if task_id in self.backup_progress:
|
||||
self.backup_progress[task_id]["status"] = status
|
||||
|
||||
def _update_progress(
|
||||
self,
|
||||
task_id: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
stage: str | None = None,
|
||||
current: int | None = None,
|
||||
total: int | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""更新任务进度"""
|
||||
if task_id not in self.backup_progress:
|
||||
return
|
||||
p = self.backup_progress[task_id]
|
||||
if status is not None:
|
||||
p["status"] = status
|
||||
if stage is not None:
|
||||
p["stage"] = stage
|
||||
if current is not None:
|
||||
p["current"] = current
|
||||
if total is not None:
|
||||
p["total"] = total
|
||||
if message is not None:
|
||||
p["message"] = message
|
||||
|
||||
def _make_progress_callback(self, task_id: str):
|
||||
"""创建进度回调函数"""
|
||||
|
||||
async def _callback(stage: str, current: int, total: int, message: str = ""):
|
||||
self._update_progress(
|
||||
task_id,
|
||||
status="processing",
|
||||
stage=stage,
|
||||
current=current,
|
||||
total=total,
|
||||
message=message,
|
||||
)
|
||||
|
||||
return _callback
|
||||
|
||||
async def list_backups(self):
|
||||
"""获取备份列表
|
||||
|
||||
Query 参数:
|
||||
- page: 页码 (默认 1)
|
||||
- page_size: 每页数量 (默认 20)
|
||||
"""
|
||||
try:
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 20, type=int)
|
||||
|
||||
# 确保备份目录存在
|
||||
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 获取所有备份文件
|
||||
backup_files = []
|
||||
for filename in os.listdir(self.backup_dir):
|
||||
if filename.endswith(".zip") and filename.startswith("astrbot_backup_"):
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
stat = os.stat(file_path)
|
||||
backup_files.append(
|
||||
{
|
||||
"filename": filename,
|
||||
"size": stat.st_size,
|
||||
"created_at": stat.st_mtime,
|
||||
}
|
||||
)
|
||||
|
||||
# 按创建时间倒序排序
|
||||
backup_files.sort(key=lambda x: x["created_at"], reverse=True)
|
||||
|
||||
# 分页
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
items = backup_files[start:end]
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"items": items,
|
||||
"total": len(backup_files),
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取备份列表失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取备份列表失败: {e!s}").__dict__
|
||||
|
||||
async def export_backup(self):
|
||||
"""创建备份
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID,用于查询导出进度
|
||||
"""
|
||||
try:
|
||||
# 生成任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 初始化任务状态
|
||||
self._init_task(task_id, "export", "pending")
|
||||
|
||||
# 启动后台导出任务
|
||||
asyncio.create_task(self._background_export_task(task_id))
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"message": "export task created, processing in background",
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"创建备份失败: {e!s}").__dict__
|
||||
|
||||
async def _background_export_task(self, task_id: str):
|
||||
"""后台导出任务"""
|
||||
try:
|
||||
self._update_progress(task_id, status="processing", message="正在初始化...")
|
||||
|
||||
# 获取知识库管理器
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
|
||||
exporter = AstrBotExporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
|
||||
# 创建进度回调
|
||||
progress_callback = self._make_progress_callback(task_id)
|
||||
|
||||
# 执行导出
|
||||
zip_path = await exporter.export_all(
|
||||
output_dir=self.backup_dir,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# 设置成功结果
|
||||
self._set_task_result(
|
||||
task_id,
|
||||
"completed",
|
||||
result={
|
||||
"filename": os.path.basename(zip_path),
|
||||
"path": zip_path,
|
||||
"size": os.path.getsize(zip_path),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"后台导出任务 {task_id} 失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._set_task_result(task_id, "failed", error=str(e))
|
||||
|
||||
async def upload_backup(self):
|
||||
"""上传备份文件
|
||||
|
||||
将备份文件上传到服务器,返回保存的文件名。
|
||||
上传后应调用 check_backup 进行预检查。
|
||||
|
||||
Form Data:
|
||||
- file: 备份文件 (.zip)
|
||||
|
||||
返回:
|
||||
- filename: 保存的文件名
|
||||
"""
|
||||
try:
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return Response().error("缺少备份文件").__dict__
|
||||
|
||||
file = files["file"]
|
||||
if not file.filename or not file.filename.endswith(".zip"):
|
||||
return Response().error("请上传 ZIP 格式的备份文件").__dict__
|
||||
|
||||
# 清洗文件名并生成唯一名称,防止路径遍历和覆盖
|
||||
safe_filename = secure_filename(file.filename)
|
||||
unique_filename = generate_unique_filename(safe_filename)
|
||||
|
||||
# 保存上传的文件
|
||||
Path(self.backup_dir).mkdir(parents=True, exist_ok=True)
|
||||
zip_path = os.path.join(self.backup_dir, unique_filename)
|
||||
await file.save(zip_path)
|
||||
|
||||
logger.info(
|
||||
f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})"
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"filename": unique_filename,
|
||||
"original_filename": file.filename,
|
||||
"size": os.path.getsize(zip_path),
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"上传备份文件失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"上传备份文件失败: {e!s}").__dict__
|
||||
|
||||
async def check_backup(self):
|
||||
"""预检查备份文件
|
||||
|
||||
检查备份文件的版本兼容性,返回确认信息。
|
||||
用户确认后调用 import_backup 执行导入。
|
||||
|
||||
JSON Body:
|
||||
- filename: 已上传的备份文件名
|
||||
|
||||
返回:
|
||||
- ImportPreCheckResult: 预检查结果
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
filename = data.get("filename")
|
||||
if not filename:
|
||||
return Response().error("缺少 filename 参数").__dict__
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
zip_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(zip_path):
|
||||
return Response().error(f"备份文件不存在: {filename}").__dict__
|
||||
|
||||
# 获取知识库管理器(用于构造 importer)
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
|
||||
importer = AstrBotImporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
|
||||
# 执行预检查
|
||||
check_result = importer.pre_check(zip_path)
|
||||
|
||||
return Response().ok(check_result.to_dict()).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"预检查备份文件失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"预检查备份文件失败: {e!s}").__dict__
|
||||
|
||||
async def import_backup(self):
|
||||
"""执行备份导入
|
||||
|
||||
在用户确认后执行实际的导入操作。
|
||||
需要先调用 upload_backup 上传文件,再调用 check_backup 预检查。
|
||||
|
||||
JSON Body:
|
||||
- filename: 已上传的备份文件名(必填)
|
||||
- confirmed: 用户已确认(必填,必须为 true)
|
||||
|
||||
返回:
|
||||
- task_id: 任务ID,用于查询导入进度
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
filename = data.get("filename")
|
||||
confirmed = data.get("confirmed", False)
|
||||
|
||||
if not filename:
|
||||
return Response().error("缺少 filename 参数").__dict__
|
||||
|
||||
if not confirmed:
|
||||
return (
|
||||
Response()
|
||||
.error("请先确认导入。导入将会清空并覆盖现有数据,此操作不可撤销。")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
zip_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(zip_path):
|
||||
return Response().error(f"备份文件不存在: {filename}").__dict__
|
||||
|
||||
# 生成任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# 初始化任务状态
|
||||
self._init_task(task_id, "import", "pending")
|
||||
|
||||
# 启动后台导入任务
|
||||
asyncio.create_task(self._background_import_task(task_id, zip_path))
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"task_id": task_id,
|
||||
"message": "import task created, processing in background",
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"导入备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"导入备份失败: {e!s}").__dict__
|
||||
|
||||
async def _background_import_task(self, task_id: str, zip_path: str):
|
||||
"""后台导入任务"""
|
||||
try:
|
||||
self._update_progress(task_id, status="processing", message="正在初始化...")
|
||||
|
||||
# 获取知识库管理器
|
||||
kb_manager = getattr(self.core_lifecycle, "kb_manager", None)
|
||||
|
||||
importer = AstrBotImporter(
|
||||
main_db=self.db,
|
||||
kb_manager=kb_manager,
|
||||
config_path=os.path.join(self.data_dir, "cmd_config.json"),
|
||||
)
|
||||
|
||||
# 创建进度回调
|
||||
progress_callback = self._make_progress_callback(task_id)
|
||||
|
||||
# 执行导入
|
||||
result = await importer.import_all(
|
||||
zip_path=zip_path,
|
||||
mode="replace",
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# 设置结果
|
||||
if result.success:
|
||||
self._set_task_result(
|
||||
task_id,
|
||||
"completed",
|
||||
result=result.to_dict(),
|
||||
)
|
||||
else:
|
||||
self._set_task_result(
|
||||
task_id,
|
||||
"failed",
|
||||
error="; ".join(result.errors),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"后台导入任务 {task_id} 失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
self._set_task_result(task_id, "failed", error=str(e))
|
||||
|
||||
async def get_progress(self):
|
||||
"""获取任务进度
|
||||
|
||||
Query 参数:
|
||||
- task_id: 任务 ID (必填)
|
||||
"""
|
||||
try:
|
||||
task_id = request.args.get("task_id")
|
||||
if not task_id:
|
||||
return Response().error("缺少参数 task_id").__dict__
|
||||
|
||||
if task_id not in self.backup_tasks:
|
||||
return Response().error("找不到该任务").__dict__
|
||||
|
||||
task_info = self.backup_tasks[task_id]
|
||||
status = task_info["status"]
|
||||
|
||||
response_data = {
|
||||
"task_id": task_id,
|
||||
"type": task_info["type"],
|
||||
"status": status,
|
||||
}
|
||||
|
||||
# 如果任务正在处理,返回进度信息
|
||||
if status == "processing" and task_id in self.backup_progress:
|
||||
response_data["progress"] = self.backup_progress[task_id]
|
||||
|
||||
# 如果任务完成,返回结果
|
||||
if status == "completed":
|
||||
response_data["result"] = task_info["result"]
|
||||
|
||||
# 如果任务失败,返回错误信息
|
||||
if status == "failed":
|
||||
response_data["error"] = task_info["error"]
|
||||
|
||||
return Response().ok(response_data).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"获取任务进度失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取任务进度失败: {e!s}").__dict__
|
||||
|
||||
async def download_backup(self):
|
||||
"""下载备份文件
|
||||
|
||||
Query 参数:
|
||||
- filename: 备份文件名 (必填)
|
||||
"""
|
||||
try:
|
||||
filename = request.args.get("filename")
|
||||
if not filename:
|
||||
return Response().error("缺少参数 filename").__dict__
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
return Response().error("备份文件不存在").__dict__
|
||||
|
||||
return await send_file(
|
||||
file_path,
|
||||
as_attachment=True,
|
||||
attachment_filename=filename,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"下载备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"下载备份失败: {e!s}").__dict__
|
||||
|
||||
async def delete_backup(self):
|
||||
"""删除备份文件
|
||||
|
||||
Body:
|
||||
- filename: 备份文件名 (必填)
|
||||
"""
|
||||
try:
|
||||
data = await request.json
|
||||
filename = data.get("filename")
|
||||
if not filename:
|
||||
return Response().error("缺少参数 filename").__dict__
|
||||
|
||||
# 安全检查 - 防止路径遍历
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
return Response().error("无效的文件名").__dict__
|
||||
|
||||
file_path = os.path.join(self.backup_dir, filename)
|
||||
if not os.path.exists(file_path):
|
||||
return Response().error("备份文件不存在").__dict__
|
||||
|
||||
os.remove(file_path)
|
||||
return Response().ok(message="删除备份成功").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"删除备份失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除备份失败: {e!s}").__dict__
|
||||
@@ -227,16 +227,19 @@ class ChatRoute(Route):
|
||||
text: str,
|
||||
media_parts: list,
|
||||
reasoning: str,
|
||||
agent_stats: dict,
|
||||
):
|
||||
"""保存 bot 消息到历史记录,返回保存的记录"""
|
||||
bot_message_parts = []
|
||||
bot_message_parts.extend(media_parts)
|
||||
if text:
|
||||
bot_message_parts.append({"type": "plain", "text": text})
|
||||
bot_message_parts.extend(media_parts)
|
||||
|
||||
new_his = {"type": "bot", "message": bot_message_parts}
|
||||
if reasoning:
|
||||
new_his["reasoning"] = reasoning
|
||||
if agent_stats:
|
||||
new_his["agent_stats"] = agent_stats
|
||||
|
||||
record = await self.platform_history_mgr.insert(
|
||||
platform_id="webchat",
|
||||
@@ -294,7 +297,8 @@ class ChatRoute(Route):
|
||||
accumulated_parts = []
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
|
||||
tool_calls = {}
|
||||
agent_stats = {}
|
||||
try:
|
||||
async with track_conversation(self.running_convs, webchat_conv_id):
|
||||
while True:
|
||||
@@ -314,6 +318,16 @@ class ChatRoute(Route):
|
||||
result_text = result["data"]
|
||||
msg_type = result.get("type")
|
||||
streaming = result.get("streaming", False)
|
||||
chain_type = result.get("chain_type")
|
||||
|
||||
if chain_type == "agent_stats":
|
||||
stats_info = {
|
||||
"type": "agent_stats",
|
||||
"data": json.loads(result_text),
|
||||
}
|
||||
yield f"data: {json.dumps(stats_info, ensure_ascii=False)}\n\n"
|
||||
agent_stats = stats_info["data"]
|
||||
continue
|
||||
|
||||
# 发送 SSE 数据
|
||||
try:
|
||||
@@ -335,11 +349,35 @@ class ChatRoute(Route):
|
||||
|
||||
# 累积消息部分
|
||||
if msg_type == "plain":
|
||||
chain_type = result.get("chain_type", "normal")
|
||||
if chain_type == "reasoning":
|
||||
chain_type = result.get("chain_type")
|
||||
if chain_type == "tool_call":
|
||||
tool_call = json.loads(result_text)
|
||||
tool_calls[tool_call.get("id")] = tool_call
|
||||
if accumulated_text:
|
||||
# 如果累积了文本,则先保存文本
|
||||
accumulated_parts.append(
|
||||
{"type": "plain", "text": accumulated_text}
|
||||
)
|
||||
accumulated_text = ""
|
||||
elif chain_type == "tool_call_result":
|
||||
tcr = json.loads(result_text)
|
||||
tc_id = tcr.get("id")
|
||||
if tc_id in tool_calls:
|
||||
tool_calls[tc_id]["result"] = tcr.get("result")
|
||||
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
|
||||
accumulated_parts.append(
|
||||
{
|
||||
"type": "tool_call",
|
||||
"tool_calls": [tool_calls[tc_id]],
|
||||
}
|
||||
)
|
||||
tool_calls.pop(tc_id, None)
|
||||
elif chain_type == "reasoning":
|
||||
accumulated_reasoning += result_text
|
||||
else:
|
||||
elif streaming:
|
||||
accumulated_text += result_text
|
||||
else:
|
||||
accumulated_text = result_text
|
||||
elif msg_type == "image":
|
||||
filename = result_text.replace("[IMAGE]", "")
|
||||
part = await self._create_attachment_from_file(
|
||||
@@ -367,15 +405,20 @@ class ChatRoute(Route):
|
||||
if msg_type == "end":
|
||||
break
|
||||
elif (
|
||||
(streaming and msg_type == "complete")
|
||||
or not streaming
|
||||
or msg_type == "break"
|
||||
(streaming and msg_type == "complete") or not streaming
|
||||
# or msg_type == "break"
|
||||
):
|
||||
if (
|
||||
chain_type == "tool_call"
|
||||
or chain_type == "tool_call_result"
|
||||
):
|
||||
continue
|
||||
saved_record = await self._save_bot_message(
|
||||
webchat_conv_id,
|
||||
accumulated_text,
|
||||
accumulated_parts,
|
||||
accumulated_reasoning,
|
||||
agent_stats,
|
||||
)
|
||||
# 发送保存的消息信息给前端
|
||||
if saved_record and not client_disconnected:
|
||||
@@ -390,11 +433,11 @@ class ChatRoute(Route):
|
||||
yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n"
|
||||
except Exception:
|
||||
pass
|
||||
# 重置累积变量 (对于 break 后的下一段消息)
|
||||
if msg_type == "break":
|
||||
accumulated_parts = []
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
accumulated_parts = []
|
||||
accumulated_text = ""
|
||||
accumulated_reasoning = ""
|
||||
# tool_calls = {}
|
||||
agent_stats = {}
|
||||
except BaseException as e:
|
||||
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
||||
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
from quart import request
|
||||
|
||||
from astrbot.core.star.command_management import (
|
||||
list_command_conflicts,
|
||||
list_commands,
|
||||
)
|
||||
from astrbot.core.star.command_management import (
|
||||
rename_command as rename_command_service,
|
||||
)
|
||||
from astrbot.core.star.command_management import (
|
||||
toggle_command as toggle_command_service,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
class CommandRoute(Route):
|
||||
def __init__(self, context: RouteContext) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
"/commands": ("GET", self.get_commands),
|
||||
"/commands/conflicts": ("GET", self.get_conflicts),
|
||||
"/commands/toggle": ("POST", self.toggle_command),
|
||||
"/commands/rename": ("POST", self.rename_command),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
async def get_commands(self):
|
||||
commands = await list_commands()
|
||||
summary = {
|
||||
"total": len(commands),
|
||||
"disabled": len([cmd for cmd in commands if not cmd["enabled"]]),
|
||||
"conflicts": len([cmd for cmd in commands if cmd.get("has_conflict")]),
|
||||
}
|
||||
return Response().ok({"items": commands, "summary": summary}).__dict__
|
||||
|
||||
async def get_conflicts(self):
|
||||
conflicts = await list_command_conflicts()
|
||||
return Response().ok(conflicts).__dict__
|
||||
|
||||
async def toggle_command(self):
|
||||
data = await request.get_json()
|
||||
handler_full_name = data.get("handler_full_name")
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if handler_full_name is None or enabled is None:
|
||||
return Response().error("handler_full_name 与 enabled 均为必填。").__dict__
|
||||
|
||||
if isinstance(enabled, str):
|
||||
enabled = enabled.lower() in ("1", "true", "yes", "on")
|
||||
|
||||
try:
|
||||
await toggle_command_service(handler_full_name, bool(enabled))
|
||||
except ValueError as exc:
|
||||
return Response().error(str(exc)).__dict__
|
||||
|
||||
payload = await _get_command_payload(handler_full_name)
|
||||
return Response().ok(payload).__dict__
|
||||
|
||||
async def rename_command(self):
|
||||
data = await request.get_json()
|
||||
handler_full_name = data.get("handler_full_name")
|
||||
new_name = data.get("new_name")
|
||||
aliases = data.get("aliases")
|
||||
|
||||
if not handler_full_name or not new_name:
|
||||
return Response().error("handler_full_name 与 new_name 均为必填。").__dict__
|
||||
|
||||
try:
|
||||
await rename_command_service(handler_full_name, new_name, aliases=aliases)
|
||||
except ValueError as exc:
|
||||
return Response().error(str(exc)).__dict__
|
||||
|
||||
payload = await _get_command_payload(handler_full_name)
|
||||
return Response().ok(payload).__dict__
|
||||
|
||||
|
||||
async def _get_command_payload(handler_full_name: str):
|
||||
commands = await list_commands()
|
||||
for cmd in commands:
|
||||
if cmd["handler_full_name"] == handler_full_name:
|
||||
return cmd
|
||||
return {}
|
||||
@@ -6,7 +6,7 @@ from typing import Any
|
||||
|
||||
from quart import request
|
||||
|
||||
from astrbot.core import file_token_service, logger
|
||||
from astrbot.core import astrbot_config, file_token_service, logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.config.default import (
|
||||
CONFIG_METADATA_2,
|
||||
@@ -21,6 +21,7 @@ from astrbot.core.platform.register import platform_cls_map, platform_registry
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core.utils.llm_metadata import LLM_METADATAS
|
||||
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
@@ -179,13 +180,157 @@ class ConfigRoute(Route):
|
||||
"/config/provider/new": ("POST", self.post_new_provider),
|
||||
"/config/provider/update": ("POST", self.post_update_provider),
|
||||
"/config/provider/delete": ("POST", self.post_delete_provider),
|
||||
"/config/provider/template": ("GET", self.get_provider_template),
|
||||
"/config/provider/check_one": ("GET", self.check_one_provider_status),
|
||||
"/config/provider/list": ("GET", self.get_provider_config_list),
|
||||
"/config/provider/model_list": ("GET", self.get_provider_model_list),
|
||||
"/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim),
|
||||
"/config/provider_sources/models": (
|
||||
"GET",
|
||||
self.get_provider_source_models,
|
||||
),
|
||||
"/config/provider_sources/update": (
|
||||
"POST",
|
||||
self.update_provider_source,
|
||||
),
|
||||
"/config/provider_sources/delete": (
|
||||
"POST",
|
||||
self.delete_provider_source,
|
||||
),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
async def delete_provider_source(self):
|
||||
"""删除 provider_source,并更新关联的 providers"""
|
||||
post_data = await request.json
|
||||
if not post_data:
|
||||
return Response().error("缺少配置数据").__dict__
|
||||
|
||||
provider_source_id = post_data.get("id")
|
||||
if not provider_source_id:
|
||||
return Response().error("缺少 provider_source_id").__dict__
|
||||
|
||||
provider_sources = self.config.get("provider_sources", [])
|
||||
target_idx = next(
|
||||
(
|
||||
i
|
||||
for i, ps in enumerate(provider_sources)
|
||||
if ps.get("id") == provider_source_id
|
||||
),
|
||||
-1,
|
||||
)
|
||||
|
||||
if target_idx == -1:
|
||||
return Response().error("未找到对应的 provider source").__dict__
|
||||
|
||||
# 删除 provider_source
|
||||
del provider_sources[target_idx]
|
||||
|
||||
# 写回配置
|
||||
self.config["provider_sources"] = provider_sources
|
||||
|
||||
# 删除引用了该 provider_source 的 providers
|
||||
await self.core_lifecycle.provider_manager.delete_provider(
|
||||
provider_source_id=provider_source_id
|
||||
)
|
||||
|
||||
try:
|
||||
save_config(self.config, self.config, is_core=True)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
return Response().ok(message="删除 provider source 成功").__dict__
|
||||
|
||||
async def update_provider_source(self):
|
||||
"""更新或新增 provider_source,并重载关联的 providers"""
|
||||
post_data = await request.json
|
||||
if not post_data:
|
||||
return Response().error("缺少配置数据").__dict__
|
||||
|
||||
new_source_config = post_data.get("config") or post_data
|
||||
original_id = post_data.get("original_id")
|
||||
if not original_id:
|
||||
return Response().error("缺少 original_id").__dict__
|
||||
|
||||
if not isinstance(new_source_config, dict):
|
||||
return Response().error("缺少或错误的配置数据").__dict__
|
||||
|
||||
# 确保配置中有 id 字段
|
||||
if not new_source_config.get("id"):
|
||||
new_source_config["id"] = original_id
|
||||
|
||||
provider_sources = self.config.get("provider_sources", [])
|
||||
|
||||
for ps in provider_sources:
|
||||
if ps.get("id") == new_source_config["id"] and ps.get("id") != original_id:
|
||||
return (
|
||||
Response()
|
||||
.error(
|
||||
f"Provider source ID '{new_source_config['id']}' exists already, please try another ID.",
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 查找旧的 provider_source,若不存在则追加为新配置
|
||||
target_idx = next(
|
||||
(i for i, ps in enumerate(provider_sources) if ps.get("id") == original_id),
|
||||
-1,
|
||||
)
|
||||
|
||||
old_id = original_id
|
||||
if target_idx == -1:
|
||||
provider_sources.append(new_source_config)
|
||||
else:
|
||||
old_id = provider_sources[target_idx].get("id")
|
||||
provider_sources[target_idx] = new_source_config
|
||||
|
||||
# 更新引用了该 provider_source 的 providers
|
||||
affected_providers = []
|
||||
for provider in self.config.get("provider", []):
|
||||
if provider.get("provider_source_id") == old_id:
|
||||
provider["provider_source_id"] = new_source_config["id"]
|
||||
affected_providers.append(provider)
|
||||
|
||||
# 写回配置
|
||||
self.config["provider_sources"] = provider_sources
|
||||
|
||||
try:
|
||||
save_config(self.config, self.config, is_core=True)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
# 重载受影响的 providers,使新的 source 配置生效
|
||||
reload_errors = []
|
||||
prov_mgr = self.core_lifecycle.provider_manager
|
||||
for provider in affected_providers:
|
||||
try:
|
||||
await prov_mgr.reload(provider)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
reload_errors.append(f"{provider.get('id')}: {e}")
|
||||
|
||||
if reload_errors:
|
||||
return (
|
||||
Response()
|
||||
.error("更新成功,但部分提供商重载失败: " + ", ".join(reload_errors))
|
||||
.__dict__
|
||||
)
|
||||
|
||||
return Response().ok(message="更新 provider source 成功").__dict__
|
||||
|
||||
async def get_provider_template(self):
|
||||
config_schema = {
|
||||
"provider": CONFIG_METADATA_2["provider_group"]["metadata"]["provider"]
|
||||
}
|
||||
data = {
|
||||
"config_schema": config_schema,
|
||||
"providers": astrbot_config["provider"],
|
||||
"provider_sources": astrbot_config["provider_sources"],
|
||||
}
|
||||
return Response().ok(data=data).__dict__
|
||||
|
||||
async def get_uc_table(self):
|
||||
"""获取 UMOP 配置路由表"""
|
||||
return Response().ok({"routing": self.ucr.umop_to_conf_id}).__dict__
|
||||
@@ -433,9 +578,25 @@ class ConfigRoute(Route):
|
||||
return Response().error("缺少参数 provider_type").__dict__
|
||||
provider_type_ls = provider_type.split(",")
|
||||
provider_list = []
|
||||
astrbot_config = self.core_lifecycle.astrbot_config
|
||||
for provider in astrbot_config["provider"]:
|
||||
if provider.get("provider_type", None) in provider_type_ls:
|
||||
ps = self.core_lifecycle.provider_manager.providers_config
|
||||
p_source_pt = {
|
||||
psrc["id"]: psrc["provider_type"]
|
||||
for psrc in self.core_lifecycle.provider_manager.provider_sources_config
|
||||
}
|
||||
for provider in ps:
|
||||
ps_id = provider.get("provider_source_id", None)
|
||||
if (
|
||||
ps_id
|
||||
and ps_id in p_source_pt
|
||||
and p_source_pt[ps_id] in provider_type_ls
|
||||
):
|
||||
# chat
|
||||
prov = self.core_lifecycle.provider_manager.get_merged_provider_config(
|
||||
provider
|
||||
)
|
||||
provider_list.append(prov)
|
||||
elif not ps_id and provider.get("provider_type", None) in provider_type_ls:
|
||||
# agent runner, embedding, etc
|
||||
provider_list.append(provider)
|
||||
return Response().ok(provider_list).__dict__
|
||||
|
||||
@@ -458,9 +619,18 @@ class ConfigRoute(Route):
|
||||
|
||||
try:
|
||||
models = await provider.get_models()
|
||||
models = models or []
|
||||
|
||||
metadata_map = {}
|
||||
for model_id in models:
|
||||
meta = LLM_METADATAS.get(model_id)
|
||||
if meta:
|
||||
metadata_map[model_id] = meta
|
||||
|
||||
ret = {
|
||||
"models": models,
|
||||
"provider_id": provider_id,
|
||||
"model_metadata": metadata_map,
|
||||
}
|
||||
return Response().ok(ret).__dict__
|
||||
except Exception as e:
|
||||
@@ -522,6 +692,104 @@ class ConfigRoute(Route):
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取嵌入维度失败: {e!s}").__dict__
|
||||
|
||||
async def get_provider_source_models(self):
|
||||
"""获取指定 provider_source 支持的模型列表
|
||||
|
||||
本质上会临时初始化一个 Provider 实例,调用 get_models() 获取模型列表,然后销毁实例
|
||||
"""
|
||||
provider_source_id = request.args.get("source_id")
|
||||
if not provider_source_id:
|
||||
return Response().error("缺少参数 source_id").__dict__
|
||||
|
||||
try:
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
|
||||
# 从配置中查找对应的 provider_source
|
||||
provider_sources = self.config.get("provider_sources", [])
|
||||
provider_source = None
|
||||
for ps in provider_sources:
|
||||
if ps.get("id") == provider_source_id:
|
||||
provider_source = ps
|
||||
break
|
||||
|
||||
if not provider_source:
|
||||
return (
|
||||
Response()
|
||||
.error(f"未找到 ID 为 {provider_source_id} 的 provider_source")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 获取 provider 类型
|
||||
provider_type = provider_source.get("type", None)
|
||||
if not provider_type:
|
||||
return Response().error("provider_source 缺少 type 字段").__dict__
|
||||
|
||||
try:
|
||||
self.core_lifecycle.provider_manager.dynamic_import_provider(
|
||||
provider_type
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"动态导入提供商适配器失败: {e!s}").__dict__
|
||||
|
||||
# 获取对应的 provider 类
|
||||
if provider_type not in provider_cls_map:
|
||||
return (
|
||||
Response()
|
||||
.error(f"未找到适用于 {provider_type} 的提供商适配器")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
provider_metadata = provider_cls_map[provider_type]
|
||||
cls_type = provider_metadata.cls_type
|
||||
|
||||
if not cls_type:
|
||||
return Response().error(f"无法找到 {provider_type} 的类").__dict__
|
||||
|
||||
# 检查是否是 Provider 类型
|
||||
if not issubclass(cls_type, Provider):
|
||||
return (
|
||||
Response()
|
||||
.error(f"提供商 {provider_type} 不支持获取模型列表")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
# 临时实例化 provider
|
||||
inst = cls_type(provider_source, {})
|
||||
|
||||
# 如果有 initialize 方法,调用它
|
||||
init_fn = getattr(inst, "initialize", None)
|
||||
if inspect.iscoroutinefunction(init_fn):
|
||||
await init_fn()
|
||||
|
||||
# 获取模型列表
|
||||
models = await inst.get_models()
|
||||
models = models or []
|
||||
|
||||
metadata_map = {}
|
||||
for model_id in models:
|
||||
meta = LLM_METADATAS.get(model_id)
|
||||
if meta:
|
||||
metadata_map[model_id] = meta
|
||||
|
||||
# 销毁实例(如果有 terminate 方法)
|
||||
terminate_fn = getattr(inst, "terminate", None)
|
||||
if inspect.iscoroutinefunction(terminate_fn):
|
||||
await terminate_fn()
|
||||
|
||||
logger.info(
|
||||
f"获取到 provider_source {provider_source_id} 的模型列表: {models}",
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({"models": models, "model_metadata": metadata_map})
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取模型列表失败: {e!s}").__dict__
|
||||
|
||||
async def get_platform_list(self):
|
||||
"""获取所有平台的列表"""
|
||||
platform_list = []
|
||||
@@ -533,7 +801,15 @@ class ConfigRoute(Route):
|
||||
data = await request.json
|
||||
config = data.get("config", None)
|
||||
conf_id = data.get("conf_id", None)
|
||||
|
||||
try:
|
||||
# 不更新 provider_sources, provider, platform
|
||||
# 这些配置有单独的接口进行更新
|
||||
if conf_id == "default":
|
||||
no_update_keys = ["provider_sources", "provider", "platform"]
|
||||
for key in no_update_keys:
|
||||
config[key] = self.acm.default_conf[key]
|
||||
|
||||
await self._save_astrbot_configs(config, conf_id)
|
||||
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
|
||||
return Response().ok(None, "保存成功~").__dict__
|
||||
@@ -573,28 +849,30 @@ class ConfigRoute(Route):
|
||||
|
||||
async def post_new_provider(self):
|
||||
new_provider_config = await request.json
|
||||
self.config["provider"].append(new_provider_config)
|
||||
|
||||
try:
|
||||
save_config(self.config, self.config, is_core=True)
|
||||
await self.core_lifecycle.provider_manager.load_provider(
|
||||
new_provider_config,
|
||||
await self.core_lifecycle.provider_manager.create_provider(
|
||||
new_provider_config
|
||||
)
|
||||
except Exception as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
return Response().ok(None, "新增服务提供商配置成功~").__dict__
|
||||
return Response().ok(None, "新增服务提供商配置成功").__dict__
|
||||
|
||||
async def post_update_platform(self):
|
||||
update_platform_config = await request.json
|
||||
platform_id = update_platform_config.get("id", None)
|
||||
origin_platform_id = update_platform_config.get("id", None)
|
||||
new_config = update_platform_config.get("config", None)
|
||||
if not platform_id or not new_config:
|
||||
if not origin_platform_id or not new_config:
|
||||
return Response().error("参数错误").__dict__
|
||||
|
||||
if origin_platform_id != new_config.get("id", None):
|
||||
return Response().error("机器人名称不允许修改").__dict__
|
||||
|
||||
# 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid
|
||||
ensure_platform_webhook_config(new_config)
|
||||
|
||||
for i, platform in enumerate(self.config["platform"]):
|
||||
if platform["id"] == platform_id:
|
||||
if platform["id"] == origin_platform_id:
|
||||
self.config["platform"][i] = new_config
|
||||
break
|
||||
else:
|
||||
@@ -609,21 +887,15 @@ class ConfigRoute(Route):
|
||||
|
||||
async def post_update_provider(self):
|
||||
update_provider_config = await request.json
|
||||
provider_id = update_provider_config.get("id", None)
|
||||
origin_provider_id = update_provider_config.get("id", None)
|
||||
new_config = update_provider_config.get("config", None)
|
||||
if not provider_id or not new_config:
|
||||
if not origin_provider_id or not new_config:
|
||||
return Response().error("参数错误").__dict__
|
||||
|
||||
for i, provider in enumerate(self.config["provider"]):
|
||||
if provider["id"] == provider_id:
|
||||
self.config["provider"][i] = new_config
|
||||
break
|
||||
else:
|
||||
return Response().error("未找到对应服务提供商").__dict__
|
||||
|
||||
try:
|
||||
save_config(self.config, self.config, is_core=True)
|
||||
await self.core_lifecycle.provider_manager.reload(new_config)
|
||||
await self.core_lifecycle.provider_manager.update_provider(
|
||||
origin_provider_id, new_config
|
||||
)
|
||||
except Exception as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
return Response().ok(None, "更新成功,已经实时生效~").__dict__
|
||||
@@ -646,19 +918,17 @@ class ConfigRoute(Route):
|
||||
|
||||
async def post_delete_provider(self):
|
||||
provider_id = await request.json
|
||||
provider_id = provider_id.get("id")
|
||||
for i, provider in enumerate(self.config["provider"]):
|
||||
if provider["id"] == provider_id:
|
||||
del self.config["provider"][i]
|
||||
break
|
||||
else:
|
||||
return Response().error("未找到对应服务提供商").__dict__
|
||||
provider_id = provider_id.get("id", "")
|
||||
if not provider_id:
|
||||
return Response().error("缺少参数 id").__dict__
|
||||
|
||||
try:
|
||||
save_config(self.config, self.config, is_core=True)
|
||||
await self.core_lifecycle.provider_manager.terminate_provider(provider_id)
|
||||
await self.core_lifecycle.provider_manager.delete_provider(
|
||||
provider_id=provider_id
|
||||
)
|
||||
except Exception as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
return Response().ok(None, "删除成功,已经实时生效~").__dict__
|
||||
return Response().ok(None, "删除成功,已经实时生效。").__dict__
|
||||
|
||||
async def get_llm_tools(self):
|
||||
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
|
||||
from quart import request
|
||||
from quart import request, send_file
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
@@ -30,6 +32,7 @@ class ConversationRoute(Route):
|
||||
"POST",
|
||||
self.update_history,
|
||||
),
|
||||
"/conversation/export": ("POST", self.export_conversations),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.conv_mgr = core_lifecycle.conversation_manager
|
||||
@@ -283,3 +286,90 @@ class ConversationRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"更新对话历史失败: {e!s}").__dict__
|
||||
|
||||
async def export_conversations(self):
|
||||
"""批量导出对话为 JSONL 格式"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
conversations_to_export = data.get("conversations", [])
|
||||
|
||||
if not conversations_to_export:
|
||||
return Response().error("导出列表不能为空").__dict__
|
||||
|
||||
# 收集所有对话的内容
|
||||
jsonl_lines = []
|
||||
exported_count = 0
|
||||
failed_items = []
|
||||
|
||||
for conv_info in conversations_to_export:
|
||||
user_id = conv_info.get("user_id")
|
||||
cid = conv_info.get("cid")
|
||||
|
||||
if not user_id or not cid:
|
||||
failed_items.append(
|
||||
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
conversation = await self.conv_mgr.get_conversation(
|
||||
unified_msg_origin=user_id,
|
||||
conversation_id=cid,
|
||||
)
|
||||
|
||||
if not conversation:
|
||||
failed_items.append(
|
||||
f"user_id:{user_id}, cid:{cid} - 对话不存在"
|
||||
)
|
||||
continue
|
||||
|
||||
# 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
|
||||
content = json.loads(conversation.history)
|
||||
|
||||
# 创建导出记录
|
||||
export_record = {
|
||||
"cid": cid,
|
||||
"user_id": user_id,
|
||||
"platform_id": conversation.platform_id,
|
||||
"title": conversation.title,
|
||||
"persona_id": conversation.persona_id,
|
||||
"created_at": conversation.created_at,
|
||||
"updated_at": conversation.updated_at,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# 将记录转换为 JSON 字符串并添加到 JSONL
|
||||
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
|
||||
exported_count += 1
|
||||
|
||||
except Exception as e:
|
||||
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
|
||||
logger.error(
|
||||
f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
|
||||
)
|
||||
|
||||
if exported_count == 0:
|
||||
return Response().error("没有成功导出任何对话").__dict__
|
||||
|
||||
# 创建 JSONL 内容
|
||||
jsonl_content = "\n".join(jsonl_lines)
|
||||
|
||||
# 创建一个内存文件对象
|
||||
file_obj = BytesIO(jsonl_content.encode("utf-8"))
|
||||
file_obj.seek(0)
|
||||
|
||||
# 生成文件名
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"astrbot_conversations_export_{timestamp}.jsonl"
|
||||
|
||||
# 返回文件流
|
||||
return await send_file(
|
||||
file_obj,
|
||||
mimetype="application/jsonl",
|
||||
as_attachment=True,
|
||||
attachment_filename=filename,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
|
||||
return Response().error(f"批量导出对话失败: {e!s}").__dict__
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import cast
|
||||
|
||||
from quart import Response as QuartResponse
|
||||
from quart import make_response
|
||||
from quart import make_response, request
|
||||
|
||||
from astrbot.core import LogBroker, logger
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
|
||||
def _format_log_sse(log: dict, ts: float) -> str:
|
||||
"""辅助函数:格式化 SSE 消息"""
|
||||
payload = {
|
||||
"type": "log",
|
||||
**log,
|
||||
}
|
||||
return f"id: {ts}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
class LogRoute(Route):
|
||||
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
|
||||
super().__init__(context)
|
||||
@@ -21,21 +32,44 @@ class LogRoute(Route):
|
||||
methods=["GET"],
|
||||
)
|
||||
|
||||
async def log(self):
|
||||
async def _replay_cached_logs(
|
||||
self, last_event_id: str
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""辅助生成器:重放缓存的日志"""
|
||||
try:
|
||||
last_ts = float(last_event_id)
|
||||
cached_logs = list(self.log_broker.log_cache)
|
||||
|
||||
for log_item in cached_logs:
|
||||
log_ts = float(log_item.get("time", 0))
|
||||
|
||||
if log_ts > last_ts:
|
||||
yield _format_log_sse(log_item, log_ts)
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Log SSE 补发历史错误: {e}")
|
||||
|
||||
async def log(self) -> QuartResponse:
|
||||
last_event_id = request.headers.get("Last-Event-ID")
|
||||
|
||||
async def stream():
|
||||
queue = None
|
||||
try:
|
||||
if last_event_id:
|
||||
async for event in self._replay_cached_logs(last_event_id):
|
||||
yield event
|
||||
|
||||
queue = self.log_broker.register()
|
||||
while True:
|
||||
message = await queue.get()
|
||||
payload = {
|
||||
"type": "log",
|
||||
**message, # see astrbot/core/log.py
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
current_ts = message.get("time", time.time())
|
||||
yield _format_log_sse(message, current_ts)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.error(f"Log SSE 连接错误: {e}")
|
||||
finally:
|
||||
if queue:
|
||||
@@ -53,7 +87,7 @@ class LogRoute(Route):
|
||||
},
|
||||
),
|
||||
)
|
||||
response.timeout = None
|
||||
response.timeout = None # type: ignore
|
||||
return response
|
||||
|
||||
async def log_history(self):
|
||||
@@ -69,6 +103,6 @@ class LogRoute(Route):
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except BaseException as e:
|
||||
except Exception as e:
|
||||
logger.error(f"获取日志历史失败: {e}")
|
||||
return Response().error(f"获取日志历史失败: {e}").__dict__
|
||||
|
||||
@@ -124,7 +124,11 @@ class PluginRoute(Route):
|
||||
session.get(url) as response,
|
||||
):
|
||||
if response.status == 200:
|
||||
remote_data = await response.json()
|
||||
try:
|
||||
remote_data = await response.json()
|
||||
except aiohttp.ContentTypeError:
|
||||
remote_text = await response.text()
|
||||
remote_data = json.loads(remote_text)
|
||||
|
||||
# 检查远程数据是否为空
|
||||
if not remote_data or (
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from functools import cmp_to_key
|
||||
|
||||
import aiohttp
|
||||
import psutil
|
||||
@@ -11,7 +14,9 @@ from astrbot.core.config import VERSION
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.migration.helper import check_migration_needed_v4
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.utils.io import get_dashboard_version
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -30,6 +35,8 @@ class StatRoute(Route):
|
||||
"/stat/start-time": ("GET", self.get_start_time),
|
||||
"/stat/restart-core": ("POST", self.restart_core),
|
||||
"/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection),
|
||||
"/stat/changelog": ("GET", self.get_changelog),
|
||||
"/stat/changelog/list": ("GET", self.list_changelog_versions),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.register_routes()
|
||||
@@ -183,3 +190,92 @@ class StatRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Error: {e!s}").__dict__
|
||||
|
||||
async def get_changelog(self):
|
||||
"""获取指定版本的更新日志"""
|
||||
try:
|
||||
version = request.args.get("version")
|
||||
if not version:
|
||||
return Response().error("version parameter is required").__dict__
|
||||
|
||||
version = version.lstrip("v")
|
||||
|
||||
# 防止路径遍历攻击
|
||||
if not re.match(r"^[a-zA-Z0-9._-]+$", version):
|
||||
return Response().error("Invalid version format").__dict__
|
||||
if ".." in version or "/" in version or "\\" in version:
|
||||
return Response().error("Invalid version format").__dict__
|
||||
|
||||
filename = f"v{version}.md"
|
||||
project_path = get_astrbot_path()
|
||||
changelogs_dir = os.path.join(project_path, "changelogs")
|
||||
changelog_path = os.path.join(changelogs_dir, filename)
|
||||
|
||||
# 规范化路径,防止符号链接攻击
|
||||
changelog_path = os.path.realpath(changelog_path)
|
||||
changelogs_dir = os.path.realpath(changelogs_dir)
|
||||
|
||||
# 验证最终路径在预期的 changelogs 目录内(防止路径遍历)
|
||||
# 确保规范化后的路径以 changelogs_dir 开头,且是目录内的文件
|
||||
changelog_path_normalized = os.path.normpath(changelog_path)
|
||||
changelogs_dir_normalized = os.path.normpath(changelogs_dir)
|
||||
|
||||
# 检查路径是否在预期目录内(必须是目录的子文件,不能是目录本身)
|
||||
expected_prefix = changelogs_dir_normalized + os.sep
|
||||
if not changelog_path_normalized.startswith(expected_prefix):
|
||||
logger.warning(
|
||||
f"Path traversal attempt detected: {version} -> {changelog_path}",
|
||||
)
|
||||
return Response().error("Invalid version format").__dict__
|
||||
|
||||
if not os.path.exists(changelog_path):
|
||||
return (
|
||||
Response()
|
||||
.error(f"Changelog for version {version} not found")
|
||||
.__dict__
|
||||
)
|
||||
if not os.path.isfile(changelog_path):
|
||||
return (
|
||||
Response()
|
||||
.error(f"Changelog for version {version} not found")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
with open(changelog_path, encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
return Response().ok({"content": content, "version": version}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Error: {e!s}").__dict__
|
||||
|
||||
async def list_changelog_versions(self):
|
||||
"""获取所有可用的更新日志版本列表"""
|
||||
try:
|
||||
project_path = get_astrbot_path()
|
||||
changelogs_dir = os.path.join(project_path, "changelogs")
|
||||
|
||||
if not os.path.exists(changelogs_dir):
|
||||
return Response().ok({"versions": []}).__dict__
|
||||
|
||||
versions = []
|
||||
for filename in os.listdir(changelogs_dir):
|
||||
if filename.endswith(".md") and filename.startswith("v"):
|
||||
# 提取版本号(去除 v 前缀和 .md 后缀)
|
||||
version = filename[1:-3] # 去掉 "v" 和 ".md"
|
||||
# 验证版本号格式
|
||||
if re.match(r"^[a-zA-Z0-9._-]+$", version):
|
||||
versions.append(version)
|
||||
|
||||
# 按版本号排序(降序,最新的在前)
|
||||
# 使用项目中的 VersionComparator 进行语义化版本号排序
|
||||
versions.sort(
|
||||
key=cmp_to_key(
|
||||
lambda v1, v2: VersionComparator.compare_version(v2, v1),
|
||||
),
|
||||
)
|
||||
|
||||
return Response().ok({"versions": versions}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"Error: {e!s}").__dict__
|
||||
|
||||
@@ -3,6 +3,7 @@ import traceback
|
||||
from quart import request
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.mcp_client import MCPTool
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.star import star_map
|
||||
|
||||
@@ -296,15 +297,30 @@ class ToolsRoute(Route):
|
||||
"""获取所有注册的工具列表"""
|
||||
try:
|
||||
tools = self.tool_mgr.func_list
|
||||
tools_dict = [
|
||||
{
|
||||
tools_dict = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, MCPTool):
|
||||
origin = "mcp"
|
||||
origin_name = tool.mcp_server_name
|
||||
elif tool.handler_module_path and star_map.get(
|
||||
tool.handler_module_path
|
||||
):
|
||||
star = star_map[tool.handler_module_path]
|
||||
origin = "plugin"
|
||||
origin_name = star.name
|
||||
else:
|
||||
origin = "unknown"
|
||||
origin_name = "unknown"
|
||||
|
||||
tool_info = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
"active": tool.active,
|
||||
"origin": origin,
|
||||
"origin_name": origin_name,
|
||||
}
|
||||
for tool in tools
|
||||
]
|
||||
tools_dict.append(tool_info)
|
||||
return Response().ok(data=tools_dict).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -19,6 +19,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
|
||||
from .routes import *
|
||||
from .routes.backup import BackupRoute
|
||||
from .routes.platform import PlatformRoute
|
||||
from .routes.route import Response, RouteContext
|
||||
from .routes.session_management import SessionManagementRoute
|
||||
@@ -67,6 +68,7 @@ class AstrBotDashboard:
|
||||
core_lifecycle,
|
||||
core_lifecycle.plugin_manager,
|
||||
)
|
||||
self.command_route = CommandRoute(self.context)
|
||||
self.cr = ConfigRoute(self.context, core_lifecycle)
|
||||
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
|
||||
self.sfr = StaticFileRoute(self.context)
|
||||
@@ -84,6 +86,7 @@ class AstrBotDashboard:
|
||||
self.t2i_route = T2iRoute(self.context, core_lifecycle)
|
||||
self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle)
|
||||
self.platform_route = PlatformRoute(self.context, core_lifecycle)
|
||||
self.backup_route = BackupRoute(self.context, db, core_lifecycle)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
@@ -107,7 +110,12 @@ class AstrBotDashboard:
|
||||
async def auth_middleware(self):
|
||||
if not request.path.startswith("/api"):
|
||||
return None
|
||||
allowed_endpoints = ["/api/auth/login", "/api/file", "/api/platform/webhook"]
|
||||
allowed_endpoints = [
|
||||
"/api/auth/login",
|
||||
"/api/file",
|
||||
"/api/platform/webhook",
|
||||
"/api/stat/start-time",
|
||||
]
|
||||
if any(request.path.startswith(prefix) for prefix in allowed_endpoints):
|
||||
return None
|
||||
# 声明 JWT
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
## What's Changed
|
||||
|
||||
> 📢 在升级前,请**完整阅读**本次更新日志。
|
||||
>
|
||||
> **特别提醒:**
|
||||
> 1. 该版本为 alpha.1 预览版本。
|
||||
> 2. 本次升级**如果再降级**,会由于提供商配置的变更,导致提供商配置错乱,需要手动删除后重新添加。
|
||||
> 3. 此版本 WebUI 包体相较上一个版本增加约 **193%**,共约 **9.8 MB**,升级可能会需要一些时间。
|
||||
|
||||
### 重构与优化
|
||||
|
||||
- 重构 Provider 页面和提供商的配置结构,将 Chat Provider 配置拆分为 Provider Source(提供商源)和 Provider(代表提供商源的各个模型),引入了提供商模型自动发现、模型元数据自动发现的功能,**提供更加便捷的模型添加体验**。
|
||||
- ⚠️ 将 “MCP” 页面移动到了 “插件” 页面中
|
||||
- ⚠️ 将 “MCP” 页面中的工具管理移动到了 “插件” -> “管理行为” 中。
|
||||
- ⚠️ 将 “QQ 个人号(OneBot v11)” 机器人适配器类型更名为 “OneBot v11”,并将其 Logo 更改为 OneBot 的 Logo。
|
||||
- ⚠️ AstrBot WebChat 升级为 **AstrBot ChatUI**,入口从边栏修改为顶部(右上角)切换按钮。
|
||||
- 优化引用消息的逻辑,减少对模型输入缓存的破坏。
|
||||
|
||||
### 修复
|
||||
|
||||
- ‼️ 修复部分情况下,分段回复无法正常分段的问题。
|
||||
- 修复处理工具返回结果的过程中,导致一些直接发送图片的工具(如生图工具)无法正确发送到用户的问题。
|
||||
- 修复 WebChat 部分情况下,上一条消息文字内容增量到下一条消息的问题。
|
||||
|
||||
### 新增
|
||||
|
||||
- 支持**指令管理**,设置指令别名、解决指令冲突、查看指令详情等。入口:“插件” -> “管理行为”。
|
||||
- 支持 Google Gemini 3 系列引入的 [Thinking Level](https://ai.google.dev/gemini-api/docs/thinking#thinking-levels) 配置。
|
||||
- 支持记录每条 LLM 消息的耗时、Token 使用量、TTFT 数据,以及每次 Agent Loop 的各种统计数据。
|
||||
- AstrBot ChatUI 支持查看每条消息的 TTFT、Token 使用量数据。
|
||||
- AstrBot ChatUI 支持显示每次工具调用的耗时、参数和响应。
|
||||
- AstrBot ChatUI 支持渲染 Mermaid、LateX 内容,优化了 Code Block 的显示效果(使用 Monaco Editor),并减少 DOM 更新于内存占用。(Powered by [Simon-He95/markstream-vue](https://github.com/Simon-He95/markstream-vue))
|
||||
- 支持查看 Changelog 历史版本更新日志。
|
||||
- 🎄
|
||||
@@ -0,0 +1,44 @@
|
||||
## What's Changed
|
||||
|
||||
> 📢 在升级前,请**完整阅读**本次更新日志。
|
||||
>
|
||||
> **特别提醒:**
|
||||
> 1. 该版本为 alpha.2 预览版本。
|
||||
> 2. 本次升级**如果再降级**,会由于提供商配置的变更,导致提供商配置错乱,需要手动删除后重新添加。
|
||||
> 3. 此版本 WebUI 包体相较上一个版本增加约 **193%**,共约 **9.8 MB**,升级可能会需要一些时间。
|
||||
|
||||
## alpha.1 -> alpha.2
|
||||
|
||||
- 修复:“对话数据”页对话轨迹详情显示异常的问题
|
||||
- 优化:当 Agent 达到最大步数时的处理。在达到最大步数后,会移除所有请求中的 tools 并告知模型根据上下文进行最终总结。
|
||||
- 优化:LLM tools 执行的错误处理,减少工具调用无限循环的问题。
|
||||
- 优化:ChatUI 打开模型选择菜单时,会重新获取提供商配置。
|
||||
- 优化:ChatUI 新建对话并发送消息后,对话列表页自动选中该对话。
|
||||
|
||||
## 4.10.0 变化
|
||||
|
||||
### 重构与优化
|
||||
|
||||
- 重构 Provider 页面和提供商的配置结构,将 Chat Provider 配置拆分为 Provider Source(提供商源)和 Provider(代表提供商源的各个模型),引入了提供商模型自动发现、模型元数据自动发现的功能,**提供更加便捷的模型添加体验**。
|
||||
- ⚠️ 将 “MCP” 页面移动到了 “插件” 页面中
|
||||
- ⚠️ 将 “MCP” 页面中的工具管理移动到了 “插件” -> “管理行为” 中。
|
||||
- ⚠️ 将 “QQ 个人号(OneBot v11)” 机器人适配器类型更名为 “OneBot v11”,并将其 Logo 更改为 OneBot 的 Logo。
|
||||
- ⚠️ AstrBot WebChat 升级为 **AstrBot ChatUI**,入口从边栏修改为顶部(右上角)切换按钮。
|
||||
- 优化引用消息的逻辑,减少对模型输入缓存的破坏。
|
||||
|
||||
### 修复
|
||||
|
||||
- ‼️ 修复部分情况下,分段回复无法正常分段的问题。
|
||||
- 修复处理工具返回结果的过程中,导致一些直接发送图片的工具(如生图工具)无法正确发送到用户的问题。
|
||||
- 修复 WebChat 部分情况下,上一条消息文字内容增量到下一条消息的问题。
|
||||
|
||||
### 新增
|
||||
|
||||
- 支持**指令管理**,设置指令别名、解决指令冲突、查看指令详情等。入口:“插件” -> “管理行为”。
|
||||
- 支持 Google Gemini 3 系列引入的 [Thinking Level](https://ai.google.dev/gemini-api/docs/thinking#thinking-levels) 配置。
|
||||
- 支持记录每条 LLM 消息的耗时、Token 使用量、TTFT 数据,以及每次 Agent Loop 的各种统计数据。
|
||||
- AstrBot ChatUI 支持查看每条消息的 TTFT、Token 使用量数据。
|
||||
- AstrBot ChatUI 支持显示每次工具调用的耗时、参数和响应。
|
||||
- AstrBot ChatUI 支持渲染 Mermaid、LateX 内容,优化了 Code Block 的显示效果(使用 Monaco Editor),并减少 DOM 更新于内存占用。(Powered by [Simon-He95/markstream-vue](https://github.com/Simon-He95/markstream-vue))
|
||||
- 支持查看 Changelog 历史版本更新日志。
|
||||
- 🎄
|
||||
@@ -0,0 +1,40 @@
|
||||
## What's Changed
|
||||
|
||||
> 📢 在升级前,请**完整阅读**本次更新日志。
|
||||
>
|
||||
> **特别提醒:**
|
||||
> 1. 本次升级**如果再降级**,会由于提供商配置的变更,导致提供商配置错乱,需要手动删除后重新添加。
|
||||
> 2. 此版本 WebUI 包体相较上一个版本增加约 **193%**,共约 **9.8 MB**,升级可能会需要一些时间。
|
||||
> 3. **升级后请务必确保 WebUI 和 AstrBot Core 版本一致**,否则会产生预期之外的情况。(判断方法:日志中出现 `WebUI 版本已是最新。` 即为一致的版本,`检测到 WebUI 版本 (xxx) 与当前 AstrBot 版本 (xxx) 不符。` 即为不一致的版本。此版本的判断方法也可通查看 WebUI 右上角是否出现 Bot / Chat 的切换按钮控件来判断是否是新版本的 WebUI)。
|
||||
> 4. 如果有任何问题请提交 [Issue](https://github.com/AstrBotDevs/AstrBot/issues) 并附带 `v4.10.0` tag。
|
||||
|
||||
### 重构与优化
|
||||
|
||||
- 重构 Provider 页面和提供商的配置结构,将 Chat Provider 配置拆分为 Provider Source(提供商源)和 Provider(代表提供商源的各个模型),引入了提供商模型自动发现、模型元数据自动发现的功能,**提供更加便捷的模型添加体验**。
|
||||
- ⚠️ 将 “MCP” 页面移动到了 “插件” 页面中
|
||||
- ⚠️ 将 “MCP” 页面中的工具管理移动到了 “插件” -> “管理行为” 中。
|
||||
- ⚠️ 将 “QQ 个人号(OneBot v11)” 机器人适配器类型更名为 “OneBot v11”,并将其 Logo 更改为 OneBot 的 Logo。
|
||||
- ⚠️ AstrBot WebChat 升级为 **AstrBot ChatUI**,入口从边栏修改为顶部(右上角)切换按钮。
|
||||
- 优化引用消息的逻辑,减少对模型输入缓存的破坏。
|
||||
- 优化当 Agent 达到最大步数时的处理。在达到最大步数后,会移除所有请求中的 tools 并告知模型根据上下文进行最终总结。
|
||||
- 优化 LLM tools 执行的错误处理,减少工具调用无限循环的问题。
|
||||
|
||||
|
||||
### 修复
|
||||
|
||||
- ‼️ 修复部分情况下,分段回复无法正常分段的问题。
|
||||
- 修复处理工具返回结果的过程中,导致一些直接发送图片的工具(如生图工具)无法正确发送到用户的问题。
|
||||
- 修复 WebChat 部分情况下,上一条消息文字内容增量到下一条消息的问题。
|
||||
|
||||
### 新增
|
||||
|
||||
- 支持**指令管理**,设置指令别名、解决指令冲突、查看指令详情等。入口:“插件” -> “管理行为”。
|
||||
- 支持 Google Gemini 3 系列引入的 [Thinking Level](https://ai.google.dev/gemini-api/docs/thinking#thinking-levels) 配置。
|
||||
- 支持记录每条 LLM 消息的耗时、Token 使用量、TTFT 数据,以及每次 Agent Loop 的各种统计数据。
|
||||
- AstrBot ChatUI 支持查看每条消息的 TTFT、Token 使用量数据。
|
||||
- AstrBot ChatUI 支持显示每次工具调用的耗时、参数和响应。
|
||||
- AstrBot ChatUI 支持渲染 Mermaid、LateX 内容,优化了 Code Block 的显示效果(使用 Monaco Editor),并减少 DOM 更新于内存占用。(Powered by [Simon-He95/markstream-vue](https://github.com/Simon-He95/markstream-vue))
|
||||
- 支持查看 Changelog 历史版本更新日志。
|
||||
- 🎄
|
||||
|
||||
Merry Christmas!
|
||||
@@ -0,0 +1,46 @@
|
||||
## What's Changed
|
||||
|
||||
> 📢 在升级前,请**完整阅读**本次更新日志。
|
||||
>
|
||||
> **特别提醒:**
|
||||
> 1. 本次升级**如果再降级**,会由于提供商配置的变更,导致提供商配置错乱,需要手动删除后重新添加。
|
||||
> 2. 此版本 WebUI 包体相较上一个版本增加约 **193%**,共约 **9.8 MB**,升级可能会需要一些时间。
|
||||
> 3. **升级后请务必确保 WebUI 和 AstrBot Core 版本一致**,否则会产生预期之外的情况。(判断方法:日志中出现 `WebUI 版本已是最新。` 即为一致的版本,`检测到 WebUI 版本 (xxx) 与当前 AstrBot 版本 (xxx) 不符。` 即为不一致的版本。此版本的判断方法也可通查看 WebUI 右上角是否出现 Bot / Chat 的切换按钮控件来判断是否是新版本的 WebUI)。
|
||||
> 4. 如果有任何问题请提交 [Issue](https://github.com/AstrBotDevs/AstrBot/issues) 并附带 `v4.10.0` tag。
|
||||
|
||||
## 4.10.0 -> 4.10.1
|
||||
|
||||
- fix(core): 修复极少数情况下由于指令管理导致的 AstrBot 启动失败的问题
|
||||
- fix(core): 修复当提供商源带有斜杠(“/”)时,无法删除 / 更新提供商源的问题(报错 405)
|
||||
- perf(core): 优化 OneBot 适配器的消息段解析逻辑,修复部分情况下无法正确解析消息段的问题
|
||||
|
||||
### 重构与优化
|
||||
|
||||
- 重构 Provider 页面和提供商的配置结构,将 Chat Provider 配置拆分为 Provider Source(提供商源)和 Provider(代表提供商源的各个模型),引入了提供商模型自动发现、模型元数据自动发现的功能,**提供更加便捷的模型添加体验**。
|
||||
- ⚠️ 将 “MCP” 页面移动到了 “插件” 页面中
|
||||
- ⚠️ 将 “MCP” 页面中的工具管理移动到了 “插件” -> “管理行为” 中。
|
||||
- ⚠️ 将 “QQ 个人号(OneBot v11)” 机器人适配器类型更名为 “OneBot v11”,并将其 Logo 更改为 OneBot 的 Logo。
|
||||
- ⚠️ AstrBot WebChat 升级为 **AstrBot ChatUI**,入口从边栏修改为顶部(右上角)切换按钮。
|
||||
- 优化引用消息的逻辑,减少对模型输入缓存的破坏。
|
||||
- 优化当 Agent 达到最大步数时的处理。在达到最大步数后,会移除所有请求中的 tools 并告知模型根据上下文进行最终总结。
|
||||
- 优化 LLM tools 执行的错误处理,减少工具调用无限循环的问题。
|
||||
|
||||
|
||||
### 修复
|
||||
|
||||
- ‼️ 修复部分情况下,分段回复无法正常分段的问题。
|
||||
- 修复处理工具返回结果的过程中,导致一些直接发送图片的工具(如生图工具)无法正确发送到用户的问题。
|
||||
- 修复 WebChat 部分情况下,上一条消息文字内容增量到下一条消息的问题。
|
||||
|
||||
### 新增
|
||||
|
||||
- 支持**指令管理**,设置指令别名、解决指令冲突、查看指令详情等。入口:“插件” -> “管理行为”。
|
||||
- 支持 Google Gemini 3 系列引入的 [Thinking Level](https://ai.google.dev/gemini-api/docs/thinking#thinking-levels) 配置。
|
||||
- 支持记录每条 LLM 消息的耗时、Token 使用量、TTFT 数据,以及每次 Agent Loop 的各种统计数据。
|
||||
- AstrBot ChatUI 支持查看每条消息的 TTFT、Token 使用量数据。
|
||||
- AstrBot ChatUI 支持显示每次工具调用的耗时、参数和响应。
|
||||
- AstrBot ChatUI 支持渲染 Mermaid、LateX 内容,优化了 Code Block 的显示效果(使用 Monaco Editor),并减少 DOM 更新于内存占用。(Powered by [Simon-He95/markstream-vue](https://github.com/Simon-He95/markstream-vue))
|
||||
- 支持查看 Changelog 历史版本更新日志。
|
||||
- 🎄
|
||||
|
||||
Merry Christmas!
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user