Compare commits
98 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 11f45d3fc2 | |||
| 300a73ace0 | |||
| a5b9de3695 | |||
| 90142bcafe | |||
| 79d0487c03 | |||
| 4f15102e79 | |||
| ef1feb639c | |||
| 1039a4f864 | |||
| 66e2f49c11 | |||
| c5773fe63e | |||
| 4e9ef48af2 | |||
| 9eafd7b44a | |||
| fc61f7ad32 | |||
| f51810997a | |||
| fb4baf676f | |||
| 71ad974c3c | |||
| f0fff68947 | |||
| 3e3599835e | |||
| 5255388e2d | |||
| fbdd60b64c | |||
| bd1b0a2836 | |||
| 19541d9d07 | |||
| 2a5d574394 | |||
| f2924fbd1b | |||
| 703e208947 | |||
| 9a5cc977c2 | |||
| aa38fe776a | |||
| 701399c00c | |||
| eaee98d4b8 | |||
| 76c66000a7 | |||
| 4b365143c0 | |||
| 6e4e5011e2 | |||
| d853bfde84 | |||
| a0e856f80f | |||
| 8c94a0010c | |||
| a44fdaaec0 | |||
| 60105c76f5 | |||
| bcf87d3ce4 | |||
| 4d7c8c8453 | |||
| a064a9115f | |||
| 6ef99e1553 | |||
| c0dbe5cf65 | |||
| 3598c51eff | |||
| b5cdb8f650 | |||
| fc5b520f9b | |||
| 904f56b32f | |||
| 2f15fd019c | |||
| 82330b8d10 | |||
| 3ee6af7027 | |||
| 6e20ebe901 | |||
| 4d6150fd6d | |||
| 544e52191b | |||
| f2c2a6da4a | |||
| dd3df425ee | |||
| 40b4a27a3d | |||
| 9d991c7468 | |||
| ad6a8b5c94 | |||
| 1b4bfcbd72 | |||
| 9d3cc593a1 | |||
| f0dee35ba9 | |||
| 4135bd84d5 | |||
| f6da614e5d | |||
| 5f531c9be5 | |||
| 94591d965b | |||
| 8a0f865af1 | |||
| 4aced976a8 | |||
| 0299aa6e4c | |||
| e8b54a019e | |||
| 98ce796275 | |||
| b87dcf2275 | |||
| 591a228431 | |||
| f52f375154 | |||
| 975c685a17 | |||
| 6db80d36a8 | |||
| 4651bd2807 | |||
| 94ada3793e | |||
| fd05b0bf09 | |||
| 4d046f8490 | |||
| 58e32b7b70 | |||
| 903dd0f9f7 | |||
| 1acac0cac2 | |||
| 80b89fd2ea | |||
| 26f863ba81 | |||
| f78a90218e | |||
| a3ecebd2aa | |||
| 67c33b842d | |||
| 5431c9f46e | |||
| 764b91a5f7 | |||
| c20c1b84bf | |||
| fd66a0ac00 | |||
| aaee283367 | |||
| 4a5b7d1976 | |||
| 08244548ab | |||
| b486de6a98 | |||
| b2e9dab233 | |||
| 45110200ea | |||
| a70088b799 | |||
| bb45d9cb54 |
@@ -15,7 +15,6 @@ Always reference these instructions first and fallback to search or bash command
|
||||
### Running the Application
|
||||
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
||||
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
||||
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
|
||||
|
||||
### Dashboard Build (Vue.js/Node.js)
|
||||
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
||||
@@ -35,7 +34,7 @@ Always reference these instructions first and fallback to search or bash command
|
||||
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
||||
|
||||
### Plugin Development
|
||||
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugins load from `astrbot/builtin_stars/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugin system supports function tools and message handlers
|
||||
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ jobs:
|
||||
zip -r dist.zip dist
|
||||
|
||||
- name: Archive production artifacts
|
||||
uses: actions/upload-artifact@v5
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: dist-without-markdown
|
||||
path: |
|
||||
|
||||
+51
-15
@@ -1,27 +1,63 @@
|
||||
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
|
||||
# 本工作流用于标记并关闭长期不活跃的 Issue。
|
||||
# 目前仅针对带 `bug` 标签的 Issue 生效,不会处理 PR。
|
||||
#
|
||||
# You can adjust the behavior by modifying this file.
|
||||
# For more information, see:
|
||||
# https://github.com/actions/stale
|
||||
name: Mark stale issues and pull requests
|
||||
# 文档: https://github.com/actions/stale
|
||||
name: Mark stale bug issues
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '21 23 * * *'
|
||||
# 每天 UTC 08:30 执行 (北京时间 16:30)
|
||||
- cron: '30 8 * * *'
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
dry-run:
|
||||
description: '仅预览, 不实际执行 (Dry run mode)'
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: 'Stale issue message'
|
||||
stale-pr-message: 'Stale pull request message'
|
||||
stale-issue-label: 'no-issue-activity'
|
||||
stale-pr-label: 'no-pr-activity'
|
||||
- uses: actions/stale@v10
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
# 只处理带 bug 标签的 Issue
|
||||
any-of-labels: 'bug'
|
||||
|
||||
# 不处理 PR
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
|
||||
# 不活跃判定与关闭策略: 先标记 stale, 再延迟关闭
|
||||
days-before-issue-stale: 60
|
||||
days-before-issue-close: 30
|
||||
|
||||
stale-issue-label: 'stale'
|
||||
stale-issue-message: |
|
||||
This issue has been automatically marked as **stale** because it has not had any activity.
|
||||
It will be closed in a certain period of time if no further activity occurs.
|
||||
If this issue is still relevant, please leave a comment.
|
||||
|
||||
---
|
||||
|
||||
该 Issue 已较长时间无活动, 已被标记为 `stale`。
|
||||
如无后续活动, 将在一段时间后自动关闭。
|
||||
如仍需跟进, 请回复评论。
|
||||
close-issue-message: |
|
||||
This issue has been automatically closed due to inactivity.
|
||||
If the problem still exists, feel free to reopen or create a new issue with updated information.
|
||||
|
||||
---
|
||||
|
||||
该 Issue 因长期无活动已自动关闭。
|
||||
如问题仍存在, 欢迎补充复现信息并重新打开或新建 Issue。
|
||||
|
||||
remove-stale-when-updated: true
|
||||
|
||||
debug-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run }}
|
||||
|
||||
+2
-2
@@ -24,9 +24,9 @@ configs/session
|
||||
configs/config.yaml
|
||||
cmd_config.json
|
||||
|
||||
# Plugins and packages
|
||||
# Plugins
|
||||
addons/plugins
|
||||
packages/python_interpreter/workplace
|
||||
astrbot/builtin_stars/python_interpreter/workplace
|
||||
tests/astrbot_plugin_openai
|
||||
|
||||
# Dashboard
|
||||
|
||||
+26
-1
@@ -33,6 +33,20 @@
|
||||
- 请使用英文描述您的 PR。
|
||||
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
|
||||
|
||||
#### 代码规范
|
||||
|
||||
##### Core
|
||||
|
||||
我们使用 Ruff 作为代码格式化和静态分析工具。在提交代码之前,请运行以下命令以确保代码符合规范:
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
|
||||
如果您使用 VSCode,可以安装 `Ruff` 插件。
|
||||
|
||||
|
||||
## Contributing Guide
|
||||
|
||||
First off, thanks for taking the time to contribute! ❤️
|
||||
@@ -62,4 +76,15 @@ We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features.
|
||||
|
||||
#### PR Description
|
||||
- Please use English to describe your PR.
|
||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||
|
||||
#### Code Style
|
||||
|
||||
##### Core
|
||||
|
||||
We use Ruff as our code formatter and static analysis tool. Before submitting your code, please run the following commands to ensure your code adheres to the style guidelines:
|
||||
|
||||
```bash
|
||||
ruff format .
|
||||
ruff check .
|
||||
```
|
||||
|
||||
@@ -100,16 +100,8 @@ class Main(star.Star):
|
||||
logger.error(f"ltm: {e}")
|
||||
|
||||
@filter.on_llm_response()
|
||||
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
"""在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话"""
|
||||
umo = event.unified_msg_origin
|
||||
cfg = self.context.get_config(umo).get("provider_settings", {})
|
||||
show_reasoning = cfg.get("display_reasoning_text", False)
|
||||
if show_reasoning and resp.reasoning_content:
|
||||
resp.completion_text = (
|
||||
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
|
||||
)
|
||||
|
||||
async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMResponse):
|
||||
"""在 LLM 响应后记录对话"""
|
||||
if self.ltm and self.ltm_enabled(event):
|
||||
try:
|
||||
await self.ltm.after_req_llm(event, resp)
|
||||
+44
-16
@@ -7,6 +7,7 @@ from astrbot.api import logger, sp, star
|
||||
from astrbot.api.event import AstrMessageEvent
|
||||
from astrbot.api.message_components import Image, Reply
|
||||
from astrbot.api.provider import Provider, ProviderRequest
|
||||
from astrbot.core.agent.message import TextPart
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
|
||||
|
||||
@@ -85,7 +86,9 @@ class ProcessLLMRequest:
|
||||
req.image_urls,
|
||||
)
|
||||
if caption:
|
||||
req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}"
|
||||
req.extra_user_content_parts.append(
|
||||
TextPart(text=f"<image_caption>{caption}</image_caption>")
|
||||
)
|
||||
req.image_urls = []
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片描述失败: {e}")
|
||||
@@ -129,13 +132,14 @@ class ProcessLLMRequest:
|
||||
else:
|
||||
req.prompt = prefix + req.prompt
|
||||
|
||||
# 收集系统提醒信息
|
||||
system_parts = []
|
||||
|
||||
# user identifier
|
||||
if cfg.get("identifier"):
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
req.prompt = (
|
||||
f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n{req.prompt}"
|
||||
)
|
||||
system_parts.append(f"User ID: {user_id}, Nickname: {user_nickname}")
|
||||
|
||||
# group name identifier
|
||||
if cfg.get("group_name_display") and event.message_obj.group_id:
|
||||
@@ -146,7 +150,7 @@ class ProcessLLMRequest:
|
||||
return
|
||||
group_name = event.message_obj.group.group_name
|
||||
if group_name:
|
||||
req.system_prompt += f"\nGroup name: {group_name}\n"
|
||||
system_parts.append(f"Group name: {group_name}")
|
||||
|
||||
# time info
|
||||
if cfg.get("datetime_system_prompt"):
|
||||
@@ -162,7 +166,7 @@ class ProcessLLMRequest:
|
||||
current_time = (
|
||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
)
|
||||
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
||||
system_parts.append(f"Current datetime: {current_time}")
|
||||
|
||||
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
|
||||
if req.conversation:
|
||||
@@ -181,37 +185,61 @@ class ProcessLLMRequest:
|
||||
quote = comp
|
||||
break
|
||||
if quote:
|
||||
sender_info = ""
|
||||
if quote.sender_nickname:
|
||||
sender_info = f"(Sent by {quote.sender_nickname})"
|
||||
message_str = quote.message_str or "[Empty Text]"
|
||||
req.system_prompt += (
|
||||
f"\nUser is quoting a message{sender_info}.\n"
|
||||
f"Here are the information of the quoted message: Text Content: {message_str}.\n"
|
||||
content_parts = []
|
||||
|
||||
# 1. 处理引用的文本
|
||||
sender_info = (
|
||||
f"({quote.sender_nickname}): " if quote.sender_nickname else ""
|
||||
)
|
||||
message_str = quote.message_str or "[Empty Text]"
|
||||
content_parts.append(f"{sender_info}{message_str}")
|
||||
|
||||
# 2. 处理引用的图片 (保留原有逻辑,但改变输出目标)
|
||||
image_seg = None
|
||||
if quote.chain:
|
||||
for comp in quote.chain:
|
||||
if isinstance(comp, Image):
|
||||
image_seg = comp
|
||||
break
|
||||
|
||||
if image_seg:
|
||||
try:
|
||||
# 找到可以生成图片描述的 provider
|
||||
prov = None
|
||||
if img_cap_prov_id:
|
||||
prov = self.ctx.get_provider_by_id(img_cap_prov_id)
|
||||
if prov is None:
|
||||
prov = self.ctx.get_using_provider(event.unified_msg_origin)
|
||||
|
||||
# 调用 provider 生成图片描述
|
||||
if prov and isinstance(prov, Provider):
|
||||
llm_resp = await prov.text_chat(
|
||||
prompt="Please describe the image content.",
|
||||
image_urls=[await image_seg.convert_to_file_path()],
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
req.system_prompt += (
|
||||
f"Image Caption: {llm_resp.completion_text}\n"
|
||||
# 将图片描述作为文本添加到 content_parts
|
||||
content_parts.append(
|
||||
f"[Image Caption in quoted message]: {llm_resp.completion_text}"
|
||||
)
|
||||
else:
|
||||
logger.warning("No provider found for image captioning.")
|
||||
logger.warning(
|
||||
"No provider found for image captioning in quote."
|
||||
)
|
||||
except BaseException as e:
|
||||
logger.error(f"处理引用图片失败: {e}")
|
||||
|
||||
# 3. 将所有部分组合成文本并添加到 extra_user_content_parts 中
|
||||
# 确保引用内容被正确的标签包裹
|
||||
quoted_content = "\n".join(content_parts)
|
||||
# 确保所有内容都在<Quoted Message>标签内
|
||||
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
|
||||
|
||||
req.extra_user_content_parts.append(TextPart(text=quoted_text))
|
||||
|
||||
# 统一包裹所有系统提醒
|
||||
if system_parts:
|
||||
system_content = (
|
||||
"<system_reminder>" + "\n".join(system_parts) + "</system_reminder>"
|
||||
)
|
||||
req.extra_user_content_parts.append(TextPart(text=system_content))
|
||||
+1
@@ -71,6 +71,7 @@ class AdminCommands:
|
||||
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
|
||||
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
"""更新管理面板"""
|
||||
await event.send(MessageChain().message("正在尝试更新管理面板..."))
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
await event.send(MessageChain().message("管理面板更新完成。"))
|
||||
@@ -0,0 +1,88 @@
|
||||
import aiohttp
|
||||
|
||||
from astrbot.api import star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.star import command_management
|
||||
from astrbot.core.utils.io import get_dashboard_version
|
||||
|
||||
|
||||
class HelpCommand:
|
||||
def __init__(self, context: star.Context):
|
||||
self.context = context
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
"https://astrbot.app/notice.json",
|
||||
timeout=2,
|
||||
) as resp:
|
||||
return (await resp.json())["notice"]
|
||||
except BaseException:
|
||||
return ""
|
||||
|
||||
async def _build_reserved_command_lines(self) -> list[str]:
|
||||
"""
|
||||
使用实时指令配置生成内置指令清单,确保重命名/禁用后与实际生效状态保持一致。
|
||||
"""
|
||||
try:
|
||||
commands = await command_management.list_commands()
|
||||
except BaseException:
|
||||
return []
|
||||
|
||||
lines: list[str] = []
|
||||
hidden_commands = {"set", "unset", "websearch"}
|
||||
|
||||
def walk(items: list[dict], indent: int = 0):
|
||||
for item in items:
|
||||
if not item.get("reserved") or not item.get("enabled"):
|
||||
continue
|
||||
# 仅展示顶级指令或指令组
|
||||
if item.get("type") == "sub_command":
|
||||
continue
|
||||
if item.get("parent_signature"):
|
||||
continue
|
||||
|
||||
effective = (
|
||||
item.get("effective_command")
|
||||
or item.get("original_command")
|
||||
or item.get("handler_name")
|
||||
)
|
||||
if not effective:
|
||||
continue
|
||||
if effective in hidden_commands:
|
||||
continue
|
||||
|
||||
description = item.get("description") or ""
|
||||
desc_text = f" - {description}" if description else ""
|
||||
indent_prefix = " " * indent
|
||||
lines.append(f"{indent_prefix}/{effective}{desc_text}")
|
||||
|
||||
walk(commands)
|
||||
return lines
|
||||
|
||||
async def help(self, event: AstrMessageEvent):
|
||||
"""查看帮助"""
|
||||
notice = ""
|
||||
try:
|
||||
notice = await self._query_astrbot_notice()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
dashboard_version = await get_dashboard_version()
|
||||
command_lines = await self._build_reserved_command_lines()
|
||||
commands_section = (
|
||||
"\n".join(command_lines) if command_lines else "暂无启用的内置指令"
|
||||
)
|
||||
|
||||
msg_parts = [
|
||||
f"AstrBot v{VERSION}(WebUI: {dashboard_version})",
|
||||
"内置指令:",
|
||||
commands_section,
|
||||
]
|
||||
if notice:
|
||||
msg_parts.append(notice)
|
||||
msg = "\n".join(msg_parts)
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
+6
-4
@@ -184,7 +184,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_tts_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
@@ -198,7 +199,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message("请输入序号。"))
|
||||
return
|
||||
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_stt_providers()[idx2 - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
@@ -209,8 +211,8 @@ class ProviderCommands:
|
||||
event.set_result(MessageEventResult().message(f"成功切换到 {id_}。"))
|
||||
elif isinstance(idx, int):
|
||||
if idx > len(self.context.get_all_providers()) or idx < 1:
|
||||
event.set_result(MessageEventResult().message("无效的序号。"))
|
||||
|
||||
event.set_result(MessageEventResult().message("无效的提供商序号。"))
|
||||
return
|
||||
provider = self.context.get_all_providers()[idx - 1]
|
||||
id_ = provider.meta().id
|
||||
await self.context.provider_manager.set_provider(
|
||||
@@ -49,7 +49,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command_group("tool")
|
||||
def tool(self):
|
||||
pass
|
||||
"""函数工具管理"""
|
||||
|
||||
@tool.command("ls")
|
||||
async def tool_ls(self, event: AstrMessageEvent):
|
||||
@@ -73,7 +73,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command_group("plugin")
|
||||
def plugin(self):
|
||||
pass
|
||||
"""插件管理"""
|
||||
|
||||
@plugin.command("ls")
|
||||
async def plugin_ls(self, event: AstrMessageEvent):
|
||||
@@ -219,6 +219,7 @@ class Main(star.Star):
|
||||
@filter.permission_type(filter.PermissionType.ADMIN)
|
||||
@filter.command("dashboard_update")
|
||||
async def update_dashboard(self, event: AstrMessageEvent):
|
||||
"""更新管理面板"""
|
||||
await self.admin_c.update_dashboard(event)
|
||||
|
||||
@filter.command("set")
|
||||
+133
-134
@@ -157,9 +157,8 @@ class Main(star.Star):
|
||||
async def is_docker_available(self) -> bool:
|
||||
"""Check if docker is available"""
|
||||
try:
|
||||
docker = aiodocker.Docker()
|
||||
await docker.version()
|
||||
await docker.close()
|
||||
async with aiodocker.Docker() as docker:
|
||||
await docker.version()
|
||||
return True
|
||||
except BaseException as e:
|
||||
logger.info(f"检查 Docker 可用性: {e}")
|
||||
@@ -249,7 +248,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command_group("pi")
|
||||
def pi(self):
|
||||
pass
|
||||
"""代码执行器配置"""
|
||||
|
||||
@pi.command("absdir")
|
||||
async def pi_absdir(self, event: AstrMessageEvent, path: str = ""):
|
||||
@@ -279,14 +278,14 @@ class Main(star.Star):
|
||||
@pi.command("repull")
|
||||
async def pi_repull(self, event: AstrMessageEvent):
|
||||
"""重新拉取沙箱镜像"""
|
||||
docker = aiodocker.Docker()
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
await docker.images.delete(image_name, force=True)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
pass
|
||||
await docker.images.pull(image_name)
|
||||
async with aiodocker.Docker() as docker:
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
await docker.images.delete(image_name, force=True)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
pass
|
||||
await docker.images.pull(image_name)
|
||||
yield event.plain_result("重新拉取沙箱镜像成功。")
|
||||
|
||||
@pi.command("file")
|
||||
@@ -371,137 +370,137 @@ class Main(star.Star):
|
||||
obs = ""
|
||||
n = 5
|
||||
|
||||
for i in range(n):
|
||||
if i > 0:
|
||||
logger.info(f"Try {i + 1}/{n}")
|
||||
async with aiodocker.Docker() as docker:
|
||||
for i in range(n):
|
||||
if i > 0:
|
||||
logger.info(f"Try {i + 1}/{n}")
|
||||
|
||||
PROMPT_ = PROMPT.format(
|
||||
prompt=plain_text,
|
||||
extra_input=extra_inputs,
|
||||
extra_prompt=obs,
|
||||
)
|
||||
provider = self.context.get_using_provider()
|
||||
llm_response = await provider.text_chat(
|
||||
prompt=PROMPT_,
|
||||
session_id=f"{event.session_id}_{magic_code}_{i!s}",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"code interpreter llm gened code:" + llm_response.completion_text,
|
||||
)
|
||||
|
||||
# 整理代码并保存
|
||||
code_clean = await self.tidy_code(llm_response.completion_text)
|
||||
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
|
||||
f.write(code_clean)
|
||||
|
||||
# 启动容器
|
||||
docker = aiodocker.Docker()
|
||||
|
||||
# 检查有没有image
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
# 拉取镜像
|
||||
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
|
||||
await docker.images.pull(image_name)
|
||||
|
||||
yield event.plain_result(
|
||||
f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})",
|
||||
)
|
||||
|
||||
self.docker_host_astrbot_abs_path = self.config.get(
|
||||
"docker_host_astrbot_abs_path",
|
||||
"",
|
||||
)
|
||||
if self.docker_host_astrbot_abs_path:
|
||||
host_shared = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
self.shared_path,
|
||||
PROMPT_ = PROMPT.format(
|
||||
prompt=plain_text,
|
||||
extra_input=extra_inputs,
|
||||
extra_prompt=obs,
|
||||
)
|
||||
host_output = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
output_path,
|
||||
)
|
||||
host_workplace = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
workplace_path,
|
||||
provider = self.context.get_using_provider()
|
||||
llm_response = await provider.text_chat(
|
||||
prompt=PROMPT_,
|
||||
session_id=f"{event.session_id}_{magic_code}_{i!s}",
|
||||
)
|
||||
|
||||
else:
|
||||
host_shared = os.path.abspath(self.shared_path)
|
||||
host_output = os.path.abspath(output_path)
|
||||
host_workplace = os.path.abspath(workplace_path)
|
||||
logger.debug(
|
||||
"code interpreter llm gened code:" + llm_response.completion_text,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}",
|
||||
)
|
||||
# 整理代码并保存
|
||||
code_clean = await self.tidy_code(llm_response.completion_text)
|
||||
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
|
||||
f.write(code_clean)
|
||||
|
||||
container = await docker.containers.run(
|
||||
{
|
||||
"Image": image_name,
|
||||
"Cmd": ["python", "exec.py"],
|
||||
"Memory": 512 * 1024 * 1024,
|
||||
"NanoCPUs": 1000000000,
|
||||
"HostConfig": {
|
||||
"Binds": [
|
||||
f"{host_shared}:/astrbot_sandbox/shared:ro",
|
||||
f"{host_output}:/astrbot_sandbox/output:rw",
|
||||
f"{host_workplace}:/astrbot_sandbox:rw",
|
||||
],
|
||||
},
|
||||
"Env": [f"MAGIC_CODE={magic_code}"],
|
||||
"AutoRemove": True,
|
||||
},
|
||||
)
|
||||
# 检查有没有image
|
||||
image_name = await self.get_image_name()
|
||||
try:
|
||||
await docker.images.get(image_name)
|
||||
except aiodocker.exceptions.DockerError:
|
||||
# 拉取镜像
|
||||
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
|
||||
await docker.images.pull(image_name)
|
||||
|
||||
logger.debug(f"Container {container.id} created.")
|
||||
logs = await self.run_container(container)
|
||||
yield event.plain_result(
|
||||
f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})",
|
||||
)
|
||||
|
||||
logger.debug(f"Container {container.id} finished.")
|
||||
logger.debug(f"Container {container.id} logs: {logs}")
|
||||
|
||||
# 发送结果
|
||||
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
|
||||
ok = False
|
||||
traceback = ""
|
||||
for idx, log in enumerate(logs):
|
||||
match = re.match(pattern, log)
|
||||
if match:
|
||||
ok = True
|
||||
if match.group(1) == "TEXT":
|
||||
yield event.plain_result(match.group(2))
|
||||
elif match.group(1) == "IMAGE":
|
||||
image_path = os.path.join(workplace_path, match.group(2))
|
||||
logger.debug(f"Sending image: {image_path}")
|
||||
yield event.image_result(image_path)
|
||||
elif match.group(1) == "FILE":
|
||||
file_path = os.path.join(workplace_path, match.group(2))
|
||||
# logger.debug(f"Sending file: {file_path}")
|
||||
# file_s3_url = await self.file_upload(file_path)
|
||||
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
||||
file_name = os.path.basename(file_path)
|
||||
chain: list[BaseMessageComponent] = [
|
||||
File(name=file_name, file=file_path)
|
||||
]
|
||||
yield event.set_result(MessageEventResult(chain=chain))
|
||||
|
||||
elif "Traceback (most recent call last)" in log or "[Error]: " in log:
|
||||
traceback = "\n".join(logs[idx:])
|
||||
|
||||
if not ok:
|
||||
if traceback:
|
||||
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
|
||||
else:
|
||||
logger.warning(
|
||||
f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}",
|
||||
self.docker_host_astrbot_abs_path = self.config.get(
|
||||
"docker_host_astrbot_abs_path",
|
||||
"",
|
||||
)
|
||||
if self.docker_host_astrbot_abs_path:
|
||||
host_shared = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
self.shared_path,
|
||||
)
|
||||
break
|
||||
else:
|
||||
# 成功了
|
||||
self.user_file_msg_buffer.pop(event.get_session_id())
|
||||
return
|
||||
host_output = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
output_path,
|
||||
)
|
||||
host_workplace = os.path.join(
|
||||
self.docker_host_astrbot_abs_path,
|
||||
workplace_path,
|
||||
)
|
||||
|
||||
else:
|
||||
host_shared = os.path.abspath(self.shared_path)
|
||||
host_output = os.path.abspath(output_path)
|
||||
host_workplace = os.path.abspath(workplace_path)
|
||||
|
||||
logger.debug(
|
||||
f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}",
|
||||
)
|
||||
|
||||
container = await docker.containers.run(
|
||||
{
|
||||
"Image": image_name,
|
||||
"Cmd": ["python", "exec.py"],
|
||||
"Memory": 512 * 1024 * 1024,
|
||||
"NanoCPUs": 1000000000,
|
||||
"HostConfig": {
|
||||
"Binds": [
|
||||
f"{host_shared}:/astrbot_sandbox/shared:ro",
|
||||
f"{host_output}:/astrbot_sandbox/output:rw",
|
||||
f"{host_workplace}:/astrbot_sandbox:rw",
|
||||
],
|
||||
},
|
||||
"Env": [f"MAGIC_CODE={magic_code}"],
|
||||
"AutoRemove": True,
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(f"Container {container.id} created.")
|
||||
logs = await self.run_container(container)
|
||||
|
||||
logger.debug(f"Container {container.id} finished.")
|
||||
logger.debug(f"Container {container.id} logs: {logs}")
|
||||
|
||||
# 发送结果
|
||||
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
|
||||
ok = False
|
||||
traceback = ""
|
||||
for idx, log in enumerate(logs):
|
||||
match = re.match(pattern, log)
|
||||
if match:
|
||||
ok = True
|
||||
if match.group(1) == "TEXT":
|
||||
yield event.plain_result(match.group(2))
|
||||
elif match.group(1) == "IMAGE":
|
||||
image_path = os.path.join(workplace_path, match.group(2))
|
||||
logger.debug(f"Sending image: {image_path}")
|
||||
yield event.image_result(image_path)
|
||||
elif match.group(1) == "FILE":
|
||||
file_path = os.path.join(workplace_path, match.group(2))
|
||||
# logger.debug(f"Sending file: {file_path}")
|
||||
# file_s3_url = await self.file_upload(file_path)
|
||||
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
|
||||
file_name = os.path.basename(file_path)
|
||||
chain: list[BaseMessageComponent] = [
|
||||
File(name=file_name, file=file_path)
|
||||
]
|
||||
yield event.set_result(MessageEventResult(chain=chain))
|
||||
|
||||
elif (
|
||||
"Traceback (most recent call last)" in log or "[Error]: " in log
|
||||
):
|
||||
traceback = "\n".join(logs[idx:])
|
||||
|
||||
if not ok:
|
||||
if traceback:
|
||||
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
|
||||
else:
|
||||
logger.warning(
|
||||
f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}",
|
||||
)
|
||||
break
|
||||
else:
|
||||
# 成功了
|
||||
self.user_file_msg_buffer.pop(event.get_session_id())
|
||||
return
|
||||
|
||||
yield event.plain_result(
|
||||
"经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。",
|
||||
@@ -179,7 +179,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command_group("reminder")
|
||||
def reminder(self):
|
||||
"""The command group of the reminder."""
|
||||
"""待办提醒"""
|
||||
|
||||
async def get_upcoming_reminders(self, unified_msg_origin: str):
|
||||
"""Get upcoming reminders."""
|
||||
@@ -185,6 +185,7 @@ class Main(star.Star):
|
||||
|
||||
@filter.command("websearch")
|
||||
async def websearch(self, event: AstrMessageEvent, oper: str | None = None):
|
||||
"""网页搜索指令(已废弃)"""
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
"此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。",
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.9.2"
|
||||
__version__ = "4.10.3"
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class ContentPart(BaseModel):
|
||||
|
||||
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
|
||||
|
||||
type: str
|
||||
type: Literal["text", "think", "image_url", "audio_url"]
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
@@ -63,6 +63,28 @@ class TextPart(ContentPart):
|
||||
text: str
|
||||
|
||||
|
||||
class ThinkPart(ContentPart):
|
||||
"""
|
||||
>>> ThinkPart(think="I think I need to think about this.").model_dump()
|
||||
{'type': 'think', 'think': 'I think I need to think about this.', 'encrypted': None}
|
||||
"""
|
||||
|
||||
type: str = "think"
|
||||
think: str
|
||||
encrypted: str | None = None
|
||||
"""Encrypted thinking content, or signature."""
|
||||
|
||||
def merge_in_place(self, other: Any) -> bool:
|
||||
if not isinstance(other, ThinkPart):
|
||||
return False
|
||||
if self.encrypted:
|
||||
return False
|
||||
self.think += other.think
|
||||
if other.encrypted:
|
||||
self.encrypted = other.encrypted
|
||||
return True
|
||||
|
||||
|
||||
class ImageURLPart(ContentPart):
|
||||
"""
|
||||
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
|
||||
@@ -122,10 +144,12 @@ class ToolCall(BaseModel):
|
||||
extra_content: dict[str, Any] | None = None
|
||||
"""Extra metadata for the tool call."""
|
||||
|
||||
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.extra_content is None:
|
||||
kwargs.setdefault("exclude", set()).add("extra_content")
|
||||
return super().model_dump(**kwargs)
|
||||
data.pop("extra_content", None)
|
||||
return data
|
||||
|
||||
|
||||
class ToolCallPart(BaseModel):
|
||||
@@ -167,6 +191,15 @@ class Message(BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def serialize(self, handler):
|
||||
data = handler(self)
|
||||
if self.tool_calls is None:
|
||||
data.pop("tool_calls", None)
|
||||
if self.tool_call_id is None:
|
||||
data.pop("tool_call_id", None)
|
||||
return data
|
||||
|
||||
|
||||
class AssistantMessageSegment(Message):
|
||||
"""A message segment from the assistant."""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import TokenUsage
|
||||
|
||||
|
||||
class AgentResponseData(T.TypedDict):
|
||||
@@ -12,3 +13,23 @@ class AgentResponseData(T.TypedDict):
|
||||
class AgentResponse:
|
||||
type: str
|
||||
data: AgentResponseData
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStats:
|
||||
token_usage: TokenUsage = field(default_factory=TokenUsage)
|
||||
start_time: float = 0.0
|
||||
end_time: float = 0.0
|
||||
time_to_first_token: float = 0.0
|
||||
|
||||
@property
|
||||
def duration(self) -> float:
|
||||
return self.end_time - self.start_time
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"token_usage": self.token_usage.__dict__,
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"time_to_first_token": self.time_to_first_token,
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ from .message import Message
|
||||
TContext = TypeVar("TContext", default=Any)
|
||||
|
||||
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
@dataclass
|
||||
class ContextWrapper(Generic[TContext]):
|
||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing as T
|
||||
|
||||
@@ -12,6 +13,8 @@ from mcp.types import (
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import TextPart, ThinkPart
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
)
|
||||
@@ -24,7 +27,7 @@ from astrbot.core.provider.provider import Provider
|
||||
|
||||
from ..hooks import BaseAgentRunHooks
|
||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||
from ..response import AgentResponseData
|
||||
from ..response import AgentResponseData, AgentStats
|
||||
from ..run_context import ContextWrapper, TContext
|
||||
from ..tool_executor import BaseFunctionToolExecutor
|
||||
from .base import AgentResponse, AgentState, BaseAgentRunner
|
||||
@@ -69,14 +72,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
self.run_context.messages = messages
|
||||
|
||||
self.stats = AgentStats()
|
||||
self.stats.start_time = time.time()
|
||||
|
||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||
"""Yields chunks *and* a final LLMResponse."""
|
||||
payload = {
|
||||
"contexts": self.run_context.messages, # list[Message]
|
||||
"func_tool": self.req.func_tool,
|
||||
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
|
||||
"session_id": self.req.session_id,
|
||||
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
|
||||
}
|
||||
|
||||
if self.streaming:
|
||||
stream = self.provider.text_chat_stream(**self.req.__dict__)
|
||||
stream = self.provider.text_chat_stream(**payload)
|
||||
async for resp in stream: # type: ignore
|
||||
yield resp
|
||||
else:
|
||||
yield await self.provider.text_chat(**self.req.__dict__)
|
||||
yield await self.provider.text_chat(**payload)
|
||||
|
||||
@override
|
||||
async def step(self):
|
||||
@@ -98,6 +112,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
|
||||
async for llm_response in self._iter_llm_responses():
|
||||
if llm_response.is_chunk:
|
||||
# update ttft
|
||||
if self.stats.time_to_first_token == 0:
|
||||
self.stats.time_to_first_token = time.time() - self.stats.start_time
|
||||
|
||||
if llm_response.result_chain:
|
||||
yield AgentResponse(
|
||||
type="streaming_delta",
|
||||
@@ -121,6 +139,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
continue
|
||||
llm_resp_result = llm_response
|
||||
|
||||
if not llm_response.is_chunk and llm_response.usage:
|
||||
# only count the token usage of the final response for computation purpose
|
||||
self.stats.token_usage += llm_response.usage
|
||||
break # got final response
|
||||
|
||||
if not llm_resp_result:
|
||||
@@ -132,6 +154,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
if llm_resp.role == "err":
|
||||
# 如果 LLM 响应错误,转换到错误状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self.stats.end_time = time.time()
|
||||
self._transition_state(AgentState.ERROR)
|
||||
yield AgentResponse(
|
||||
type="err",
|
||||
@@ -146,13 +169,21 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果没有工具调用,转换到完成状态
|
||||
self.final_llm_resp = llm_resp
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
|
||||
# record the final assistant message
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="assistant",
|
||||
content=llm_resp.completion_text or "",
|
||||
),
|
||||
)
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
parts.append(TextPart(text=llm_resp.completion_text or "*No response*"))
|
||||
self.run_context.messages.append(Message(role="assistant", content=parts))
|
||||
|
||||
# call the on_agent_done hook
|
||||
try:
|
||||
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
|
||||
except Exception as e:
|
||||
@@ -175,29 +206,35 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# 如果有工具调用,还需处理工具调用
|
||||
if llm_resp.tools_call_name:
|
||||
tool_call_result_blocks = []
|
||||
for tool_call_name in llm_resp.tools_call_name:
|
||||
yield AgentResponse(
|
||||
type="tool_call",
|
||||
data=AgentResponseData(
|
||||
chain=MessageChain(type="tool_call").message(
|
||||
f"🔨 调用工具: {tool_call_name}"
|
||||
),
|
||||
),
|
||||
)
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
result.type = "tool_call_result"
|
||||
if result.type is None:
|
||||
# should not happen
|
||||
continue
|
||||
if result.type == "tool_direct_result":
|
||||
ar_type = "tool_call_result"
|
||||
else:
|
||||
ar_type = result.type
|
||||
yield AgentResponse(
|
||||
type="tool_call_result",
|
||||
type=ar_type,
|
||||
data=AgentResponseData(chain=result),
|
||||
)
|
||||
# 将结果添加到上下文中
|
||||
parts = []
|
||||
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
|
||||
parts.append(
|
||||
ThinkPart(
|
||||
think=llm_resp.reasoning_content,
|
||||
encrypted=llm_resp.reasoning_signature,
|
||||
)
|
||||
)
|
||||
parts.append(TextPart(text=llm_resp.completion_text or "*No response*"))
|
||||
tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=AssistantMessageSegment(
|
||||
tool_calls=llm_resp.to_openai_to_calls_model(),
|
||||
content=llm_resp.completion_text,
|
||||
content=parts,
|
||||
),
|
||||
tool_calls_result=tool_call_result_blocks,
|
||||
)
|
||||
@@ -218,6 +255,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
# 如果循环结束了但是 agent 还没有完成,说明是达到了 max_step
|
||||
if not self.done():
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
# 拔掉所有工具
|
||||
if self.req:
|
||||
self.req.func_tool = None
|
||||
# 注入提示词
|
||||
self.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
# 再执行最后一步
|
||||
async for resp in self.step():
|
||||
yield resp
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
@@ -233,6 +289,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
yield MessageChain(
|
||||
type="tool_call",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"name": func_tool_name,
|
||||
"args": func_tool_args,
|
||||
"ts": time.time(),
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
try:
|
||||
if not req.func_tool:
|
||||
return
|
||||
@@ -306,7 +375,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content=res.content[0].text,
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -328,7 +396,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content=resource.text,
|
||||
),
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
@@ -352,20 +419,34 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
content="返回的数据类型不受支持",
|
||||
),
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
# 发送消息逻辑在 ToolExecutor 中处理了。
|
||||
logger.warning(
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户。"
|
||||
)
|
||||
self._transition_state(AgentState.DONE)
|
||||
self.stats.end_time = time.time()
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具没有返回值或者将结果直接发送给了用户*",
|
||||
),
|
||||
)
|
||||
else:
|
||||
# 不应该出现其他类型
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。",
|
||||
f"Tool 返回了不支持的类型: {type(resp)}。",
|
||||
)
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="*工具返回了不支持的类型,请告诉用户检查这个工具的定义和实现。*",
|
||||
),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -387,6 +468,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
)
|
||||
|
||||
# yield the last tool call result
|
||||
if tool_call_result_blocks:
|
||||
last_tcr_content = str(tool_call_result_blocks[-1].content)
|
||||
yield MessageChain(
|
||||
type="tool_call_result",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"ts": time.time(),
|
||||
"result": last_tcr_content,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# 处理函数调用响应
|
||||
if tool_call_result_blocks:
|
||||
yield tool_call_result_blocks
|
||||
|
||||
@@ -6,8 +6,10 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
|
||||
@dataclass(config={"arbitrary_types_allowed": True})
|
||||
@dataclass
|
||||
class AstrAgentContext:
|
||||
__pydantic_config__ = {"arbitrary_types_allowed": True}
|
||||
|
||||
context: Context
|
||||
"""The star context instance"""
|
||||
event: AstrMessageEvent
|
||||
|
||||
@@ -13,6 +13,12 @@ from astrbot.core.star.star_handler import EventType
|
||||
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
async def on_agent_done(self, run_context, llm_response):
|
||||
# 执行事件钩子
|
||||
if llm_response and llm_response.reasoning_content:
|
||||
# we will use this in result_decorate stage to inject reasoning content to chain
|
||||
run_context.context.event.set_extra(
|
||||
"_llm_reasoning_content", llm_response.reasoning_content
|
||||
)
|
||||
|
||||
await call_event_hook(
|
||||
run_context.context.event,
|
||||
EventType.OnLLMResponseEvent,
|
||||
|
||||
@@ -2,8 +2,10 @@ import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
MessageEventResult,
|
||||
@@ -23,8 +25,25 @@ async def run_agent(
|
||||
) -> AsyncGenerator[MessageChain | None, None]:
|
||||
step_idx = 0
|
||||
astr_event = agent_runner.run_context.context.event
|
||||
while step_idx < max_step:
|
||||
while step_idx < max_step + 1:
|
||||
step_idx += 1
|
||||
|
||||
if step_idx == max_step + 1:
|
||||
logger.warning(
|
||||
f"Agent reached max steps ({max_step}), forcing a final response."
|
||||
)
|
||||
if not agent_runner.done():
|
||||
# 拔掉所有工具
|
||||
if agent_runner.req:
|
||||
agent_runner.req.func_tool = None
|
||||
# 注入提示词
|
||||
agent_runner.run_context.messages.append(
|
||||
Message(
|
||||
role="user",
|
||||
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
async for resp in agent_runner.step():
|
||||
if astr_event.is_stopped():
|
||||
@@ -33,16 +52,27 @@ async def run_agent(
|
||||
msg_chain = resp.data["chain"]
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
await astr_event.send(resp.data["chain"])
|
||||
await astr_event.send(msg_chain)
|
||||
continue
|
||||
if astr_event.get_platform_id() == "webchat":
|
||||
await astr_event.send(msg_chain)
|
||||
# 对于其他情况,暂时先不处理
|
||||
continue
|
||||
elif resp.type == "tool_call":
|
||||
if agent_runner.streaming:
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
if show_tool_use:
|
||||
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(resp.data["chain"])
|
||||
elif show_tool_use:
|
||||
json_comp = resp.data["chain"].chain[0]
|
||||
if isinstance(json_comp, Json):
|
||||
m = f"🔨 调用工具: {json_comp.data.get('name')}"
|
||||
else:
|
||||
m = "🔨 调用工具..."
|
||||
chain = MessageChain(type="tool_call").message(m)
|
||||
await astr_event.send(chain)
|
||||
continue
|
||||
|
||||
if stream_to_general and resp.type == "streaming_delta":
|
||||
@@ -69,6 +99,15 @@ async def run_agent(
|
||||
continue
|
||||
yield resp.data["chain"] # MessageChain
|
||||
if agent_runner.done():
|
||||
# send agent stats to webchat
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(
|
||||
MessageChain(
|
||||
type="agent_stats",
|
||||
chain=[Json(data=agent_runner.stats.to_dict())],
|
||||
)
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -209,12 +209,42 @@ async def call_local_llm_tool(
|
||||
else:
|
||||
raise ValueError(f"未知的方法名: {method_name}")
|
||||
except ValueError as e:
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
|
||||
except TypeError:
|
||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||
raise Exception(f"Tool execution ValueError: {e}") from e
|
||||
except TypeError as e:
|
||||
# 获取函数的签名(包括类型),除了第一个 event/context 参数。
|
||||
try:
|
||||
sig = inspect.signature(handler)
|
||||
params = list(sig.parameters.values())
|
||||
# 跳过第一个参数(event 或 context)
|
||||
if params:
|
||||
params = params[1:]
|
||||
|
||||
param_strs = []
|
||||
for param in params:
|
||||
param_str = param.name
|
||||
if param.annotation != inspect.Parameter.empty:
|
||||
# 获取类型注解的字符串表示
|
||||
if isinstance(param.annotation, type):
|
||||
type_str = param.annotation.__name__
|
||||
else:
|
||||
type_str = str(param.annotation)
|
||||
param_str += f": {type_str}"
|
||||
if param.default != inspect.Parameter.empty:
|
||||
param_str += f" = {param.default!r}"
|
||||
param_strs.append(param_str)
|
||||
|
||||
handler_param_str = (
|
||||
", ".join(param_strs) if param_strs else "(no additional parameters)"
|
||||
)
|
||||
except Exception:
|
||||
handler_param_str = "(unable to inspect signature)"
|
||||
|
||||
raise Exception(
|
||||
f"Tool handler parameter mismatch, please check the handler definition. Handler parameters: {handler_param_str}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
trace_ = traceback.format_exc()
|
||||
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
|
||||
raise Exception(f"Tool execution error: {e}. Traceback: {trace_}") from e
|
||||
|
||||
if not ready_to_call:
|
||||
return
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""AstrBot 备份与恢复模块
|
||||
|
||||
提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。
|
||||
"""
|
||||
|
||||
# 从 constants 模块导入共享常量
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
# 导入导出器和导入器
|
||||
from .exporter import AstrBotExporter
|
||||
from .importer import AstrBotImporter, ImportPreCheckResult
|
||||
|
||||
__all__ = [
|
||||
"AstrBotExporter",
|
||||
"AstrBotImporter",
|
||||
"ImportPreCheckResult",
|
||||
"MAIN_DB_MODELS",
|
||||
"KB_METADATA_MODELS",
|
||||
"get_backup_directories",
|
||||
"BACKUP_MANIFEST_VERSION",
|
||||
]
|
||||
@@ -0,0 +1,77 @@
|
||||
"""AstrBot 备份模块共享常量
|
||||
|
||||
此文件定义了导出器和导入器共享的常量,确保两端配置一致。
|
||||
"""
|
||||
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
PlatformSession,
|
||||
PlatformStat,
|
||||
Preference,
|
||||
)
|
||||
from astrbot.core.knowledge_base.models import (
|
||||
KBDocument,
|
||||
KBMedia,
|
||||
KnowledgeBase,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_plugin_data_path,
|
||||
get_astrbot_plugin_path,
|
||||
get_astrbot_t2i_templates_path,
|
||||
get_astrbot_temp_path,
|
||||
get_astrbot_webchat_path,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 共享常量 - 确保导出和导入端配置一致
|
||||
# ============================================================
|
||||
|
||||
# 主数据库模型类映射
|
||||
MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
|
||||
"platform_stats": PlatformStat,
|
||||
"conversations": ConversationV2,
|
||||
"personas": Persona,
|
||||
"preferences": Preference,
|
||||
"platform_message_history": PlatformMessageHistory,
|
||||
"platform_sessions": PlatformSession,
|
||||
"attachments": Attachment,
|
||||
"command_configs": CommandConfig,
|
||||
"command_conflicts": CommandConflict,
|
||||
}
|
||||
|
||||
# 知识库元数据模型类映射
|
||||
KB_METADATA_MODELS: dict[str, type[SQLModel]] = {
|
||||
"knowledge_bases": KnowledgeBase,
|
||||
"kb_documents": KBDocument,
|
||||
"kb_media": KBMedia,
|
||||
}
|
||||
|
||||
|
||||
def get_backup_directories() -> dict[str, str]:
|
||||
"""获取需要备份的目录列表
|
||||
|
||||
使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录。
|
||||
|
||||
Returns:
|
||||
dict: 键为备份文件中的目录名称,值为目录的绝对路径
|
||||
"""
|
||||
return {
|
||||
"plugins": get_astrbot_plugin_path(), # 插件本体
|
||||
"plugin_data": get_astrbot_plugin_data_path(), # 插件数据
|
||||
"config": get_astrbot_config_path(), # 配置目录
|
||||
"t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板
|
||||
"webchat": get_astrbot_webchat_path(), # WebChat 数据
|
||||
"temp": get_astrbot_temp_path(), # 临时文件
|
||||
}
|
||||
|
||||
|
||||
# 备份清单版本号
|
||||
BACKUP_MANIFEST_VERSION = "1.1"
|
||||
@@ -0,0 +1,477 @@
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
负责将所有数据导出为 ZIP 备份文件。
|
||||
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import zipfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_backups_path,
|
||||
get_astrbot_data_path,
|
||||
)
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
BACKUP_MANIFEST_VERSION,
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
|
||||
|
||||
class AstrBotExporter:
|
||||
"""AstrBot 数据导出器
|
||||
|
||||
导出内容:
|
||||
- 主数据库所有表(data/data_v4.db)
|
||||
- 知识库元数据(data/knowledge_base/kb.db)
|
||||
- 每个知识库的向量文档数据
|
||||
- 配置文件(data/cmd_config.json)
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self._checksums: dict[str, str] = {}
|
||||
|
||||
async def export_all(
|
||||
self,
|
||||
output_dir: str | None = None,
|
||||
progress_callback: Any | None = None,
|
||||
) -> str:
|
||||
"""导出所有数据到 ZIP 文件
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
str: 生成的 ZIP 文件路径
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = get_astrbot_backups_path()
|
||||
|
||||
# 确保输出目录存在
|
||||
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
zip_filename = f"astrbot_backup_{timestamp}.zip"
|
||||
zip_path = os.path.join(output_dir, zip_filename)
|
||||
|
||||
logger.info(f"开始导出备份到 {zip_path}")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
||||
# 1. 导出主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导出主数据库...")
|
||||
main_data = await self._export_main_database()
|
||||
main_db_json = json.dumps(
|
||||
main_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/main_db.json", main_db_json)
|
||||
self._add_checksum("databases/main_db.json", main_db_json)
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导出完成")
|
||||
|
||||
# 2. 导出知识库数据
|
||||
kb_meta_data: dict[str, Any] = {
|
||||
"knowledge_bases": [],
|
||||
"kb_documents": [],
|
||||
"kb_media": [],
|
||||
}
|
||||
if self.kb_manager:
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 0, 100, "正在导出知识库元数据..."
|
||||
)
|
||||
kb_meta_data = await self._export_kb_metadata()
|
||||
kb_meta_json = json.dumps(
|
||||
kb_meta_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
zf.writestr("databases/kb_metadata.json", kb_meta_json)
|
||||
self._add_checksum("databases/kb_metadata.json", kb_meta_json)
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_metadata", 100, 100, "知识库元数据导出完成"
|
||||
)
|
||||
|
||||
# 导出每个知识库的文档数据
|
||||
kb_insts = self.kb_manager.kb_insts
|
||||
total_kbs = len(kb_insts)
|
||||
for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()):
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents",
|
||||
idx,
|
||||
total_kbs,
|
||||
f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...",
|
||||
)
|
||||
doc_data = await self._export_kb_documents(kb_helper)
|
||||
doc_json = json.dumps(
|
||||
doc_data, ensure_ascii=False, indent=2, default=str
|
||||
)
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
zf.writestr(doc_path, doc_json)
|
||||
self._add_checksum(doc_path, doc_json)
|
||||
|
||||
# 导出 FAISS 索引文件
|
||||
await self._export_faiss_index(zf, kb_helper, kb_id)
|
||||
|
||||
# 导出知识库多媒体文件
|
||||
await self._export_kb_media_files(zf, kb_helper, kb_id)
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"kb_documents", total_kbs, total_kbs, "知识库文档导出完成"
|
||||
)
|
||||
|
||||
# 3. 导出配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导出配置文件...")
|
||||
if os.path.exists(self.config_path):
|
||||
with open(self.config_path, encoding="utf-8") as f:
|
||||
config_content = f.read()
|
||||
zf.writestr("config/cmd_config.json", config_content)
|
||||
self._add_checksum("config/cmd_config.json", config_content)
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导出完成")
|
||||
|
||||
# 4. 导出附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导出附件...")
|
||||
await self._export_attachments(zf, main_data.get("attachments", []))
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导出完成")
|
||||
|
||||
# 5. 导出插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导出插件和数据目录..."
|
||||
)
|
||||
dir_stats = await self._export_directories(zf)
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导出完成")
|
||||
|
||||
# 6. 生成 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 0, 100, "正在生成清单...")
|
||||
manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats)
|
||||
manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2)
|
||||
zf.writestr("manifest.json", manifest_json)
|
||||
if progress_callback:
|
||||
await progress_callback("manifest", 100, 100, "清单生成完成")
|
||||
|
||||
logger.info(f"备份导出完成: {zip_path}")
|
||||
return zip_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"备份导出失败: {e}")
|
||||
# 清理失败的文件
|
||||
if os.path.exists(zip_path):
|
||||
os.remove(zip_path)
|
||||
raise
|
||||
|
||||
async def _export_main_database(self) -> dict[str, list[dict]]:
|
||||
"""导出主数据库所有表"""
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_metadata(self) -> dict[str, list[dict]]:
|
||||
"""导出知识库元数据库"""
|
||||
if not self.kb_manager:
|
||||
return {"knowledge_bases": [], "kb_documents": [], "kb_media": []}
|
||||
|
||||
export_data: dict[str, list[dict]] = {}
|
||||
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
result = await session.execute(select(model_class))
|
||||
records = result.scalars().all()
|
||||
export_data[table_name] = [
|
||||
self._model_to_dict(record) for record in records
|
||||
]
|
||||
logger.debug(
|
||||
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库表 {table_name} 失败: {e}")
|
||||
export_data[table_name] = []
|
||||
|
||||
return export_data
|
||||
|
||||
async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]:
|
||||
"""导出知识库的文档块数据"""
|
||||
try:
|
||||
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
|
||||
|
||||
vec_db: FaissVecDB = kb_helper.vec_db
|
||||
if not vec_db or not vec_db.document_storage:
|
||||
return {"documents": []}
|
||||
|
||||
# 获取所有文档
|
||||
docs = await vec_db.document_storage.get_documents(
|
||||
metadata_filters={},
|
||||
offset=0,
|
||||
limit=None, # 获取全部
|
||||
)
|
||||
|
||||
return {"documents": docs}
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库文档失败: {e}")
|
||||
return {"documents": []}
|
||||
|
||||
async def _export_faiss_index(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_helper: Any,
|
||||
kb_id: str,
|
||||
) -> None:
|
||||
"""导出 FAISS 索引文件"""
|
||||
try:
|
||||
index_path = kb_helper.kb_dir / "index.faiss"
|
||||
if index_path.exists():
|
||||
archive_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
zf.write(str(index_path), archive_path)
|
||||
logger.debug(f"导出 FAISS 索引: {archive_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"导出 FAISS 索引失败: {e}")
|
||||
|
||||
async def _export_kb_media_files(
|
||||
self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str
|
||||
) -> None:
|
||||
"""导出知识库的多媒体文件"""
|
||||
try:
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if not media_dir.exists():
|
||||
return
|
||||
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
file_path = Path(root) / file
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(kb_helper.kb_dir)
|
||||
archive_path = f"files/kb_media/{kb_id}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出知识库媒体文件失败: {e}")
|
||||
|
||||
async def _export_directories(
|
||||
self, zf: zipfile.ZipFile
|
||||
) -> dict[str, dict[str, int]]:
|
||||
"""导出插件和其他数据目录
|
||||
|
||||
Returns:
|
||||
dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}}
|
||||
"""
|
||||
stats: dict[str, dict[str, int]] = {}
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name, dir_path in backup_directories.items():
|
||||
full_path = Path(dir_path)
|
||||
if not full_path.exists():
|
||||
logger.debug(f"目录不存在,跳过: {full_path}")
|
||||
continue
|
||||
|
||||
file_count = 0
|
||||
total_size = 0
|
||||
|
||||
try:
|
||||
for root, dirs, files in os.walk(full_path):
|
||||
# 跳过 __pycache__ 目录
|
||||
dirs[:] = [d for d in dirs if d != "__pycache__"]
|
||||
|
||||
for file in files:
|
||||
# 跳过 .pyc 文件
|
||||
if file.endswith(".pyc"):
|
||||
continue
|
||||
|
||||
file_path = Path(root) / file
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = file_path.relative_to(full_path)
|
||||
archive_path = f"directories/{dir_name}/{rel_path}"
|
||||
zf.write(str(file_path), archive_path)
|
||||
file_count += 1
|
||||
total_size += file_path.stat().st_size
|
||||
except Exception as e:
|
||||
logger.warning(f"导出文件 {file_path} 失败: {e}")
|
||||
|
||||
stats[dir_name] = {"files": file_count, "size": total_size}
|
||||
logger.debug(
|
||||
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出目录 {dir_path} 失败: {e}")
|
||||
stats[dir_name] = {"files": 0, "size": 0}
|
||||
|
||||
return stats
|
||||
|
||||
async def _export_attachments(
|
||||
self, zf: zipfile.ZipFile, attachments: list[dict]
|
||||
) -> None:
|
||||
"""导出附件文件"""
|
||||
for attachment in attachments:
|
||||
try:
|
||||
file_path = attachment.get("path", "")
|
||||
if file_path and os.path.exists(file_path):
|
||||
# 使用 attachment_id 作为文件名
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
archive_path = f"files/attachments/{attachment_id}{ext}"
|
||||
zf.write(file_path, archive_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"导出附件失败: {e}")
|
||||
|
||||
def _model_to_dict(self, record: Any) -> dict:
|
||||
"""将 SQLModel 实例转换为字典
|
||||
|
||||
这是数据库无关的序列化方式,支持未来迁移到其他数据库。
|
||||
"""
|
||||
# 使用 SQLModel 内置的 model_dump 方法(如果可用)
|
||||
if hasattr(record, "model_dump"):
|
||||
data = record.model_dump(mode="python")
|
||||
# 处理 datetime 类型
|
||||
for key, value in data.items():
|
||||
if isinstance(value, datetime):
|
||||
data[key] = value.isoformat()
|
||||
return data
|
||||
|
||||
# 回退到手动提取
|
||||
data = {}
|
||||
# 使用 inspect 获取表信息
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
mapper = sa_inspect(record.__class__)
|
||||
for column in mapper.columns:
|
||||
value = getattr(record, column.name)
|
||||
# 处理 datetime 类型 - 统一转为 ISO 格式字符串
|
||||
if isinstance(value, datetime):
|
||||
value = value.isoformat()
|
||||
data[column.name] = value
|
||||
return data
|
||||
|
||||
def _add_checksum(self, path: str, content: str | bytes) -> None:
|
||||
"""计算并添加文件校验和"""
|
||||
if isinstance(content, str):
|
||||
content = content.encode("utf-8")
|
||||
checksum = hashlib.sha256(content).hexdigest()
|
||||
self._checksums[path] = f"sha256:{checksum}"
|
||||
|
||||
def _generate_manifest(
|
||||
self,
|
||||
main_data: dict[str, list[dict]],
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
dir_stats: dict[str, dict[str, int]] | None = None,
|
||||
) -> dict:
|
||||
"""生成备份清单"""
|
||||
if dir_stats is None:
|
||||
dir_stats = {}
|
||||
# 收集知识库 ID
|
||||
kb_document_tables = {}
|
||||
if self.kb_manager:
|
||||
for kb_id in self.kb_manager.kb_insts.keys():
|
||||
kb_document_tables[kb_id] = "documents"
|
||||
|
||||
# 收集附件文件列表
|
||||
attachment_files = []
|
||||
for attachment in main_data.get("attachments", []):
|
||||
attachment_id = attachment.get("attachment_id", "")
|
||||
path = attachment.get("path", "")
|
||||
if attachment_id and path:
|
||||
ext = os.path.splitext(path)[1]
|
||||
attachment_files.append(f"{attachment_id}{ext}")
|
||||
|
||||
# 收集知识库媒体文件
|
||||
kb_media_files: dict[str, list[str]] = {}
|
||||
if self.kb_manager:
|
||||
for kb_id, kb_helper in self.kb_manager.kb_insts.items():
|
||||
media_files: list[str] = []
|
||||
media_dir = kb_helper.kb_medias_dir
|
||||
if media_dir.exists():
|
||||
for root, _, files in os.walk(media_dir):
|
||||
for file in files:
|
||||
media_files.append(file)
|
||||
if media_files:
|
||||
kb_media_files[kb_id] = media_files
|
||||
|
||||
manifest = {
|
||||
"version": BACKUP_MANIFEST_VERSION,
|
||||
"astrbot_version": VERSION,
|
||||
"exported_at": datetime.now(timezone.utc).isoformat(),
|
||||
"origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传
|
||||
"schema_version": {
|
||||
"main_db": "v4",
|
||||
"kb_db": "v1",
|
||||
},
|
||||
"tables": {
|
||||
"main_db": list(main_data.keys()),
|
||||
"kb_metadata": list(kb_meta_data.keys()),
|
||||
"kb_documents": kb_document_tables,
|
||||
},
|
||||
"files": {
|
||||
"attachments": attachment_files,
|
||||
"kb_media": kb_media_files,
|
||||
},
|
||||
"directories": list(dir_stats.keys()),
|
||||
"checksums": self._checksums,
|
||||
"statistics": {
|
||||
"main_db": {
|
||||
table: len(records) for table, records in main_data.items()
|
||||
},
|
||||
"kb_metadata": {
|
||||
table: len(records) for table, records in kb_meta_data.items()
|
||||
},
|
||||
"directories": dir_stats,
|
||||
},
|
||||
}
|
||||
|
||||
return manifest
|
||||
@@ -0,0 +1,761 @@
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
负责从 ZIP 备份文件恢复所有数据。
|
||||
导入时进行版本校验:
|
||||
- 主版本(前两位)不同时直接拒绝导入
|
||||
- 小版本(第三位)不同时提示警告,用户可选择强制导入
|
||||
- 版本匹配时也需要用户确认
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import delete
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_data_path,
|
||||
get_astrbot_knowledge_base_path,
|
||||
)
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
|
||||
# 从共享常量模块导入
|
||||
from .constants import (
|
||||
KB_METADATA_MODELS,
|
||||
MAIN_DB_MODELS,
|
||||
get_backup_directories,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
|
||||
|
||||
|
||||
def _get_major_version(version_str: str) -> str:
|
||||
"""提取版本的主版本部分(前两位)
|
||||
|
||||
Args:
|
||||
version_str: 版本字符串,如 "4.9.1", "4.10.0-beta"
|
||||
|
||||
Returns:
|
||||
主版本字符串,如 "4.9", "4.10"
|
||||
"""
|
||||
if not version_str:
|
||||
return "0.0"
|
||||
# 移除 v 前缀和预发布标签
|
||||
version = version_str.lower().replace("v", "").split("-")[0].split("+")[0]
|
||||
parts = [p for p in version.split(".") if p] # 过滤空字符串
|
||||
if len(parts) >= 2:
|
||||
return f"{parts[0]}.{parts[1]}"
|
||||
elif len(parts) == 1 and parts[0]:
|
||||
return f"{parts[0]}.0"
|
||||
return "0.0"
|
||||
|
||||
|
||||
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
KB_PATH = get_astrbot_knowledge_base_path()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImportPreCheckResult:
|
||||
"""导入预检查结果
|
||||
|
||||
用于在实际导入前检查备份文件的版本兼容性,
|
||||
并返回确认信息让用户决定是否继续导入。
|
||||
"""
|
||||
|
||||
# 检查是否通过(文件有效且版本可导入)
|
||||
valid: bool = False
|
||||
# 是否可以导入(版本兼容)
|
||||
can_import: bool = False
|
||||
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
|
||||
version_status: str = ""
|
||||
# 备份文件中的 AstrBot 版本
|
||||
backup_version: str = ""
|
||||
# 当前运行的 AstrBot 版本
|
||||
current_version: str = VERSION
|
||||
# 备份创建时间
|
||||
backup_time: str = ""
|
||||
# 确认消息(显示给用户)
|
||||
confirm_message: str = ""
|
||||
# 警告消息列表
|
||||
warnings: list[str] = field(default_factory=list)
|
||||
# 错误消息(如果检查失败)
|
||||
error: str = ""
|
||||
# 备份包含的内容摘要
|
||||
backup_summary: dict = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"valid": self.valid,
|
||||
"can_import": self.can_import,
|
||||
"version_status": self.version_status,
|
||||
"backup_version": self.backup_version,
|
||||
"current_version": self.current_version,
|
||||
"backup_time": self.backup_time,
|
||||
"confirm_message": self.confirm_message,
|
||||
"warnings": self.warnings,
|
||||
"error": self.error,
|
||||
"backup_summary": self.backup_summary,
|
||||
}
|
||||
|
||||
|
||||
class ImportResult:
|
||||
"""导入结果"""
|
||||
|
||||
def __init__(self):
|
||||
self.success = True
|
||||
self.imported_tables: dict[str, int] = {}
|
||||
self.imported_files: dict[str, int] = {}
|
||||
self.imported_directories: dict[str, int] = {}
|
||||
self.warnings: list[str] = []
|
||||
self.errors: list[str] = []
|
||||
|
||||
def add_warning(self, msg: str) -> None:
|
||||
self.warnings.append(msg)
|
||||
logger.warning(msg)
|
||||
|
||||
def add_error(self, msg: str) -> None:
|
||||
self.errors.append(msg)
|
||||
self.success = False
|
||||
logger.error(msg)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"success": self.success,
|
||||
"imported_tables": self.imported_tables,
|
||||
"imported_files": self.imported_files,
|
||||
"imported_directories": self.imported_directories,
|
||||
"warnings": self.warnings,
|
||||
"errors": self.errors,
|
||||
}
|
||||
|
||||
|
||||
class AstrBotImporter:
|
||||
"""AstrBot 数据导入器
|
||||
|
||||
导入备份文件中的所有数据,包括:
|
||||
- 主数据库所有表
|
||||
- 知识库元数据和文档
|
||||
- 配置文件
|
||||
- 附件文件
|
||||
- 知识库多媒体文件
|
||||
- 插件目录(data/plugins)
|
||||
- 插件数据目录(data/plugin_data)
|
||||
- 配置目录(data/config)
|
||||
- T2I 模板目录(data/t2i_templates)
|
||||
- WebChat 数据目录(data/webchat)
|
||||
- 临时文件目录(data/temp)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
main_db: BaseDatabase,
|
||||
kb_manager: "KnowledgeBaseManager | None" = None,
|
||||
config_path: str = CMD_CONFIG_FILE_PATH,
|
||||
kb_root_dir: str = KB_PATH,
|
||||
):
|
||||
self.main_db = main_db
|
||||
self.kb_manager = kb_manager
|
||||
self.config_path = config_path
|
||||
self.kb_root_dir = kb_root_dir
|
||||
|
||||
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
|
||||
"""预检查备份文件
|
||||
|
||||
在实际导入前检查备份文件的有效性和版本兼容性。
|
||||
返回检查结果供前端显示确认对话框。
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
|
||||
Returns:
|
||||
ImportPreCheckResult: 预检查结果
|
||||
"""
|
||||
result = ImportPreCheckResult()
|
||||
result.current_version = VERSION
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.error = f"备份文件不存在: {zip_path}"
|
||||
return result
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 读取 manifest
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.error = f"manifest.json 格式错误: {e}"
|
||||
return result
|
||||
|
||||
# 提取基本信息
|
||||
result.backup_version = manifest.get("astrbot_version", "未知")
|
||||
result.backup_time = manifest.get("exported_at", "未知")
|
||||
result.valid = True
|
||||
|
||||
# 构建备份摘要
|
||||
result.backup_summary = {
|
||||
"tables": list(manifest.get("tables", {}).keys()),
|
||||
"has_knowledge_bases": manifest.get("has_knowledge_bases", False),
|
||||
"has_config": manifest.get("has_config", False),
|
||||
"directories": manifest.get("directories", []),
|
||||
}
|
||||
|
||||
# 检查版本兼容性
|
||||
version_check = self._check_version_compatibility(result.backup_version)
|
||||
result.version_status = version_check["status"]
|
||||
result.can_import = version_check["can_import"]
|
||||
|
||||
# 版本信息由前端根据 version_status 和 i18n 生成显示
|
||||
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
|
||||
# warnings 列表保留用于其他非版本相关的警告
|
||||
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.error = "无效的 ZIP 文件"
|
||||
return result
|
||||
except Exception as e:
|
||||
result.error = f"检查备份文件失败: {e}"
|
||||
return result
|
||||
|
||||
def _check_version_compatibility(self, backup_version: str) -> dict:
|
||||
"""检查版本兼容性
|
||||
|
||||
规则:
|
||||
- 主版本(前两位,如 4.9)必须一致,否则拒绝
|
||||
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
|
||||
|
||||
Returns:
|
||||
dict: {status, can_import, message}
|
||||
"""
|
||||
if not backup_version:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": "备份文件缺少版本信息",
|
||||
}
|
||||
|
||||
# 提取主版本(前两位)进行比较
|
||||
backup_major = _get_major_version(backup_version)
|
||||
current_major = _get_major_version(VERSION)
|
||||
|
||||
# 比较主版本
|
||||
if VersionComparator.compare_version(backup_major, current_major) != 0:
|
||||
return {
|
||||
"status": "major_diff",
|
||||
"can_import": False,
|
||||
"message": (
|
||||
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。"
|
||||
),
|
||||
}
|
||||
|
||||
# 比较完整版本
|
||||
version_cmp = VersionComparator.compare_version(backup_version, VERSION)
|
||||
if version_cmp != 0:
|
||||
return {
|
||||
"status": "minor_diff",
|
||||
"can_import": True,
|
||||
"message": (
|
||||
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}。"
|
||||
),
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "match",
|
||||
"can_import": True,
|
||||
"message": "版本匹配",
|
||||
}
|
||||
|
||||
async def import_all(
|
||||
self,
|
||||
zip_path: str,
|
||||
mode: str = "replace", # "replace" 清空后导入
|
||||
progress_callback: Any | None = None,
|
||||
) -> ImportResult:
|
||||
"""从 ZIP 文件导入所有数据
|
||||
|
||||
Args:
|
||||
zip_path: ZIP 备份文件路径
|
||||
mode: 导入模式,目前仅支持 "replace"(清空后导入)
|
||||
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
|
||||
|
||||
Returns:
|
||||
ImportResult: 导入结果
|
||||
"""
|
||||
result = ImportResult()
|
||||
|
||||
if not os.path.exists(zip_path):
|
||||
result.add_error(f"备份文件不存在: {zip_path}")
|
||||
return result
|
||||
|
||||
logger.info(f"开始从 {zip_path} 导入备份")
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zf:
|
||||
# 1. 读取并验证 manifest
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 0, 100, "正在验证备份文件...")
|
||||
|
||||
try:
|
||||
manifest_data = zf.read("manifest.json")
|
||||
manifest = json.loads(manifest_data)
|
||||
except KeyError:
|
||||
result.add_error("备份文件缺少 manifest.json")
|
||||
return result
|
||||
except json.JSONDecodeError as e:
|
||||
result.add_error(f"manifest.json 格式错误: {e}")
|
||||
return result
|
||||
|
||||
# 版本校验
|
||||
try:
|
||||
self._validate_version(manifest)
|
||||
except ValueError as e:
|
||||
result.add_error(str(e))
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("validate", 100, 100, "验证完成")
|
||||
|
||||
# 2. 导入主数据库
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 0, 100, "正在导入主数据库...")
|
||||
|
||||
try:
|
||||
main_data_content = zf.read("databases/main_db.json")
|
||||
main_data = json.loads(main_data_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_main_db()
|
||||
|
||||
imported = await self._import_main_database(main_data)
|
||||
result.imported_tables.update(imported)
|
||||
except Exception as e:
|
||||
result.add_error(f"导入主数据库失败: {e}")
|
||||
return result
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("main_db", 100, 100, "主数据库导入完成")
|
||||
|
||||
# 3. 导入知识库
|
||||
if self.kb_manager and "databases/kb_metadata.json" in zf.namelist():
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 0, 100, "正在导入知识库...")
|
||||
|
||||
try:
|
||||
kb_meta_content = zf.read("databases/kb_metadata.json")
|
||||
kb_meta_data = json.loads(kb_meta_content)
|
||||
|
||||
if mode == "replace":
|
||||
await self._clear_kb_data()
|
||||
|
||||
await self._import_knowledge_bases(zf, kb_meta_data, result)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("kb", 100, 100, "知识库导入完成")
|
||||
|
||||
# 4. 导入配置文件
|
||||
if progress_callback:
|
||||
await progress_callback("config", 0, 100, "正在导入配置文件...")
|
||||
|
||||
if "config/cmd_config.json" in zf.namelist():
|
||||
try:
|
||||
config_content = zf.read("config/cmd_config.json")
|
||||
# 备份现有配置
|
||||
if os.path.exists(self.config_path):
|
||||
backup_path = f"{self.config_path}.bak"
|
||||
shutil.copy2(self.config_path, backup_path)
|
||||
|
||||
with open(self.config_path, "wb") as f:
|
||||
f.write(config_content)
|
||||
result.imported_files["config"] = 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入配置文件失败: {e}")
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("config", 100, 100, "配置文件导入完成")
|
||||
|
||||
# 5. 导入附件文件
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 0, 100, "正在导入附件...")
|
||||
|
||||
attachment_count = await self._import_attachments(
|
||||
zf, main_data.get("attachments", [])
|
||||
)
|
||||
result.imported_files["attachments"] = attachment_count
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("attachments", 100, 100, "附件导入完成")
|
||||
|
||||
# 6. 导入插件和其他目录
|
||||
if progress_callback:
|
||||
await progress_callback(
|
||||
"directories", 0, 100, "正在导入插件和数据目录..."
|
||||
)
|
||||
|
||||
dir_stats = await self._import_directories(zf, manifest, result)
|
||||
result.imported_directories = dir_stats
|
||||
|
||||
if progress_callback:
|
||||
await progress_callback("directories", 100, 100, "目录导入完成")
|
||||
|
||||
logger.info(f"备份导入完成: {result.to_dict()}")
|
||||
return result
|
||||
|
||||
except zipfile.BadZipFile:
|
||||
result.add_error("无效的 ZIP 文件")
|
||||
return result
|
||||
except Exception as e:
|
||||
result.add_error(f"导入失败: {e}")
|
||||
return result
|
||||
|
||||
def _validate_version(self, manifest: dict) -> None:
|
||||
"""验证版本兼容性 - 仅允许相同主版本导入
|
||||
|
||||
注意:此方法仅在 import_all 中调用,用于双重校验。
|
||||
前端应先调用 pre_check 获取详细的版本信息并让用户确认。
|
||||
"""
|
||||
backup_version = manifest.get("astrbot_version")
|
||||
if not backup_version:
|
||||
raise ValueError("备份文件缺少版本信息")
|
||||
|
||||
# 使用新的版本兼容性检查
|
||||
version_check = self._check_version_compatibility(backup_version)
|
||||
|
||||
if version_check["status"] == "major_diff":
|
||||
raise ValueError(version_check["message"])
|
||||
|
||||
# minor_diff 和 match 都允许导入
|
||||
if version_check["status"] == "minor_diff":
|
||||
logger.warning(f"版本差异警告: {version_check['message']}")
|
||||
|
||||
async def _clear_main_db(self) -> None:
|
||||
"""清空主数据库所有表"""
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in MAIN_DB_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空表 {table_name} 失败: {e}")
|
||||
|
||||
async def _clear_kb_data(self) -> None:
|
||||
"""清空知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 清空知识库元数据表
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, model_class in KB_METADATA_MODELS.items():
|
||||
try:
|
||||
await session.execute(delete(model_class))
|
||||
logger.debug(f"已清空知识库表 {table_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
|
||||
|
||||
# 删除知识库文件目录
|
||||
for kb_id in list(self.kb_manager.kb_insts.keys()):
|
||||
try:
|
||||
kb_helper = self.kb_manager.kb_insts[kb_id]
|
||||
await kb_helper.terminate()
|
||||
if kb_helper.kb_dir.exists():
|
||||
shutil.rmtree(kb_helper.kb_dir)
|
||||
except Exception as e:
|
||||
logger.warning(f"清理知识库 {kb_id} 失败: {e}")
|
||||
|
||||
self.kb_manager.kb_insts.clear()
|
||||
|
||||
async def _import_main_database(
|
||||
self, data: dict[str, list[dict]]
|
||||
) -> dict[str, int]:
|
||||
"""导入主数据库数据"""
|
||||
imported: dict[str, int] = {}
|
||||
|
||||
async with self.main_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in data.items():
|
||||
model_class = MAIN_DB_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
logger.warning(f"未知的表: {table_name}")
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
# 转换 datetime 字符串为 datetime 对象
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入记录到 {table_name} 失败: {e}")
|
||||
|
||||
imported[table_name] = count
|
||||
logger.debug(f"导入表 {table_name}: {count} 条记录")
|
||||
|
||||
return imported
|
||||
|
||||
async def _import_knowledge_bases(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
kb_meta_data: dict[str, list[dict]],
|
||||
result: ImportResult,
|
||||
) -> None:
|
||||
"""导入知识库数据"""
|
||||
if not self.kb_manager:
|
||||
return
|
||||
|
||||
# 1. 导入知识库元数据
|
||||
async with self.kb_manager.kb_db.get_db() as session:
|
||||
async with session.begin():
|
||||
for table_name, rows in kb_meta_data.items():
|
||||
model_class = KB_METADATA_MODELS.get(table_name)
|
||||
if not model_class:
|
||||
continue
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
row = self._convert_datetime_fields(row, model_class)
|
||||
obj = model_class(**row)
|
||||
session.add(obj)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
|
||||
|
||||
result.imported_tables[f"kb_{table_name}"] = count
|
||||
|
||||
# 2. 导入每个知识库的文档和文件
|
||||
for kb_data in kb_meta_data.get("knowledge_bases", []):
|
||||
kb_id = kb_data.get("kb_id")
|
||||
if not kb_id:
|
||||
continue
|
||||
|
||||
# 创建知识库目录
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
kb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 导入文档数据
|
||||
doc_path = f"databases/kb_{kb_id}/documents.json"
|
||||
if doc_path in zf.namelist():
|
||||
try:
|
||||
doc_content = zf.read(doc_path)
|
||||
doc_data = json.loads(doc_content)
|
||||
|
||||
# 导入到文档存储数据库
|
||||
await self._import_kb_documents(kb_id, doc_data)
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}")
|
||||
|
||||
# 导入 FAISS 索引
|
||||
faiss_path = f"databases/kb_{kb_id}/index.faiss"
|
||||
if faiss_path in zf.namelist():
|
||||
try:
|
||||
target_path = kb_dir / "index.faiss"
|
||||
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
|
||||
|
||||
# 导入媒体文件
|
||||
media_prefix = f"files/kb_media/{kb_id}/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(media_prefix):
|
||||
try:
|
||||
rel_path = name[len(media_prefix) :]
|
||||
target_path = kb_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
|
||||
|
||||
# 3. 重新加载知识库实例
|
||||
await self.kb_manager.load_kbs()
|
||||
|
||||
async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None:
|
||||
"""导入知识库文档到向量数据库"""
|
||||
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
|
||||
|
||||
kb_dir = Path(self.kb_root_dir) / kb_id
|
||||
doc_db_path = kb_dir / "doc.db"
|
||||
|
||||
# 初始化文档存储
|
||||
doc_storage = DocumentStorage(str(doc_db_path))
|
||||
await doc_storage.initialize()
|
||||
|
||||
try:
|
||||
documents = doc_data.get("documents", [])
|
||||
for doc in documents:
|
||||
try:
|
||||
await doc_storage.insert_document(
|
||||
doc_id=doc.get("doc_id", ""),
|
||||
text=doc.get("text", ""),
|
||||
metadata=json.loads(doc.get("metadata", "{}")),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"导入文档块失败: {e}")
|
||||
finally:
|
||||
await doc_storage.close()
|
||||
|
||||
async def _import_attachments(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
attachments: list[dict],
|
||||
) -> int:
|
||||
"""导入附件文件"""
|
||||
count = 0
|
||||
|
||||
attachments_dir = Path(self.config_path).parent / "attachments"
|
||||
attachments_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
attachment_prefix = "files/attachments/"
|
||||
for name in zf.namelist():
|
||||
if name.startswith(attachment_prefix) and name != attachment_prefix:
|
||||
try:
|
||||
# 从附件记录中找到原始路径
|
||||
attachment_id = os.path.splitext(os.path.basename(name))[0]
|
||||
original_path = None
|
||||
for att in attachments:
|
||||
if att.get("attachment_id") == attachment_id:
|
||||
original_path = att.get("path")
|
||||
break
|
||||
|
||||
if original_path:
|
||||
target_path = Path(original_path)
|
||||
else:
|
||||
target_path = attachments_dir / os.path.basename(name)
|
||||
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"导入附件 {name} 失败: {e}")
|
||||
|
||||
return count
|
||||
|
||||
async def _import_directories(
|
||||
self,
|
||||
zf: zipfile.ZipFile,
|
||||
manifest: dict,
|
||||
result: ImportResult,
|
||||
) -> dict[str, int]:
|
||||
"""导入插件和其他数据目录
|
||||
|
||||
Args:
|
||||
zf: ZIP 文件对象
|
||||
manifest: 备份清单
|
||||
result: 导入结果对象
|
||||
|
||||
Returns:
|
||||
dict: 每个目录导入的文件数量
|
||||
"""
|
||||
dir_stats: dict[str, int] = {}
|
||||
|
||||
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
|
||||
backup_version = manifest.get("version", "1.0")
|
||||
if VersionComparator.compare_version(backup_version, "1.1") < 0:
|
||||
logger.info("备份版本不支持目录备份,跳过目录导入")
|
||||
return dir_stats
|
||||
|
||||
backed_up_dirs = manifest.get("directories", [])
|
||||
backup_directories = get_backup_directories()
|
||||
|
||||
for dir_name in backed_up_dirs:
|
||||
if dir_name not in backup_directories:
|
||||
result.add_warning(f"未知的目录类型: {dir_name}")
|
||||
continue
|
||||
|
||||
target_dir = Path(backup_directories[dir_name])
|
||||
archive_prefix = f"directories/{dir_name}/"
|
||||
|
||||
file_count = 0
|
||||
|
||||
try:
|
||||
# 获取该目录下的所有文件
|
||||
dir_files = [
|
||||
name
|
||||
for name in zf.namelist()
|
||||
if name.startswith(archive_prefix) and name != archive_prefix
|
||||
]
|
||||
|
||||
if not dir_files:
|
||||
continue
|
||||
|
||||
# 备份现有目录(如果存在)
|
||||
if target_dir.exists():
|
||||
backup_path = Path(f"{target_dir}.bak")
|
||||
if backup_path.exists():
|
||||
shutil.rmtree(backup_path)
|
||||
shutil.move(str(target_dir), str(backup_path))
|
||||
logger.debug(f"已备份现有目录 {target_dir} 到 {backup_path}")
|
||||
|
||||
# 创建目标目录
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 解压文件
|
||||
for name in dir_files:
|
||||
try:
|
||||
# 计算相对路径
|
||||
rel_path = name[len(archive_prefix) :]
|
||||
if not rel_path: # 跳过目录条目
|
||||
continue
|
||||
|
||||
target_path = target_dir / rel_path
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with zf.open(name) as src, open(target_path, "wb") as dst:
|
||||
dst.write(src.read())
|
||||
file_count += 1
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入文件 {name} 失败: {e}")
|
||||
|
||||
dir_stats[dir_name] = file_count
|
||||
logger.debug(f"导入目录 {dir_name}: {file_count} 个文件")
|
||||
|
||||
except Exception as e:
|
||||
result.add_warning(f"导入目录 {dir_name} 失败: {e}")
|
||||
dir_stats[dir_name] = 0
|
||||
|
||||
return dir_stats
|
||||
|
||||
def _convert_datetime_fields(self, row: dict, model_class: type) -> dict:
|
||||
"""转换 datetime 字符串字段为 datetime 对象"""
|
||||
result = row.copy()
|
||||
|
||||
# 获取模型的 datetime 字段
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
|
||||
try:
|
||||
mapper = sa_inspect(model_class)
|
||||
for column in mapper.columns:
|
||||
if column.name in result and result[column.name] is not None:
|
||||
# 检查是否是 datetime 类型的列
|
||||
from sqlalchemy import DateTime
|
||||
|
||||
if isinstance(column.type, DateTime):
|
||||
value = result[column.name]
|
||||
if isinstance(value, str):
|
||||
# 解析 ISO 格式的日期时间字符串
|
||||
result[column.name] = datetime.fromisoformat(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
@@ -80,6 +80,8 @@ class AstrBotConfig(dict):
|
||||
if v["type"] == "object":
|
||||
conf[k] = {}
|
||||
_parse_schema(v["items"], conf[k])
|
||||
elif v["type"] == "template_list":
|
||||
conf[k] = default
|
||||
else:
|
||||
conf[k] = default
|
||||
|
||||
|
||||
+174
-215
@@ -1,10 +1,11 @@
|
||||
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
|
||||
|
||||
import os
|
||||
from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.9.2"
|
||||
VERSION = "4.10.3"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -61,7 +62,8 @@ DEFAULT_CONFIG = {
|
||||
"ignore_bot_self_message": False,
|
||||
"ignore_at_all": False,
|
||||
},
|
||||
"provider": [],
|
||||
"provider_sources": [], # provider sources
|
||||
"provider": [], # models from provider_sources
|
||||
"provider_settings": {
|
||||
"enable": True,
|
||||
"default_provider_id": "",
|
||||
@@ -171,6 +173,22 @@ DEFAULT_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
class ChatProviderTemplate(TypedDict):
|
||||
id: str
|
||||
provider_source_id: str
|
||||
model: str
|
||||
modalities: list
|
||||
custom_extra_body: dict[str, Any]
|
||||
|
||||
|
||||
CHAT_PROVIDER_TEMPLATE = {
|
||||
"id": "",
|
||||
"provide_source_id": "",
|
||||
"model": "",
|
||||
"modalities": [],
|
||||
"custom_extra_body": {},
|
||||
}
|
||||
|
||||
"""
|
||||
AstrBot v3 时代的配置元数据,目前仅承担以下功能:
|
||||
|
||||
@@ -209,7 +227,7 @@ CONFIG_METADATA_2 = {
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6196,
|
||||
},
|
||||
"QQ 个人号(OneBot v11)": {
|
||||
"OneBot v11": {
|
||||
"id": "default",
|
||||
"type": "aiocqhttp",
|
||||
"enable": False,
|
||||
@@ -844,6 +862,7 @@ CONFIG_METADATA_2 = {
|
||||
"metadata": {
|
||||
"provider": {
|
||||
"type": "list",
|
||||
# provider sources templates
|
||||
"config_template": {
|
||||
"OpenAI": {
|
||||
"id": "openai",
|
||||
@@ -854,107 +873,10 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
"id": "azure",
|
||||
"provider": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"provider": "xai",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"xai_native_search": False,
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Anthropic": {
|
||||
"hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错",
|
||||
"id": "claude",
|
||||
"provider": "anthropic",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "claude-3-5-sonnet-latest",
|
||||
"max_tokens": 4096,
|
||||
"temperature": 0.2,
|
||||
},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Ollama": {
|
||||
"hint": "启用前请确保已正确安装并运行 Ollama 服务端,Ollama默认不带鉴权,无需修改key",
|
||||
"id": "ollama_default",
|
||||
"provider": "ollama",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"provider": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://localhost:1234/v1",
|
||||
"model_config": {
|
||||
"model": "llama-3.1-8b",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Gemini(OpenAI兼容)": {
|
||||
"id": "gemini_default",
|
||||
"provider": "google",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Gemini": {
|
||||
"id": "gemini_default",
|
||||
"Google Gemini": {
|
||||
"id": "google_gemini",
|
||||
"provider": "google",
|
||||
"type": "googlegenai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -962,10 +884,6 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"gm_resp_image_modal": False,
|
||||
"gm_native_search": False,
|
||||
"gm_native_coderunner": False,
|
||||
@@ -976,13 +894,44 @@ CONFIG_METADATA_2 = {
|
||||
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"budget": 0,
|
||||
},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
"gm_thinking_config": {"budget": 0, "level": "HIGH"},
|
||||
},
|
||||
"Anthropic": {
|
||||
"id": "anthropic",
|
||||
"provider": "anthropic",
|
||||
"type": "anthropic_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
"anth_thinking_config": {"budget": 0},
|
||||
},
|
||||
"Moonshot": {
|
||||
"id": "moonshot",
|
||||
"provider": "moonshot",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"xAI": {
|
||||
"id": "xai",
|
||||
"provider": "xai",
|
||||
"type": "xai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
"xai_native_search": False,
|
||||
},
|
||||
"DeepSeek": {
|
||||
"id": "deepseek_default",
|
||||
"id": "deepseek",
|
||||
"provider": "deepseek",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -990,13 +939,75 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"Zhipu": {
|
||||
"id": "zhipu",
|
||||
"provider": "zhipu",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
"id": "azure_openai",
|
||||
"provider": "azure",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"api_version": "2024-05-01-preview",
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Ollama": {
|
||||
"id": "ollama",
|
||||
"provider": "ollama",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://127.0.0.1:11434/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"LM Studio": {
|
||||
"id": "lm_studio",
|
||||
"provider": "lm_studio",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://127.0.0.1:1234/v1",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"ModelStack": {
|
||||
"id": "modelstack",
|
||||
"provider": "modelstack",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://modelstack.app/v1",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Gemini_OpenAI_API": {
|
||||
"id": "google_gemini_openai",
|
||||
"provider": "google",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Groq": {
|
||||
"id": "groq_default",
|
||||
"id": "groq",
|
||||
"provider": "groq",
|
||||
"type": "groq_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
@@ -1004,13 +1015,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.groq.com/openai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "openai/gpt-oss-20b",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "tool_use"],
|
||||
},
|
||||
"302.AI": {
|
||||
"id": "302ai",
|
||||
@@ -1021,12 +1026,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.302.ai/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"硅基流动": {
|
||||
"SiliconFlow": {
|
||||
"id": "siliconflow",
|
||||
"provider": "siliconflow",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -1035,15 +1037,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.siliconflow.cn/v1",
|
||||
"model_config": {
|
||||
"model": "deepseek-ai/DeepSeek-V3",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"PPIO派欧云": {
|
||||
"PPIO": {
|
||||
"id": "ppio",
|
||||
"provider": "ppio",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -1052,14 +1048,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.ppinfra.com/v3/openai",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "deepseek/deepseek-r1",
|
||||
"temperature": 0.4,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"小马算力": {
|
||||
"TokenPony": {
|
||||
"id": "tokenpony",
|
||||
"provider": "tokenpony",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -1068,14 +1059,9 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.tokenpony.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "kimi-k2-instruct-0905",
|
||||
"temperature": 0.7,
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
"优云智算": {
|
||||
"Compshare": {
|
||||
"id": "compshare",
|
||||
"provider": "compshare",
|
||||
"type": "openai_chat_completion",
|
||||
@@ -1084,42 +1070,18 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.modelverse.cn/v1",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "moonshotai/Kimi-K2-Instruct",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Kimi": {
|
||||
"id": "moonshot",
|
||||
"provider": "moonshot",
|
||||
"ModelScope": {
|
||||
"id": "modelscope",
|
||||
"provider": "modelscope",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"智谱 AI": {
|
||||
"id": "zhipu_default",
|
||||
"provider": "zhipu",
|
||||
"type": "zhipu_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"model_config": {
|
||||
"model": "glm-4-flash",
|
||||
},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"Dify": {
|
||||
"id": "dify_app_default",
|
||||
@@ -1134,7 +1096,6 @@ CONFIG_METADATA_2 = {
|
||||
"dify_query_input_key": "astrbot_text_query",
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
|
||||
},
|
||||
"Coze": {
|
||||
"id": "coze",
|
||||
@@ -1165,20 +1126,6 @@ CONFIG_METADATA_2 = {
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
},
|
||||
"ModelScope": {
|
||||
"id": "modelscope",
|
||||
"provider": "modelscope",
|
||||
"type": "openai_chat_completion",
|
||||
"provider_type": "chat_completion",
|
||||
"enable": True,
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
"modalities": ["text", "image", "tool_use"],
|
||||
},
|
||||
"FastGPT": {
|
||||
"id": "fastgpt",
|
||||
"provider": "fastgpt",
|
||||
@@ -1202,7 +1149,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "whisper-1",
|
||||
},
|
||||
"Whisper(Local)": {
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"provider": "openai",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"provider_type": "speech_to_text",
|
||||
@@ -1211,7 +1157,6 @@ CONFIG_METADATA_2 = {
|
||||
"model": "tiny",
|
||||
},
|
||||
"SenseVoice(Local)": {
|
||||
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"type": "sensevoice_stt_selfhost",
|
||||
"provider": "sensevoice",
|
||||
"provider_type": "speech_to_text",
|
||||
@@ -1233,7 +1178,6 @@ CONFIG_METADATA_2 = {
|
||||
"timeout": "20",
|
||||
},
|
||||
"Edge TTS": {
|
||||
"hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
|
||||
"id": "edge_tts",
|
||||
"provider": "microsoft",
|
||||
"type": "edge_tts",
|
||||
@@ -1343,7 +1287,7 @@ CONFIG_METADATA_2 = {
|
||||
"minimax-is-timber-weight": False,
|
||||
"minimax-voice-id": "female-shaonv",
|
||||
"minimax-timber-weight": '[\n {\n "voice_id": "Chinese (Mandarin)_Warm_Girl",\n "weight": 25\n },\n {\n "voice_id": "Chinese (Mandarin)_BashfulGirl",\n "weight": 50\n }\n]',
|
||||
"minimax-voice-emotion": "neutral",
|
||||
"minimax-voice-emotion": "auto",
|
||||
"minimax-voice-latex": False,
|
||||
"minimax-voice-english-normalization": False,
|
||||
"timeout": 20,
|
||||
@@ -1449,6 +1393,10 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"provider_source_id": {
|
||||
"invisible": True,
|
||||
"type": "string",
|
||||
},
|
||||
"xai_native_search": {
|
||||
"description": "启用原生搜索功能",
|
||||
"type": "bool",
|
||||
@@ -1819,13 +1767,35 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"description": "Gemini思考设置",
|
||||
"description": "Thinking Config",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"budget": {
|
||||
"description": "思考预算",
|
||||
"description": "Thinking Budget",
|
||||
"type": "int",
|
||||
"hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
|
||||
"hint": "Guides the model on the specific number of thinking tokens to use for reasoning. See: https://ai.google.dev/gemini-api/docs/thinking#set-budget",
|
||||
},
|
||||
"level": {
|
||||
"description": "Thinking Level",
|
||||
"type": "string",
|
||||
"hint": "Recommended for Gemini 3 models and onwards, lets you control reasoning behavior.See: https://ai.google.dev/gemini-api/docs/thinking#thinking-levels",
|
||||
"options": [
|
||||
"MINIMAL",
|
||||
"LOW",
|
||||
"MEDIUM",
|
||||
"HIGH",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
"anth_thinking_config": {
|
||||
"description": "Thinking Config",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"budget": {
|
||||
"description": "Thinking Budget",
|
||||
"type": "int",
|
||||
"hint": "Anthropic thinking.budget_tokens param. Must >= 1024. See: https://platform.claude.com/docs/en/build-with-claude/extended-thinking",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -1900,15 +1870,18 @@ CONFIG_METADATA_2 = {
|
||||
"minimax-voice-emotion": {
|
||||
"type": "string",
|
||||
"description": "情绪",
|
||||
"hint": "控制合成语音的情绪",
|
||||
"hint": "控制合成语音的情绪。当为 auto 时,将根据文本内容自动选择情绪。",
|
||||
"options": [
|
||||
"auto",
|
||||
"happy",
|
||||
"sad",
|
||||
"angry",
|
||||
"fearful",
|
||||
"disgusted",
|
||||
"surprised",
|
||||
"neutral",
|
||||
"calm",
|
||||
"fluent",
|
||||
"whisper",
|
||||
],
|
||||
},
|
||||
"minimax-voice-latex": {
|
||||
@@ -2006,7 +1979,6 @@ CONFIG_METADATA_2 = {
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
"hint": "模型提供商名字。",
|
||||
},
|
||||
"type": {
|
||||
"description": "模型提供商种类",
|
||||
@@ -2026,29 +1998,15 @@ CONFIG_METADATA_2 = {
|
||||
"description": "API Key",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "提供商 API Key。",
|
||||
},
|
||||
"api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
|
||||
},
|
||||
"model_config": {
|
||||
"description": "模型配置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"model": {
|
||||
"description": "模型名称",
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"max_tokens": {
|
||||
"description": "模型最大输出长度(tokens)",
|
||||
"type": "int",
|
||||
},
|
||||
"temperature": {"description": "温度", "type": "float"},
|
||||
"top_p": {"description": "Top P值", "type": "float"},
|
||||
},
|
||||
"model": {
|
||||
"description": "模型 ID",
|
||||
"type": "string",
|
||||
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
|
||||
},
|
||||
"dify_api_key": {
|
||||
"description": "API Key",
|
||||
@@ -3106,4 +3064,5 @@ DEFAULT_VALUE_MAP = {
|
||||
"text": "",
|
||||
"list": [],
|
||||
"object": {},
|
||||
"template_list": [],
|
||||
}
|
||||
|
||||
@@ -33,6 +33,7 @@ from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
|
||||
from astrbot.core.umop_config_router import UmopConfigRouter
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core.utils.llm_metadata import update_llm_metadata
|
||||
from astrbot.core.utils.migra_helper import migra
|
||||
|
||||
from . import astrbot_config, html_renderer
|
||||
@@ -185,6 +186,8 @@ class AstrBotCoreLifecycle:
|
||||
# 初始化关闭控制面板的事件
|
||||
self.dashboard_shutdown_event = asyncio.Event()
|
||||
|
||||
asyncio.create_task(update_llm_metadata())
|
||||
|
||||
def _load(self) -> None:
|
||||
"""加载事件总线和任务并初始化."""
|
||||
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||
|
||||
@@ -9,6 +9,8 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
|
||||
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
@@ -314,6 +316,76 @@ class BaseDatabase(abc.ABC):
|
||||
"""Clear all preferences for a specific scope ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_command_configs(self) -> list[CommandConfig]:
|
||||
"""Get all stored command configurations."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_command_config(self, handler_full_name: str) -> CommandConfig | None:
|
||||
"""Fetch a single command configuration by handler."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upsert_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
"""Create or update a command configuration."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||
"""Delete a single command configuration."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||
"""Bulk delete command configurations."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_command_conflicts(
|
||||
self,
|
||||
status: str | None = None,
|
||||
) -> list[CommandConflict]:
|
||||
"""List recorded command conflict entries."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def upsert_command_conflict(
|
||||
self,
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
"""Create or update a conflict record."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||
"""Delete conflict records."""
|
||||
...
|
||||
|
||||
# @abc.abstractmethod
|
||||
# async def insert_llm_message(
|
||||
# self,
|
||||
|
||||
@@ -234,6 +234,65 @@ class Attachment(SQLModel, table=True):
|
||||
)
|
||||
|
||||
|
||||
class CommandConfig(SQLModel, table=True):
|
||||
"""Per-command configuration overrides for dashboard management."""
|
||||
|
||||
__tablename__ = "command_configs" # type: ignore
|
||||
|
||||
handler_full_name: str = Field(
|
||||
primary_key=True,
|
||||
max_length=512,
|
||||
)
|
||||
plugin_name: str = Field(nullable=False, max_length=255)
|
||||
module_path: str = Field(nullable=False, max_length=255)
|
||||
original_command: str = Field(nullable=False, max_length=255)
|
||||
resolved_command: str | None = Field(default=None, max_length=255)
|
||||
enabled: bool = Field(default=True, nullable=False)
|
||||
keep_original_alias: bool = Field(default=False, nullable=False)
|
||||
conflict_key: str | None = Field(default=None, max_length=255)
|
||||
resolution_strategy: str | None = Field(default=None, max_length=64)
|
||||
note: str | None = Field(default=None, sa_type=Text)
|
||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||
auto_managed: bool = Field(default=False, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
|
||||
class CommandConflict(SQLModel, table=True):
|
||||
"""Conflict tracking for duplicated command names."""
|
||||
|
||||
__tablename__ = "command_conflicts" # type: ignore
|
||||
|
||||
id: int | None = Field(
|
||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
conflict_key: str = Field(nullable=False, max_length=255)
|
||||
handler_full_name: str = Field(nullable=False, max_length=512)
|
||||
plugin_name: str = Field(nullable=False, max_length=255)
|
||||
status: str = Field(default="pending", max_length=32)
|
||||
resolution: str | None = Field(default=None, max_length=64)
|
||||
resolved_command: str | None = Field(default=None, max_length=255)
|
||||
note: str | None = Field(default=None, sa_type=Text)
|
||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||
auto_generated: bool = Field(default=False, nullable=False)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"conflict_key",
|
||||
"handler_full_name",
|
||||
name="uix_conflict_handler",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Conversation:
|
||||
"""LLM 对话类
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import threading
|
||||
import typing as T
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import CursorResult
|
||||
@@ -10,6 +11,8 @@ from sqlmodel import col, delete, desc, func, or_, select, text, update
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import (
|
||||
Attachment,
|
||||
CommandConfig,
|
||||
CommandConflict,
|
||||
ConversationV2,
|
||||
Persona,
|
||||
PlatformMessageHistory,
|
||||
@@ -26,6 +29,7 @@ from astrbot.core.db.po import (
|
||||
)
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
TxResult = T.TypeVar("TxResult")
|
||||
|
||||
|
||||
class SQLiteDatabase(BaseDatabase):
|
||||
@@ -670,6 +674,242 @@ class SQLiteDatabase(BaseDatabase):
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# ====
|
||||
# Command Configuration & Conflict Tracking
|
||||
# ====
|
||||
|
||||
async def _run_in_tx(
|
||||
self,
|
||||
fn: Callable[[AsyncSession], Awaitable[TxResult]],
|
||||
) -> TxResult:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
return await fn(session)
|
||||
|
||||
@staticmethod
|
||||
def _apply_updates(model, **updates) -> None:
|
||||
for field, value in updates.items():
|
||||
if value is not None:
|
||||
setattr(model, field, value)
|
||||
|
||||
@staticmethod
|
||||
def _new_command_config(
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
return CommandConfig(
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=plugin_name,
|
||||
module_path=module_path,
|
||||
original_command=original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=True if enabled is None else enabled,
|
||||
keep_original_alias=False
|
||||
if keep_original_alias is None
|
||||
else keep_original_alias,
|
||||
conflict_key=conflict_key or original_command,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=bool(auto_managed),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _new_command_conflict(
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
return CommandConflict(
|
||||
conflict_key=conflict_key,
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=plugin_name,
|
||||
status=status or "pending",
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=bool(auto_generated),
|
||||
)
|
||||
|
||||
async def get_command_configs(self) -> list[CommandConfig]:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
result = await session.execute(select(CommandConfig))
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
) -> CommandConfig | None:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
return await session.get(CommandConfig, handler_full_name)
|
||||
|
||||
async def upsert_command_config(
|
||||
self,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
module_path: str,
|
||||
original_command: str,
|
||||
*,
|
||||
resolved_command: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
keep_original_alias: bool | None = None,
|
||||
conflict_key: str | None = None,
|
||||
resolution_strategy: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_managed: bool | None = None,
|
||||
) -> CommandConfig:
|
||||
async def _op(session: AsyncSession) -> CommandConfig:
|
||||
config = await session.get(CommandConfig, handler_full_name)
|
||||
if not config:
|
||||
config = self._new_command_config(
|
||||
handler_full_name,
|
||||
plugin_name,
|
||||
module_path,
|
||||
original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=enabled,
|
||||
keep_original_alias=keep_original_alias,
|
||||
conflict_key=conflict_key,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=auto_managed,
|
||||
)
|
||||
session.add(config)
|
||||
else:
|
||||
self._apply_updates(
|
||||
config,
|
||||
plugin_name=plugin_name,
|
||||
module_path=module_path,
|
||||
original_command=original_command,
|
||||
resolved_command=resolved_command,
|
||||
enabled=enabled,
|
||||
keep_original_alias=keep_original_alias,
|
||||
conflict_key=conflict_key,
|
||||
resolution_strategy=resolution_strategy,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_managed=auto_managed,
|
||||
)
|
||||
await session.flush()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
return await self._run_in_tx(_op)
|
||||
|
||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||
await self.delete_command_configs([handler_full_name])
|
||||
|
||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||
if not handler_full_names:
|
||||
return
|
||||
|
||||
async def _op(session: AsyncSession) -> None:
|
||||
await session.execute(
|
||||
delete(CommandConfig).where(
|
||||
col(CommandConfig.handler_full_name).in_(handler_full_names),
|
||||
),
|
||||
)
|
||||
|
||||
await self._run_in_tx(_op)
|
||||
|
||||
async def list_command_conflicts(
|
||||
self,
|
||||
status: str | None = None,
|
||||
) -> list[CommandConflict]:
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(CommandConflict)
|
||||
if status:
|
||||
query = query.where(CommandConflict.status == status)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def upsert_command_conflict(
|
||||
self,
|
||||
conflict_key: str,
|
||||
handler_full_name: str,
|
||||
plugin_name: str,
|
||||
*,
|
||||
status: str | None = None,
|
||||
resolution: str | None = None,
|
||||
resolved_command: str | None = None,
|
||||
note: str | None = None,
|
||||
extra_data: dict | None = None,
|
||||
auto_generated: bool | None = None,
|
||||
) -> CommandConflict:
|
||||
async def _op(session: AsyncSession) -> CommandConflict:
|
||||
result = await session.execute(
|
||||
select(CommandConflict).where(
|
||||
CommandConflict.conflict_key == conflict_key,
|
||||
CommandConflict.handler_full_name == handler_full_name,
|
||||
),
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if not record:
|
||||
record = self._new_command_conflict(
|
||||
conflict_key,
|
||||
handler_full_name,
|
||||
plugin_name,
|
||||
status=status,
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=auto_generated,
|
||||
)
|
||||
session.add(record)
|
||||
else:
|
||||
self._apply_updates(
|
||||
record,
|
||||
plugin_name=plugin_name,
|
||||
status=status,
|
||||
resolution=resolution,
|
||||
resolved_command=resolved_command,
|
||||
note=note,
|
||||
extra_data=extra_data,
|
||||
auto_generated=auto_generated,
|
||||
)
|
||||
await session.flush()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
return await self._run_in_tx(_op)
|
||||
|
||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||
if not ids:
|
||||
return
|
||||
|
||||
async def _op(session: AsyncSession) -> None:
|
||||
await session.execute(
|
||||
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
|
||||
)
|
||||
|
||||
await self._run_in_tx(_op)
|
||||
|
||||
# ====
|
||||
# Deprecated Methods
|
||||
# ====
|
||||
|
||||
+1
-1
@@ -58,7 +58,7 @@ def is_plugin_path(pathname):
|
||||
return False
|
||||
|
||||
norm_path = os.path.normpath(pathname)
|
||||
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
|
||||
return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path)
|
||||
|
||||
|
||||
def get_short_level_name(level_name):
|
||||
|
||||
@@ -629,12 +629,11 @@ class Nodes(BaseMessageComponent):
|
||||
|
||||
class Json(BaseMessageComponent):
|
||||
type = ComponentType.Json
|
||||
data: str | dict
|
||||
resid: int | None = 0
|
||||
data: dict
|
||||
|
||||
def __init__(self, data, **_):
|
||||
if isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
def __init__(self, data: str | dict, **_):
|
||||
if isinstance(data, str):
|
||||
data = json.loads(data)
|
||||
super().__init__(data=data, **_)
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||
from astrbot.core.conversation_mgr import Conversation
|
||||
@@ -294,6 +295,7 @@ class InternalAgentSubStage(Stage):
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse | None,
|
||||
all_messages: list[Message],
|
||||
):
|
||||
if (
|
||||
not req
|
||||
@@ -307,26 +309,23 @@ class InternalAgentSubStage(Stage):
|
||||
logger.debug("LLM 响应为空,不保存记录。")
|
||||
return
|
||||
|
||||
if req.contexts is None:
|
||||
req.contexts = []
|
||||
# using agent context messages to save to history
|
||||
message_to_save = []
|
||||
for message in all_messages:
|
||||
if message.role == "system":
|
||||
# we do not save system messages to history
|
||||
continue
|
||||
if message.role in ["assistant", "user"] and getattr(
|
||||
message, "_no_save", None
|
||||
):
|
||||
# we do not save user and assistant messages that are marked as _no_save
|
||||
continue
|
||||
message_to_save.append(message.model_dump())
|
||||
|
||||
# 历史上下文
|
||||
messages = copy.deepcopy(req.contexts)
|
||||
# 这一轮对话请求的用户输入
|
||||
messages.append(await req.assemble_context())
|
||||
# 这一轮对话的 LLM 响应
|
||||
if req.tool_calls_result:
|
||||
if not isinstance(req.tool_calls_result, list):
|
||||
messages.extend(req.tool_calls_result.to_openai_messages())
|
||||
elif isinstance(req.tool_calls_result, list):
|
||||
for tcr in req.tool_calls_result:
|
||||
messages.extend(tcr.to_openai_messages())
|
||||
messages.append({"role": "assistant", "content": llm_response.completion_text})
|
||||
messages = list(filter(lambda item: "_no_save" not in item, messages))
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
req.conversation.cid,
|
||||
history=messages,
|
||||
history=message_to_save,
|
||||
)
|
||||
|
||||
def _fix_messages(self, messages: list[dict]) -> list[dict]:
|
||||
@@ -350,174 +349,190 @@ class InternalAgentSubStage(Stage):
|
||||
) -> AsyncGenerator[None, None]:
|
||||
req: ProviderRequest | None = None
|
||||
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
|
||||
return
|
||||
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
logger.debug("ready to request llm provider")
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
try:
|
||||
provider = self._select_provider(event)
|
||||
if provider is None:
|
||||
return
|
||||
if not isinstance(provider, Provider):
|
||||
logger.error(
|
||||
f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。"
|
||||
)
|
||||
return
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
streaming_response = self.streaming_response
|
||||
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
|
||||
streaming_response = bool(enable_streaming)
|
||||
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
logger.debug("ready to request llm provider")
|
||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||
logger.debug("acquired session lock for llm request")
|
||||
if event.get_extra("provider_request"):
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), (
|
||||
"provider_request 必须是 ProviderRequest 类型。"
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
|
||||
else:
|
||||
req = ProviderRequest()
|
||||
req.prompt = ""
|
||||
req.image_urls = []
|
||||
if sel_model := event.get_extra("selected_model"):
|
||||
req.model = sel_model
|
||||
if provider_wake_prefix and not event.message_str.startswith(
|
||||
provider_wake_prefix
|
||||
):
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
|
||||
req.prompt = event.message_str[len(provider_wake_prefix) :]
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
req.image_urls.append(image_path)
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
conversation = await self._get_session_conv(event)
|
||||
req.conversation = conversation
|
||||
req.contexts = json.loads(conversation.history)
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
event.set_extra("provider_request", req)
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
self._fix_messages(req.contexts)
|
||||
|
||||
# fix contexts json str
|
||||
if isinstance(req.contexts, str):
|
||||
req.contexts = json.loads(req.contexts)
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# apply file extract
|
||||
if self.file_extract_enabled:
|
||||
try:
|
||||
await self._apply_file_extract(event, req)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while applying file extract: {e}")
|
||||
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
||||
self._modalities_fix(provider, req)
|
||||
|
||||
if not req.prompt and not req.image_urls:
|
||||
return
|
||||
# filter tools, only keep tools from this pipeline's selected plugins
|
||||
self._plugin_tool_fix(event, req)
|
||||
|
||||
# call event hook
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
# apply knowledge base feature
|
||||
await self._apply_kb(event, req)
|
||||
|
||||
# truncate contexts to fit max length
|
||||
if req.contexts:
|
||||
req.contexts = self._truncate_contexts(req.contexts)
|
||||
self._fix_messages(req.contexts)
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
# check provider modalities, if provider does not support image/tool_use, clear them in request.
|
||||
self._modalities_fix(provider, req)
|
||||
|
||||
# filter tools, only keep tools from this pipeline's selected plugins
|
||||
self._plugin_tool_fix(event, req)
|
||||
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
),
|
||||
stream_to_general = (
|
||||
self.unsupported_streaming_strategy == "turn_off"
|
||||
and not event.platform_meta.support_streaming_message
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
# 备份 req.contexts
|
||||
backup_contexts = copy.deepcopy(req.contexts)
|
||||
|
||||
# run agent
|
||||
agent_runner = AgentRunner()
|
||||
logger.debug(
|
||||
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
|
||||
)
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
context=self.ctx.plugin_manager.context,
|
||||
event=event,
|
||||
)
|
||||
await agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
context=astr_agent_ctx,
|
||||
tool_call_timeout=self.tool_call_timeout,
|
||||
),
|
||||
tool_executor=FunctionToolExecutor(),
|
||||
agent_hooks=MAIN_AGENT_HOOKS,
|
||||
streaming=streaming_response,
|
||||
)
|
||||
|
||||
if streaming_response and not stream_to_general:
|
||||
# 流式响应
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(
|
||||
run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
show_reasoning=self.show_reasoning,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
),
|
||||
)
|
||||
yield
|
||||
if agent_runner.done():
|
||||
if final_llm_resp := agent_runner.get_final_llm_resp():
|
||||
if final_llm_resp.completion_text:
|
||||
chain = (
|
||||
MessageChain()
|
||||
.message(final_llm_resp.completion_text)
|
||||
.chain
|
||||
)
|
||||
elif final_llm_resp.result_chain:
|
||||
chain = final_llm_resp.result_chain.chain
|
||||
else:
|
||||
chain = MessageChain().chain
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=chain,
|
||||
result_content_type=ResultContentType.STREAMING_FINISH,
|
||||
),
|
||||
)
|
||||
else:
|
||||
async for _ in run_agent(
|
||||
agent_runner,
|
||||
self.max_step,
|
||||
self.show_tool_use,
|
||||
stream_to_general,
|
||||
show_reasoning=self.show_reasoning,
|
||||
):
|
||||
yield
|
||||
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
# 恢复备份的 contexts
|
||||
req.contexts = backup_contexts
|
||||
|
||||
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
|
||||
await self._save_to_history(
|
||||
event,
|
||||
req,
|
||||
agent_runner.get_final_llm_resp(),
|
||||
agent_runner.run_context.messages,
|
||||
)
|
||||
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
# 异步处理 WebChat 特殊情况
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=agent_runner.provider.get_model(),
|
||||
provider_type=agent_runner.provider.meta().type,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred while processing agent: {e}")
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
f"Error occurred while processing agent request: {e}"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -119,7 +119,7 @@ class RespondStage(Stage):
|
||||
|
||||
if (result := event.get_result()) is None:
|
||||
return False
|
||||
if self.only_llm_result and result.is_llm_result():
|
||||
if self.only_llm_result and not result.is_llm_result():
|
||||
return False
|
||||
|
||||
if event.get_platform_name() in [
|
||||
|
||||
@@ -98,6 +98,9 @@ class ResultDecorateStage(Stage):
|
||||
self.content_safe_check_stage = stage_cls()
|
||||
await self.content_safe_check_stage.initialize(ctx)
|
||||
|
||||
provider_cfg = ctx.astrbot_config.get("provider_settings", {})
|
||||
self.show_reasoning = provider_cfg.get("display_reasoning_text", False)
|
||||
|
||||
def _split_text_by_words(self, text: str) -> list[str]:
|
||||
"""使用分段词列表分段文本"""
|
||||
if not self.split_words_pattern:
|
||||
@@ -254,70 +257,75 @@ class ResultDecorateStage(Stage):
|
||||
event.unified_msg_origin,
|
||||
)
|
||||
|
||||
if (
|
||||
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
|
||||
should_tts = (
|
||||
bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"])
|
||||
and result.is_llm_result()
|
||||
and SessionServiceManager.should_process_tts_request(event)
|
||||
):
|
||||
should_tts = self.tts_trigger_probability >= 1.0 or (
|
||||
self.tts_trigger_probability > 0.0
|
||||
and random.random() <= self.tts_trigger_probability
|
||||
and random.random() <= self.tts_trigger_probability
|
||||
and tts_provider
|
||||
)
|
||||
if should_tts and not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||
)
|
||||
|
||||
if not should_tts:
|
||||
logger.debug("跳过 TTS:触发概率未命中。")
|
||||
elif not tts_provider:
|
||||
logger.warning(
|
||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||
)
|
||||
else:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}",
|
||||
)
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
if (
|
||||
not should_tts
|
||||
and self.show_reasoning
|
||||
and event.get_extra("_llm_reasoning_content")
|
||||
):
|
||||
# inject reasoning content to chain
|
||||
reasoning_content = event.get_extra("_llm_reasoning_content")
|
||||
result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n"))
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path,
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
),
|
||||
if should_tts and tts_provider:
|
||||
new_chain = []
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Plain) and len(comp.text) > 1:
|
||||
try:
|
||||
logger.info(f"TTS 请求: {comp.text}")
|
||||
audio_path = await tts_provider.get_audio(comp.text)
|
||||
logger.info(f"TTS 结果: {audio_path}")
|
||||
if not audio_path:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}",
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
continue
|
||||
|
||||
use_file_service = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["use_file_service"]
|
||||
callback_api_base = self.ctx.astrbot_config[
|
||||
"callback_api_base"
|
||||
]
|
||||
dual_output = self.ctx.astrbot_config[
|
||||
"provider_tts_settings"
|
||||
]["dual_output"]
|
||||
|
||||
url = None
|
||||
if use_file_service and callback_api_base:
|
||||
token = await file_token_service.register_file(
|
||||
audio_path,
|
||||
)
|
||||
url = f"{callback_api_base}/api/file/{token}"
|
||||
logger.debug(f"已注册:{url}")
|
||||
|
||||
new_chain.append(
|
||||
Record(
|
||||
file=url or audio_path,
|
||||
url=url or audio_path,
|
||||
),
|
||||
)
|
||||
if dual_output:
|
||||
new_chain.append(comp)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error("TTS 失败,使用文本发送。")
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
result.chain = new_chain
|
||||
|
||||
# 文本转图片
|
||||
elif (
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import At, AtAll, Reply
|
||||
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
@@ -13,6 +14,23 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
|
||||
from ..context import PipelineContext
|
||||
from ..stage import Stage, register_stage
|
||||
|
||||
UNIQUE_SESSION_ID_BUILDERS: dict[str, Callable[[AstrMessageEvent], str | None]] = {
|
||||
"aiocqhttp": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
|
||||
"slack": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
|
||||
"dingtalk": lambda e: e.get_sender_id(),
|
||||
"qq_official": lambda e: e.get_sender_id(),
|
||||
"qq_official_webhook": lambda e: e.get_sender_id(),
|
||||
"lark": lambda e: f"{e.get_sender_id()}%{e.get_group_id()}",
|
||||
"misskey": lambda e: f"{e.get_session_id()}_{e.get_sender_id()}",
|
||||
"wechatpadpro": lambda e: f"{e.get_group_id()}#{e.get_sender_id()}",
|
||||
}
|
||||
|
||||
|
||||
def build_unique_session_id(event: AstrMessageEvent) -> str | None:
|
||||
platform = event.get_platform_name()
|
||||
builder = UNIQUE_SESSION_ID_BUILDERS.get(platform)
|
||||
return builder(event) if builder else None
|
||||
|
||||
|
||||
@register_stage
|
||||
class WakingCheckStage(Stage):
|
||||
@@ -53,18 +71,27 @@ class WakingCheckStage(Stage):
|
||||
self.disable_builtin_commands = self.ctx.astrbot_config.get(
|
||||
"disable_builtin_commands", False
|
||||
)
|
||||
platform_settings = self.ctx.astrbot_config.get("platform_settings", {})
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
# apply unique session
|
||||
if self.unique_session and event.message_obj.type == MessageType.GROUP_MESSAGE:
|
||||
sid = build_unique_session_id(event)
|
||||
if sid:
|
||||
event.session_id = sid
|
||||
|
||||
# ignore bot self message
|
||||
if (
|
||||
self.ignore_bot_self_message
|
||||
and event.get_self_id() == event.get_sender_id()
|
||||
):
|
||||
# 忽略机器人自己发送的消息
|
||||
event.stop_event()
|
||||
return
|
||||
|
||||
# 设置 sender 身份
|
||||
event.message_str = event.message_str.strip()
|
||||
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
||||
@@ -136,7 +163,8 @@ class WakingCheckStage(Stage):
|
||||
):
|
||||
if (
|
||||
self.disable_builtin_commands
|
||||
and handler.handler_module_path == "packages.builtin_commands.main"
|
||||
and handler.handler_module_path
|
||||
== "astrbot.builtin_stars.builtin_commands.main"
|
||||
):
|
||||
logger.debug("skipping builtin command")
|
||||
continue
|
||||
|
||||
@@ -41,7 +41,6 @@ class AiocqhttpAdapter(Platform):
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.host = platform_config["ws_reverse_host"]
|
||||
self.port = platform_config["ws_reverse_port"]
|
||||
|
||||
@@ -136,14 +135,11 @@ class AiocqhttpAdapter(Platform):
|
||||
abm.group_id = str(event.group_id)
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
abm.timestamp = int(time.time())
|
||||
@@ -164,16 +160,11 @@ class AiocqhttpAdapter(Platform):
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = (
|
||||
str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||
) # 也保留群组 id
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
abm.raw_message = event
|
||||
@@ -210,16 +201,11 @@ class AiocqhttpAdapter(Platform):
|
||||
abm.group.group_name = event.get("group_name", "N/A")
|
||||
elif event["message_type"] == "private":
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = (
|
||||
abm.sender.user_id + "_" + str(event.group_id)
|
||||
) # 也保留群组 id
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
|
||||
abm.message_id = str(event.message_id)
|
||||
abm.message = []
|
||||
@@ -385,10 +371,25 @@ class AiocqhttpAdapter(Platform):
|
||||
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
|
||||
|
||||
message_str += "".join(at_parts)
|
||||
elif t == "markdown":
|
||||
text = m["data"].get("markdown") or m["data"].get("content", "")
|
||||
abm.message.append(Plain(text=text))
|
||||
message_str += text
|
||||
else:
|
||||
for m in m_group:
|
||||
a = ComponentTypes[t](**m["data"])
|
||||
abm.message.append(a)
|
||||
try:
|
||||
if t not in ComponentTypes:
|
||||
logger.warning(
|
||||
f"不支持的消息段类型,已忽略: {t}, data={m['data']}"
|
||||
)
|
||||
continue
|
||||
a = ComponentTypes[t](**m["data"])
|
||||
abm.message.append(a)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"消息段解析失败: type={t}, data={m['data']}. {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
|
||||
@@ -50,8 +50,6 @@ class DingtalkPlatformAdapter(Platform):
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
self.client_id = platform_config["client_id"]
|
||||
self.client_secret = platform_config["client_secret"]
|
||||
|
||||
@@ -129,10 +127,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
if id := self._id_to_sid(user.dingtalk_id):
|
||||
abm.message.append(At(qq=id))
|
||||
abm.group_id = message.conversation_id
|
||||
if self.unique_session:
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
abm.session_id = abm.group_id
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
|
||||
@@ -25,6 +25,20 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
client: dingtalk_stream.ChatbotHandler,
|
||||
message: MessageChain,
|
||||
):
|
||||
icm = cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message)
|
||||
ats = []
|
||||
# fixes: #4218
|
||||
# 钉钉 at 机器人需要使用 sender_staff_id 而不是 sender_id
|
||||
for i in message.chain:
|
||||
if isinstance(i, Comp.At):
|
||||
print(i.qq, icm.sender_id, icm.sender_staff_id)
|
||||
if str(i.qq) in str(icm.sender_id or ""):
|
||||
# 适配器会将开头的 $:LWCP_v1:$ 去掉,因此我们用 in 判断
|
||||
ats.append(f"@{icm.sender_staff_id}")
|
||||
else:
|
||||
ats.append(f"@{i.qq}")
|
||||
at_str = " ".join(ats)
|
||||
|
||||
for segment in message.chain:
|
||||
if isinstance(segment, Comp.Plain):
|
||||
segment.text = segment.text.strip()
|
||||
@@ -32,7 +46,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
None,
|
||||
client.reply_markdown,
|
||||
segment.text,
|
||||
segment.text,
|
||||
f"{at_str} {segment.text}".strip(),
|
||||
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
|
||||
)
|
||||
elif isinstance(segment, Comp.Image):
|
||||
|
||||
@@ -44,8 +44,6 @@ class LarkPlatformAdapter(Platform):
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
self.appid = platform_config["app_id"]
|
||||
self.appsecret = platform_config["app_secret"]
|
||||
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
|
||||
@@ -317,14 +315,8 @@ class LarkPlatformAdapter(Platform):
|
||||
user_id=event.event.sender.sender_id.open_id,
|
||||
nickname=event.event.sender.sender_id.open_id[:8],
|
||||
)
|
||||
# 独立会话
|
||||
if not self.unique_session:
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
elif abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
|
||||
@@ -91,8 +91,6 @@ class MisskeyPlatformAdapter(Platform):
|
||||
except Exception:
|
||||
self.max_download_bytes = None
|
||||
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
|
||||
self.api: MisskeyAPI | None = None
|
||||
self._running = False
|
||||
self.client_self_id = ""
|
||||
@@ -641,7 +639,6 @@ class MisskeyPlatformAdapter(Platform):
|
||||
sender_info,
|
||||
self.client_self_id,
|
||||
is_chat=False,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
cache_user_info(
|
||||
self._user_cache,
|
||||
@@ -690,7 +687,6 @@ class MisskeyPlatformAdapter(Platform):
|
||||
sender_info,
|
||||
self.client_self_id,
|
||||
is_chat=True,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
cache_user_info(
|
||||
self._user_cache,
|
||||
@@ -720,7 +716,6 @@ class MisskeyPlatformAdapter(Platform):
|
||||
self.client_self_id,
|
||||
is_chat=False,
|
||||
room_id=room_id,
|
||||
unique_session=self.unique_session,
|
||||
)
|
||||
|
||||
cache_user_info(
|
||||
|
||||
@@ -338,7 +338,6 @@ def create_base_message(
|
||||
client_self_id: str,
|
||||
is_chat: bool = False,
|
||||
room_id: str | None = None,
|
||||
unique_session: bool = False,
|
||||
) -> AstrBotMessage:
|
||||
"""创建基础消息对象"""
|
||||
message = AstrBotMessage()
|
||||
@@ -353,8 +352,6 @@ def create_base_message(
|
||||
if room_id:
|
||||
session_prefix = "room"
|
||||
session_id = f"{session_prefix}%{room_id}"
|
||||
if unique_session:
|
||||
session_id += f"_{sender_info['sender_id']}"
|
||||
message.type = MessageType.GROUP_MESSAGE
|
||||
message.group_id = room_id
|
||||
elif is_chat:
|
||||
|
||||
@@ -44,11 +44,8 @@ class botClient(Client):
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id
|
||||
if self.platform.unique_session
|
||||
else cast(str, message.group_openid)
|
||||
)
|
||||
abm.group_id = cast(str, message.group_openid)
|
||||
abm.session_id = abm.group_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到频道消息
|
||||
@@ -57,9 +54,8 @@ class botClient(Client):
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id if self.platform.unique_session else message.channel_id
|
||||
)
|
||||
abm.group_id = message.channel_id
|
||||
abm.session_id = abm.group_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到私聊消息
|
||||
@@ -104,7 +100,6 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
|
||||
self.appid = platform_config["appid"]
|
||||
self.secret = platform_config["secret"]
|
||||
self.unique_session: bool = platform_settings["unique_session"]
|
||||
qq_group = platform_config["enable_group_c2c"]
|
||||
guild_dm = platform_config["enable_guild_direct_message"]
|
||||
|
||||
|
||||
@@ -35,11 +35,8 @@ class botClient(Client):
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id
|
||||
if self.platform.unique_session
|
||||
else cast(str, message.group_openid)
|
||||
)
|
||||
abm.group_id = cast(str, message.group_openid)
|
||||
abm.session_id = abm.group_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到频道消息
|
||||
@@ -48,9 +45,8 @@ class botClient(Client):
|
||||
message,
|
||||
MessageType.GROUP_MESSAGE,
|
||||
)
|
||||
abm.session_id = (
|
||||
abm.sender.user_id if self.platform.unique_session else message.channel_id
|
||||
)
|
||||
abm.group_id = message.channel_id
|
||||
abm.session_id = abm.group_id
|
||||
self._commit(abm)
|
||||
|
||||
# 收到私聊消息
|
||||
@@ -95,7 +91,6 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
|
||||
self.appid = platform_config["appid"]
|
||||
self.secret = platform_config["secret"]
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
|
||||
|
||||
intents = botpy.Intents(
|
||||
|
||||
@@ -142,7 +142,12 @@ class SatoriPlatformAdapter(Platform):
|
||||
raise ValueError(f"WebSocket URL必须以ws://或wss://开头: {self.endpoint}")
|
||||
|
||||
try:
|
||||
websocket = await connect(self.endpoint, additional_headers={})
|
||||
websocket = await connect(
|
||||
self.endpoint,
|
||||
additional_headers={},
|
||||
max_size=10 * 1024 * 1024, # 10MB
|
||||
)
|
||||
|
||||
self.ws = websocket
|
||||
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
@@ -41,7 +41,6 @@ class SlackAdapter(Platform):
|
||||
) -> None:
|
||||
super().__init__(platform_config, event_queue)
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
self.bot_token = platform_config.get("bot_token")
|
||||
self.app_token = platform_config.get("app_token")
|
||||
@@ -147,12 +146,10 @@ class SlackAdapter(Platform):
|
||||
abm.group_id = channel_id
|
||||
|
||||
# 设置会话ID
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = f"{user_id}_{channel_id}"
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = (
|
||||
channel_id if abm.type == MessageType.GROUP_MESSAGE else user_id
|
||||
)
|
||||
abm.session_id = user_id
|
||||
|
||||
abm.message_id = event.get("client_msg_id", uuid.uuid4().hex)
|
||||
abm.timestamp = int(float(event.get("ts", time.time())))
|
||||
|
||||
@@ -200,6 +200,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
if isinstance(chain, MessageChain):
|
||||
if chain.type == "break":
|
||||
# 分割符
|
||||
if message_id:
|
||||
try:
|
||||
await self.client.edit_message_text(
|
||||
text=delta,
|
||||
chat_id=payload["chat_id"],
|
||||
message_id=message_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"编辑消息失败(streaming-break): {e!s}")
|
||||
message_id = None # 重置消息 ID
|
||||
delta = "" # 重置 delta
|
||||
continue
|
||||
|
||||
@@ -79,7 +79,6 @@ class WebChatAdapter(Platform):
|
||||
super().__init__(platform_config, event_queue)
|
||||
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import File, Image, Plain, Record
|
||||
from astrbot.api.message_components import File, Image, Json, Plain, Record
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from .webchat_queue_mgr import webchat_queue_mgr
|
||||
@@ -41,12 +42,20 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "plain",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
"chain_type": message.type,
|
||||
},
|
||||
)
|
||||
elif isinstance(comp, Json):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "plain",
|
||||
"data": json.dumps(comp.data, ensure_ascii=False),
|
||||
"streaming": streaming,
|
||||
"chain_type": message.type,
|
||||
},
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = f"{str(uuid.uuid4())}.jpg"
|
||||
@@ -58,7 +67,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "image",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
},
|
||||
@@ -74,7 +82,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "record",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
},
|
||||
@@ -91,7 +98,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "file",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
},
|
||||
@@ -111,18 +117,17 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
cid = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
async for chain in generator:
|
||||
if chain.type == "break" and final_data:
|
||||
# 分割符
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "break", # break means a segment end
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
},
|
||||
)
|
||||
final_data = ""
|
||||
continue
|
||||
# if chain.type == "break" and final_data:
|
||||
# # 分割符
|
||||
# await web_chat_back_queue.put(
|
||||
# {
|
||||
# "type": "break", # break means a segment end
|
||||
# "data": final_data,
|
||||
# "streaming": True,
|
||||
# },
|
||||
# )
|
||||
# final_data = ""
|
||||
# continue
|
||||
|
||||
r = await WebChatMessageEvent._send(
|
||||
chain,
|
||||
@@ -142,7 +147,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
"data": final_data,
|
||||
"reasoning": reasoning_content,
|
||||
"streaming": True,
|
||||
"cid": cid,
|
||||
},
|
||||
)
|
||||
await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -47,7 +47,6 @@ class WeChatPadProAdapter(Platform):
|
||||
self._shutdown_event = None
|
||||
self.wxnewpass = None
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings.get("unique_session", False)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
name="wechatpadpro",
|
||||
@@ -509,11 +508,10 @@ class WeChatPadProAdapter(Platform):
|
||||
if accurate_nickname:
|
||||
abm.sender.nickname = accurate_nickname
|
||||
|
||||
# 对于群聊,session_id 可以是群聊 ID 或发送者 ID + 群聊 ID (如果 unique_session 为 True)
|
||||
if self.unique_session:
|
||||
abm.session_id = f"{from_user_name}#{abm.sender.user_id}"
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = abm.group_id
|
||||
else:
|
||||
abm.session_id = from_user_name
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
msg_source = raw_message.get("msg_source", "")
|
||||
if self.wxid in msg_source:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import enum
|
||||
import json
|
||||
@@ -12,6 +14,7 @@ import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import (
|
||||
AssistantMessageSegment,
|
||||
ContentPart,
|
||||
ToolCall,
|
||||
ToolCallMessageSegment,
|
||||
)
|
||||
@@ -90,6 +93,8 @@ class ProviderRequest:
|
||||
"""会话 ID"""
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
|
||||
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。支持 dict 或 ContentPart 对象"""
|
||||
func_tool: ToolSet | None = None
|
||||
"""可用的函数工具"""
|
||||
contexts: list[dict] = field(default_factory=list)
|
||||
@@ -164,13 +169,23 @@ class ProviderRequest:
|
||||
|
||||
async def assemble_context(self) -> dict:
|
||||
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if self.prompt and self.prompt.strip():
|
||||
content_blocks.append({"type": "text", "text": self.prompt})
|
||||
elif self.image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if self.extra_user_content_parts:
|
||||
for part in self.extra_user_content_parts:
|
||||
content_blocks.append(part.model_dump())
|
||||
|
||||
# 3. 图片内容
|
||||
if self.image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": self.prompt if self.prompt else "[图片]"},
|
||||
],
|
||||
}
|
||||
for image_url in self.image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
@@ -183,11 +198,21 @@ class ProviderRequest:
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
content_blocks.append(
|
||||
{"type": "image_url", "image_url": {"url": image_data}},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": self.prompt}
|
||||
|
||||
# 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容
|
||||
if (
|
||||
len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
and not self.extra_user_content_parts
|
||||
and not self.image_urls
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def _encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
@@ -199,6 +224,38 @@ class ProviderRequest:
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsage:
|
||||
input_other: int = 0
|
||||
"""The number of input tokens, excluding cached tokens."""
|
||||
input_cached: int = 0
|
||||
"""The number of input cached tokens."""
|
||||
output: int = 0
|
||||
"""The number of output tokens."""
|
||||
|
||||
@property
|
||||
def total(self) -> int:
|
||||
return self.input_other + self.input_cached + self.output
|
||||
|
||||
@property
|
||||
def input(self) -> int:
|
||||
return self.input_other + self.input_cached
|
||||
|
||||
def __add__(self, other: TokenUsage) -> TokenUsage:
|
||||
return TokenUsage(
|
||||
input_other=self.input_other + other.input_other,
|
||||
input_cached=self.input_cached + other.input_cached,
|
||||
output=self.output + other.output,
|
||||
)
|
||||
|
||||
def __sub__(self, other: TokenUsage) -> TokenUsage:
|
||||
return TokenUsage(
|
||||
input_other=self.input_other - other.input_other,
|
||||
input_cached=self.input_cached - other.input_cached,
|
||||
output=self.output - other.output,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
@@ -215,6 +272,8 @@ class LLMResponse:
|
||||
"""Tool call extra content. tool_call_id -> extra_content dict"""
|
||||
reasoning_content: str = ""
|
||||
"""The reasoning content extracted from the LLM, if any."""
|
||||
reasoning_signature: str | None = None
|
||||
"""The signature of the reasoning content, if any."""
|
||||
|
||||
raw_completion: (
|
||||
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
|
||||
@@ -227,20 +286,29 @@ class LLMResponse:
|
||||
is_chunk: bool = False
|
||||
"""Indicates if the response is a chunked response."""
|
||||
|
||||
id: str | None = None
|
||||
"""The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response."""
|
||||
usage: TokenUsage | None = None
|
||||
"""The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: str,
|
||||
completion_text: str = "",
|
||||
completion_text: str | None = None,
|
||||
result_chain: MessageChain | None = None,
|
||||
tools_call_args: list[dict[str, Any]] | None = None,
|
||||
tools_call_name: list[str] | None = None,
|
||||
tools_call_ids: list[str] | None = None,
|
||||
tools_call_extra_content: dict[str, dict[str, Any]] | None = None,
|
||||
reasoning_content: str | None = None,
|
||||
reasoning_signature: str | None = None,
|
||||
raw_completion: ChatCompletion
|
||||
| GenerateContentResponse
|
||||
| AnthropicMessage
|
||||
| None = None,
|
||||
is_chunk: bool = False,
|
||||
id: str | None = None,
|
||||
usage: TokenUsage | None = None,
|
||||
):
|
||||
"""初始化 LLMResponse
|
||||
|
||||
@@ -253,6 +321,8 @@ class LLMResponse:
|
||||
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
|
||||
|
||||
"""
|
||||
if reasoning_content is None:
|
||||
reasoning_content = ""
|
||||
if tools_call_args is None:
|
||||
tools_call_args = []
|
||||
if tools_call_name is None:
|
||||
@@ -269,6 +339,8 @@ class LLMResponse:
|
||||
self.tools_call_name = tools_call_name
|
||||
self.tools_call_ids = tools_call_ids
|
||||
self.tools_call_extra_content = tools_call_extra_content
|
||||
self.reasoning_content = reasoning_content
|
||||
self.reasoning_signature = reasoning_signature
|
||||
self.raw_completion = raw_completion
|
||||
self.is_chunk = is_chunk
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import traceback
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
@@ -32,10 +33,12 @@ class ProviderManager:
|
||||
persona_mgr: PersonaManager,
|
||||
):
|
||||
self.reload_lock = asyncio.Lock()
|
||||
self.resource_lock = asyncio.Lock()
|
||||
self.persona_mgr = persona_mgr
|
||||
self.acm = acm
|
||||
config = acm.confs["default"]
|
||||
self.providers_config: list = config["provider"]
|
||||
self.provider_sources_config: list = config.get("provider_sources", [])
|
||||
self.provider_settings: dict = config["provider_settings"]
|
||||
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||||
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
||||
@@ -148,6 +151,7 @@ class ProviderManager:
|
||||
|
||||
"""
|
||||
provider = None
|
||||
provider_id = None
|
||||
if umo:
|
||||
provider_id = sp.get(
|
||||
f"provider_perf_{provider_type.value}",
|
||||
@@ -185,6 +189,12 @@ class ProviderManager:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
if not provider and provider_id:
|
||||
logger.warning(
|
||||
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
|
||||
)
|
||||
|
||||
return provider
|
||||
|
||||
async def initialize(self):
|
||||
@@ -251,7 +261,136 @@ class ProviderManager:
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
||||
|
||||
def dynamic_import_provider(self, type: str):
|
||||
"""动态导入提供商适配器模块
|
||||
|
||||
Args:
|
||||
type (str): 提供商请求类型。
|
||||
|
||||
Raises:
|
||||
ImportError: 如果提供商类型未知或无法导入对应模块,则抛出异常。
|
||||
"""
|
||||
match type:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import (
|
||||
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
|
||||
)
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||
case "groq_chat_completion":
|
||||
from .sources.groq_source import ProviderGroq as ProviderGroq
|
||||
case "anthropic_chat_completion":
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
)
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import (
|
||||
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
||||
)
|
||||
case "sensevoice_stt_selfhost":
|
||||
from .sources.sensevoice_selfhosted_source import (
|
||||
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
|
||||
)
|
||||
case "openai_whisper_api":
|
||||
from .sources.whisper_api_source import (
|
||||
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
|
||||
)
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import (
|
||||
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
||||
)
|
||||
case "xinference_stt":
|
||||
from .sources.xinference_stt_provider import (
|
||||
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
||||
)
|
||||
case "openai_tts_api":
|
||||
from .sources.openai_tts_api_source import (
|
||||
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
||||
)
|
||||
case "edge_tts":
|
||||
from .sources.edge_tts_source import (
|
||||
ProviderEdgeTTS as ProviderEdgeTTS,
|
||||
)
|
||||
case "gsv_tts_selfhost":
|
||||
from .sources.gsv_selfhosted_source import (
|
||||
ProviderGSVTTS as ProviderGSVTTS,
|
||||
)
|
||||
case "gsvi_tts_api":
|
||||
from .sources.gsvi_tts_source import (
|
||||
ProviderGSVITTS as ProviderGSVITTS,
|
||||
)
|
||||
case "fishaudio_tts_api":
|
||||
from .sources.fishaudio_tts_api_source import (
|
||||
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
||||
)
|
||||
case "dashscope_tts":
|
||||
from .sources.dashscope_tts import (
|
||||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||
)
|
||||
case "azure_tts":
|
||||
from .sources.azure_tts_source import (
|
||||
AzureTTSProvider as AzureTTSProvider,
|
||||
)
|
||||
case "minimax_tts_api":
|
||||
from .sources.minimax_tts_api_source import (
|
||||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||||
)
|
||||
case "volcengine_tts":
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "gemini_tts":
|
||||
from .sources.gemini_tts_source import (
|
||||
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
)
|
||||
case "gemini_embedding":
|
||||
from .sources.gemini_embedding_source import (
|
||||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||||
)
|
||||
case "vllm_rerank":
|
||||
from .sources.vllm_rerank_source import (
|
||||
VLLMRerankProvider as VLLMRerankProvider,
|
||||
)
|
||||
case "xinference_rerank":
|
||||
from .sources.xinference_rerank_source import (
|
||||
XinferenceRerankProvider as XinferenceRerankProvider,
|
||||
)
|
||||
case "bailian_rerank":
|
||||
from .sources.bailian_rerank_source import (
|
||||
BailianRerankProvider as BailianRerankProvider,
|
||||
)
|
||||
|
||||
def get_merged_provider_config(self, provider_config: dict) -> dict:
|
||||
"""获取 provider 配置和 provider_source 配置合并后的结果
|
||||
|
||||
Returns:
|
||||
dict: 合并后的 provider 配置,key 为 provider id,value 为合并后的配置字典
|
||||
"""
|
||||
pc = copy.deepcopy(provider_config)
|
||||
provider_source_id = pc.get("provider_source_id", "")
|
||||
if provider_source_id:
|
||||
provider_source = None
|
||||
for ps in self.provider_sources_config:
|
||||
if ps.get("id") == provider_source_id:
|
||||
provider_source = ps
|
||||
break
|
||||
|
||||
if provider_source:
|
||||
# 合并配置,provider 的配置优先级更高
|
||||
merged_config = {**provider_source, **pc}
|
||||
# 保持 id 为 provider 的 id,而不是 source 的 id
|
||||
merged_config["id"] = pc["id"]
|
||||
pc = merged_config
|
||||
return pc
|
||||
|
||||
async def load_provider(self, provider_config: dict):
|
||||
# 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并
|
||||
provider_config = self.get_merged_provider_config(provider_config)
|
||||
|
||||
if not provider_config["enable"]:
|
||||
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
|
||||
return
|
||||
@@ -264,99 +403,7 @@ class ProviderManager:
|
||||
|
||||
# 动态导入
|
||||
try:
|
||||
match provider_config["type"]:
|
||||
case "openai_chat_completion":
|
||||
from .sources.openai_source import (
|
||||
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
|
||||
)
|
||||
case "zhipu_chat_completion":
|
||||
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
|
||||
case "groq_chat_completion":
|
||||
from .sources.groq_source import ProviderGroq as ProviderGroq
|
||||
case "anthropic_chat_completion":
|
||||
from .sources.anthropic_source import (
|
||||
ProviderAnthropic as ProviderAnthropic,
|
||||
)
|
||||
case "googlegenai_chat_completion":
|
||||
from .sources.gemini_source import (
|
||||
ProviderGoogleGenAI as ProviderGoogleGenAI,
|
||||
)
|
||||
case "sensevoice_stt_selfhost":
|
||||
from .sources.sensevoice_selfhosted_source import (
|
||||
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
|
||||
)
|
||||
case "openai_whisper_api":
|
||||
from .sources.whisper_api_source import (
|
||||
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
|
||||
)
|
||||
case "openai_whisper_selfhost":
|
||||
from .sources.whisper_selfhosted_source import (
|
||||
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
|
||||
)
|
||||
case "xinference_stt":
|
||||
from .sources.xinference_stt_provider import (
|
||||
ProviderXinferenceSTT as ProviderXinferenceSTT,
|
||||
)
|
||||
case "openai_tts_api":
|
||||
from .sources.openai_tts_api_source import (
|
||||
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
|
||||
)
|
||||
case "edge_tts":
|
||||
from .sources.edge_tts_source import (
|
||||
ProviderEdgeTTS as ProviderEdgeTTS,
|
||||
)
|
||||
case "gsv_tts_selfhost":
|
||||
from .sources.gsv_selfhosted_source import (
|
||||
ProviderGSVTTS as ProviderGSVTTS,
|
||||
)
|
||||
case "gsvi_tts_api":
|
||||
from .sources.gsvi_tts_source import (
|
||||
ProviderGSVITTS as ProviderGSVITTS,
|
||||
)
|
||||
case "fishaudio_tts_api":
|
||||
from .sources.fishaudio_tts_api_source import (
|
||||
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
||||
)
|
||||
case "dashscope_tts":
|
||||
from .sources.dashscope_tts import (
|
||||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||
)
|
||||
case "azure_tts":
|
||||
from .sources.azure_tts_source import (
|
||||
AzureTTSProvider as AzureTTSProvider,
|
||||
)
|
||||
case "minimax_tts_api":
|
||||
from .sources.minimax_tts_api_source import (
|
||||
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
|
||||
)
|
||||
case "volcengine_tts":
|
||||
from .sources.volcengine_tts import (
|
||||
ProviderVolcengineTTS as ProviderVolcengineTTS,
|
||||
)
|
||||
case "gemini_tts":
|
||||
from .sources.gemini_tts_source import (
|
||||
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
|
||||
)
|
||||
case "openai_embedding":
|
||||
from .sources.openai_embedding_source import (
|
||||
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
|
||||
)
|
||||
case "gemini_embedding":
|
||||
from .sources.gemini_embedding_source import (
|
||||
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
|
||||
)
|
||||
case "vllm_rerank":
|
||||
from .sources.vllm_rerank_source import (
|
||||
VLLMRerankProvider as VLLMRerankProvider,
|
||||
)
|
||||
case "xinference_rerank":
|
||||
from .sources.xinference_rerank_source import (
|
||||
XinferenceRerankProvider as XinferenceRerankProvider,
|
||||
)
|
||||
case "bailian_rerank":
|
||||
from .sources.bailian_rerank_source import (
|
||||
BailianRerankProvider as BailianRerankProvider,
|
||||
)
|
||||
self.dynamic_import_provider(provider_config["type"])
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
|
||||
@@ -499,6 +546,7 @@ class ProviderManager:
|
||||
|
||||
# 和配置文件保持同步
|
||||
self.providers_config = astrbot_config["provider"]
|
||||
self.provider_sources_config = astrbot_config.get("provider_sources", [])
|
||||
config_ids = [provider["id"] for provider in self.providers_config]
|
||||
logger.info(f"providers in user's config: {config_ids}")
|
||||
for key in list(self.inst_map.keys()):
|
||||
@@ -570,6 +618,68 @@ class ProviderManager:
|
||||
)
|
||||
del self.inst_map[provider_id]
|
||||
|
||||
async def delete_provider(
|
||||
self, provider_id: str | None = None, provider_source_id: str | None = None
|
||||
):
|
||||
"""Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion."""
|
||||
async with self.resource_lock:
|
||||
# delete from config
|
||||
target_prov_ids = []
|
||||
if provider_id:
|
||||
target_prov_ids.append(provider_id)
|
||||
else:
|
||||
for prov in self.providers_config:
|
||||
if prov.get("provider_source_id") == provider_source_id:
|
||||
target_prov_ids.append(prov.get("id"))
|
||||
config = self.acm.default_conf
|
||||
for tpid in target_prov_ids:
|
||||
await self.terminate_provider(tpid)
|
||||
config["provider"] = [
|
||||
prov for prov in config["provider"] if prov.get("id") != tpid
|
||||
]
|
||||
config.save_config()
|
||||
logger.info(f"Provider {target_prov_ids} 已从配置中删除。")
|
||||
|
||||
async def update_provider(self, origin_provider_id: str, new_config: dict):
|
||||
"""Update provider config and reload the instance. Config will be saved after update."""
|
||||
async with self.resource_lock:
|
||||
npid = new_config.get("id", None)
|
||||
if not npid:
|
||||
raise ValueError("New provider config must have an 'id' field")
|
||||
config = self.acm.default_conf
|
||||
for provider in config["provider"]:
|
||||
if (
|
||||
provider.get("id", None) == npid
|
||||
and provider.get("id", None) != origin_provider_id
|
||||
):
|
||||
raise ValueError(f"Provider ID {npid} already exists")
|
||||
# update config
|
||||
for idx, provider in enumerate(config["provider"]):
|
||||
if provider.get("id", None) == origin_provider_id:
|
||||
config["provider"][idx] = new_config
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"Provider ID {origin_provider_id} not found")
|
||||
config.save_config()
|
||||
# reload instance
|
||||
await self.reload(new_config)
|
||||
|
||||
async def create_provider(self, new_config: dict):
|
||||
"""Add new provider config and load the instance. Config will be saved after addition."""
|
||||
async with self.resource_lock:
|
||||
npid = new_config.get("id", None)
|
||||
if not npid:
|
||||
raise ValueError("New provider config must have an 'id' field")
|
||||
config = self.acm.default_conf
|
||||
for provider in config["provider"]:
|
||||
if provider.get("id", None) == npid:
|
||||
raise ValueError(f"Provider ID {npid} already exists")
|
||||
# add to config
|
||||
config["provider"].append(new_config)
|
||||
config.save_config()
|
||||
# load instance
|
||||
await self.load_provider(new_config)
|
||||
|
||||
async def terminate(self):
|
||||
for provider_inst in self.provider_insts:
|
||||
if hasattr(provider_inst, "terminate"):
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TypeAlias, Union
|
||||
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.message import ContentPart, Message
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.provider.entities import (
|
||||
LLMResponse,
|
||||
@@ -103,6 +103,7 @@ class Provider(AbstractProvider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||
@@ -114,6 +115,7 @@ class Provider(AbstractProvider):
|
||||
tools: tool set
|
||||
contexts: 上下文,和 prompt 二选一使用
|
||||
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||
extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等)
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
|
||||
@@ -6,10 +6,13 @@ from mimetypes import guess_type
|
||||
import anthropic
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.message_delta_usage import MessageDeltaUsage
|
||||
from anthropic.types.usage import Usage
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
@@ -45,7 +48,9 @@ class ProviderAnthropic(Provider):
|
||||
base_url=self.base_url,
|
||||
)
|
||||
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
self.thinking_config = provider_config.get("anth_thinking_config", {})
|
||||
|
||||
self.set_model(provider_config.get("model", "unknown"))
|
||||
|
||||
def _prepare_payload(self, messages: list[dict]):
|
||||
"""准备 Anthropic API 的请求 payload
|
||||
@@ -61,12 +66,33 @@ class ProviderAnthropic(Provider):
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
system_prompt = message["content"]
|
||||
system_prompt = message["content"] or "<empty system prompt>"
|
||||
elif message["role"] == "assistant":
|
||||
blocks = []
|
||||
if isinstance(message["content"], str):
|
||||
reasoning_content = ""
|
||||
thinking_signature = ""
|
||||
if isinstance(message["content"], str) and message["content"].strip():
|
||||
blocks.append({"type": "text", "text": message["content"]})
|
||||
if "tool_calls" in message:
|
||||
elif isinstance(message["content"], list):
|
||||
for part in message["content"]:
|
||||
if part.get("type") == "think":
|
||||
# only pick the last think part for now
|
||||
reasoning_content = part.get("think")
|
||||
thinking_signature = part.get("encrypted")
|
||||
else:
|
||||
blocks.append(part)
|
||||
|
||||
if reasoning_content and thinking_signature:
|
||||
blocks.insert(
|
||||
0,
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": reasoning_content,
|
||||
"signature": thinking_signature,
|
||||
},
|
||||
)
|
||||
|
||||
if "tool_calls" in message and isinstance(message["tool_calls"], list):
|
||||
for tool_call in message["tool_calls"]:
|
||||
blocks.append( # noqa: PERF401
|
||||
{
|
||||
@@ -97,7 +123,7 @@ class ProviderAnthropic(Provider):
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
"content": message["content"] or "<empty response>",
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -107,12 +133,40 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
return system_prompt, new_messages
|
||||
|
||||
def _extract_usage(self, usage: Usage) -> TokenUsage:
|
||||
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
|
||||
return TokenUsage(
|
||||
input_other=usage.input_tokens or 0,
|
||||
input_cached=usage.cache_read_input_tokens or 0,
|
||||
output=usage.output_tokens,
|
||||
)
|
||||
|
||||
def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
|
||||
if usage.input_tokens is not None:
|
||||
token_usage.input_other = usage.input_tokens
|
||||
if usage.cache_read_input_tokens is not None:
|
||||
token_usage.input_cached = usage.cache_read_input_tokens
|
||||
if usage.output_tokens is not None:
|
||||
token_usage.output = usage.output_tokens
|
||||
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
if tools:
|
||||
if tool_list := tools.get_func_desc_anthropic_style():
|
||||
payloads["tools"] = tool_list
|
||||
|
||||
completion = await self.client.messages.create(**payloads, stream=False)
|
||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
|
||||
if "max_tokens" not in payloads:
|
||||
payloads["max_tokens"] = 1024
|
||||
if self.thinking_config.get("budget"):
|
||||
payloads["thinking"] = {
|
||||
"budget_tokens": self.thinking_config.get("budget"),
|
||||
"type": "enabled",
|
||||
}
|
||||
|
||||
completion = await self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
)
|
||||
|
||||
assert isinstance(completion, Message)
|
||||
logger.debug(f"completion: {completion}")
|
||||
@@ -127,10 +181,19 @@ class ProviderAnthropic(Provider):
|
||||
completion_text = str(content_block.text).strip()
|
||||
llm_response.completion_text = completion_text
|
||||
|
||||
if content_block.type == "thinking":
|
||||
reasoning_content = str(content_block.thinking).strip()
|
||||
llm_response.reasoning_content = reasoning_content
|
||||
llm_response.reasoning_signature = content_block.signature
|
||||
|
||||
if content_block.type == "tool_use":
|
||||
llm_response.tools_call_args.append(content_block.input)
|
||||
llm_response.tools_call_name.append(content_block.name)
|
||||
llm_response.tools_call_ids.append(content_block.id)
|
||||
|
||||
llm_response.id = completion.id
|
||||
llm_response.usage = self._extract_usage(completion.usage)
|
||||
|
||||
# TODO(Soulter): 处理 end_turn 情况
|
||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
||||
@@ -151,10 +214,29 @@ class ProviderAnthropic(Provider):
|
||||
# 用于累积最终结果
|
||||
final_text = ""
|
||||
final_tool_calls = []
|
||||
id = None
|
||||
usage = TokenUsage()
|
||||
extra_body = self.provider_config.get("custom_extra_body", {})
|
||||
reasoning_content = ""
|
||||
reasoning_signature = ""
|
||||
|
||||
async with self.client.messages.stream(**payloads) as stream:
|
||||
if "max_tokens" not in payloads:
|
||||
payloads["max_tokens"] = 1024
|
||||
if self.thinking_config.get("budget"):
|
||||
payloads["thinking"] = {
|
||||
"budget_tokens": self.thinking_config.get("budget"),
|
||||
"type": "enabled",
|
||||
}
|
||||
|
||||
async with self.client.messages.stream(
|
||||
**payloads, extra_body=extra_body
|
||||
) as stream:
|
||||
assert isinstance(stream, anthropic.AsyncMessageStream)
|
||||
async for event in stream:
|
||||
if event.type == "message_start":
|
||||
# the usage contains input token usage
|
||||
id = event.message.id
|
||||
usage = self._extract_usage(event.message.usage)
|
||||
if event.type == "content_block_start":
|
||||
if event.content_block.type == "text":
|
||||
# 文本块开始
|
||||
@@ -162,6 +244,8 @@ class ProviderAnthropic(Provider):
|
||||
role="assistant",
|
||||
completion_text="",
|
||||
is_chunk=True,
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
elif event.content_block.type == "tool_use":
|
||||
# 工具使用块开始,初始化缓冲区
|
||||
@@ -179,7 +263,24 @@ class ProviderAnthropic(Provider):
|
||||
role="assistant",
|
||||
completion_text=event.delta.text,
|
||||
is_chunk=True,
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
elif event.delta.type == "thinking_delta":
|
||||
# 思考增量
|
||||
reasoning = event.delta.thinking
|
||||
if reasoning:
|
||||
yield LLMResponse(
|
||||
role="assistant",
|
||||
reasoning_content=reasoning,
|
||||
is_chunk=True,
|
||||
usage=usage,
|
||||
id=id,
|
||||
reasoning_signature=reasoning_signature or None,
|
||||
)
|
||||
reasoning_content += reasoning
|
||||
elif event.delta.type == "signature_delta":
|
||||
reasoning_signature = event.delta.signature
|
||||
elif event.delta.type == "input_json_delta":
|
||||
# 工具调用参数增量
|
||||
if event.index in tool_use_buffer:
|
||||
@@ -215,6 +316,8 @@ class ProviderAnthropic(Provider):
|
||||
tools_call_name=[tool_info["name"]],
|
||||
tools_call_ids=[tool_info["id"]],
|
||||
is_chunk=True,
|
||||
usage=usage,
|
||||
id=id,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# JSON 解析失败,跳过这个工具调用
|
||||
@@ -223,11 +326,19 @@ class ProviderAnthropic(Provider):
|
||||
# 清理缓冲区
|
||||
del tool_use_buffer[event.index]
|
||||
|
||||
elif event.type == "message_delta":
|
||||
if event.usage:
|
||||
self._update_usage(usage, event.usage)
|
||||
|
||||
# 返回最终的完整结果
|
||||
final_response = LLMResponse(
|
||||
role="assistant",
|
||||
completion_text=final_text,
|
||||
is_chunk=False,
|
||||
usage=usage,
|
||||
id=id,
|
||||
reasoning_content=reasoning_content,
|
||||
reasoning_signature=reasoning_signature or None,
|
||||
)
|
||||
|
||||
if final_tool_calls:
|
||||
@@ -249,13 +360,16 @@ class ProviderAnthropic(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -277,10 +391,9 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": new_messages, **model_config}
|
||||
payloads = {"messages": new_messages, "model": model}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
@@ -290,28 +403,30 @@ class ProviderAnthropic(Provider):
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
# logger.error(f"发生了错误。Provider 配置如下: {model_config}")
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
prompt=None,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
image_urls=None,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
):
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -332,10 +447,9 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
system_prompt, new_messages = self._prepare_payload(context_query)
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": new_messages, **model_config}
|
||||
payloads = {"messages": new_messages, "model": model}
|
||||
|
||||
# Anthropic has a different way of handling system prompts
|
||||
if system_prompt:
|
||||
@@ -344,15 +458,15 @@ class ProviderAnthropic(Provider):
|
||||
async for llm_response in self._query_stream(payloads, func_tool):
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
||||
async def assemble_context(
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
if not image_urls:
|
||||
return {"role": "user", "content": text}
|
||||
|
||||
content = []
|
||||
content.append({"type": "text", "text": text})
|
||||
|
||||
for image_url in image_urls:
|
||||
async def resolve_image_url(image_url: str) -> dict | None:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
@@ -364,28 +478,68 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
return None
|
||||
|
||||
# Get mime type for the image
|
||||
mime_type, _ = guess_type(image_url)
|
||||
if not mime_type:
|
||||
mime_type = "image/jpeg" # Default to JPEG if can't determine
|
||||
|
||||
content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
},
|
||||
return {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": mime_type,
|
||||
"data": (
|
||||
image_data.split("base64,")[1]
|
||||
if "base64," in image_data
|
||||
else image_data
|
||||
),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
content = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for block in extra_user_content_parts:
|
||||
if isinstance(block, TextPart):
|
||||
content.append({"type": "text", "text": block.text})
|
||||
elif isinstance(block, ImageURLPart):
|
||||
image_dict = await resolve_image_url(block.image_url.url)
|
||||
if image_dict:
|
||||
content.append(image_dict)
|
||||
else:
|
||||
raise ValueError(f"不支持的额外内容块类型: {type(block)}")
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
image_dict = await resolve_image_url(image_url)
|
||||
if image_dict:
|
||||
content.append(image_dict)
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content) == 1
|
||||
and content[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
|
||||
@@ -56,10 +56,14 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
"api_base",
|
||||
"https://api.fish-audio.cn/v1",
|
||||
)
|
||||
try:
|
||||
self.timeout: int = int(provider_config.get("timeout", 20))
|
||||
except ValueError:
|
||||
self.timeout = 20
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
}
|
||||
self.set_model(provider_config["model"])
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
||||
"""获取角色的reference_id
|
||||
@@ -135,17 +139,21 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
|
||||
self.headers["content-type"] = "application/msgpack"
|
||||
request = await self._generate_request(text)
|
||||
async with AsyncClient(base_url=self.api_base).stream(
|
||||
async with AsyncClient(base_url=self.api_base, timeout=self.timeout).stream(
|
||||
"POST",
|
||||
"/tts",
|
||||
headers=self.headers,
|
||||
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
|
||||
) as response:
|
||||
if response.headers["content-type"] == "audio/wav":
|
||||
if response.status_code == 200 and response.headers.get(
|
||||
"content-type", ""
|
||||
).startswith("audio/"):
|
||||
with open(path, "wb") as f:
|
||||
async for chunk in response.aiter_bytes():
|
||||
f.write(chunk)
|
||||
return path
|
||||
body = await response.aread()
|
||||
text = body.decode("utf-8", errors="replace")
|
||||
raise Exception(f"Fish Audio API请求失败: {text}")
|
||||
error_bytes = await response.aread()
|
||||
error_text = error_bytes.decode("utf-8", errors="replace")[:1024]
|
||||
raise Exception(
|
||||
f"Fish Audio API请求失败: 状态码 {response.status_code}, 响应内容: {error_text}"
|
||||
)
|
||||
|
||||
@@ -13,8 +13,9 @@ from google.genai.errors import APIError
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
@@ -68,7 +69,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
self.api_base = self.api_base[:-1]
|
||||
|
||||
self._init_client()
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
self.set_model(provider_config.get("model", "unknown"))
|
||||
self._init_safety_settings()
|
||||
|
||||
def _init_client(self) -> None:
|
||||
@@ -138,7 +139,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
modalities = ["TEXT"]
|
||||
|
||||
tool_list: list[types.Tool] | None = []
|
||||
model_name = self.get_model()
|
||||
model_name = cast(str, payloads.get("model", self.get_model()))
|
||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||
native_search = self.provider_config.get("gm_native_search", False)
|
||||
url_context = self.provider_config.get("gm_url_context", False)
|
||||
@@ -197,6 +198,53 @@ class ProviderGoogleGenAI(Provider):
|
||||
types.Tool(function_declarations=func_desc["function_declarations"]),
|
||||
]
|
||||
|
||||
# oper thinking config
|
||||
thinking_config = None
|
||||
if model_name in [
|
||||
"gemini-2.5-pro",
|
||||
"gemini-2.5-pro-preview",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-flash-preview",
|
||||
"gemini-2.5-flash-lite",
|
||||
"gemini-2.5-flash-lite-preview",
|
||||
"gemini-robotics-er-1.5-preview",
|
||||
"gemini-live-2.5-flash-preview-native-audio-09-2025",
|
||||
]:
|
||||
# The thinkingBudget parameter, introduced with the Gemini 2.5 series
|
||||
thinking_budget = self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget", 0
|
||||
)
|
||||
if thinking_budget is not None:
|
||||
thinking_config = types.ThinkingConfig(
|
||||
thinking_budget=thinking_budget,
|
||||
)
|
||||
elif model_name in [
|
||||
"gemini-3-pro",
|
||||
"gemini-3-pro-preview",
|
||||
"gemini-3-flash",
|
||||
"gemini-3-flash-preview",
|
||||
"gemini-3-flash-lite",
|
||||
"gemini-3-flash-lite-preview",
|
||||
]:
|
||||
# The thinkingLevel parameter, recommended for Gemini 3 models and onwards
|
||||
# Gemini 2.5 series models don't support thinkingLevel; use thinkingBudget instead.
|
||||
thinking_level = self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"level", "HIGH"
|
||||
)
|
||||
if thinking_level and isinstance(thinking_level, str):
|
||||
thinking_level = thinking_level.upper()
|
||||
if thinking_level not in ["MINIMAL", "LOW", "MEDIUM", "HIGH"]:
|
||||
logger.warning(
|
||||
f"Invalid thinking level: {thinking_level}, using HIGH"
|
||||
)
|
||||
thinking_level = "HIGH"
|
||||
level = types.ThinkingLevel(thinking_level)
|
||||
thinking_config = types.ThinkingConfig()
|
||||
if not hasattr(types.ThinkingConfig, "thinking_level"):
|
||||
setattr(types.ThinkingConfig, "thinking_level", level)
|
||||
else:
|
||||
thinking_config.thinking_level = level
|
||||
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
@@ -216,22 +264,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
response_modalities=modalities,
|
||||
tools=cast(types.ToolListUnion | None, tool_list),
|
||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||
thinking_config=(
|
||||
types.ThinkingConfig(
|
||||
thinking_budget=min(
|
||||
int(
|
||||
self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget",
|
||||
0,
|
||||
),
|
||||
),
|
||||
24576,
|
||||
),
|
||||
)
|
||||
if "gemini-2.5-flash" in self.get_model()
|
||||
and hasattr(types.ThinkingConfig, "thinking_budget")
|
||||
else None
|
||||
),
|
||||
thinking_config=thinking_config,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True,
|
||||
),
|
||||
@@ -288,9 +321,37 @@ class ProviderGoogleGenAI(Provider):
|
||||
append_or_extend(gemini_contents, parts, types.UserContent)
|
||||
|
||||
elif role == "assistant":
|
||||
if content:
|
||||
if isinstance(content, str):
|
||||
parts = [types.Part.from_text(text=content)]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
elif isinstance(content, list):
|
||||
parts = []
|
||||
thinking_signature = None
|
||||
text = ""
|
||||
for part in content:
|
||||
# for most cases, assistant content only contains two parts: think and text
|
||||
if part.get("type") == "think":
|
||||
thinking_signature = part.get("encrypted") or None
|
||||
else:
|
||||
text += str(part.get("text"))
|
||||
|
||||
if thinking_signature and isinstance(thinking_signature, str):
|
||||
try:
|
||||
thinking_signature = base64.b64decode(thinking_signature)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to decode google gemini thinking signature: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
thinking_signature = None
|
||||
parts.append(
|
||||
types.Part(
|
||||
text=text,
|
||||
thought_signature=thinking_signature,
|
||||
)
|
||||
)
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
|
||||
elif not native_tool_enabled and "tool_calls" in message:
|
||||
parts = []
|
||||
for tool in message["tool_calls"]:
|
||||
@@ -347,6 +408,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
]
|
||||
return "".join(thought_buf).strip()
|
||||
|
||||
def _extract_usage(
|
||||
self, usage_metadata: types.GenerateContentResponseUsageMetadata
|
||||
) -> TokenUsage:
|
||||
"""Extract usage from candidate"""
|
||||
return TokenUsage(
|
||||
input_other=usage_metadata.prompt_token_count or 0,
|
||||
input_cached=usage_metadata.cached_content_token_count or 0,
|
||||
output=usage_metadata.candidates_token_count or 0,
|
||||
)
|
||||
|
||||
def _process_content_parts(
|
||||
self,
|
||||
candidate: types.Candidate,
|
||||
@@ -398,7 +469,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
for part in result_parts:
|
||||
if part.text:
|
||||
chain.append(Comp.Plain(part.text))
|
||||
elif (
|
||||
|
||||
if (
|
||||
part.function_call
|
||||
and part.function_call.name is not None
|
||||
and part.function_call.args is not None
|
||||
@@ -415,13 +487,18 @@ class ProviderGoogleGenAI(Provider):
|
||||
llm_response.tools_call_extra_content[tool_call_id] = {
|
||||
"google": {"thought_signature": ts_bs64}
|
||||
}
|
||||
elif (
|
||||
|
||||
if (
|
||||
part.inline_data
|
||||
and part.inline_data.mime_type
|
||||
and part.inline_data.mime_type.startswith("image/")
|
||||
and part.inline_data.data
|
||||
):
|
||||
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||
|
||||
if ts := part.thought_signature:
|
||||
# only keep the last thinking signature
|
||||
llm_response.reasoning_signature = base64.b64encode(ts).decode("utf-8")
|
||||
return MessageChain(chain=chain)
|
||||
|
||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||
@@ -431,6 +508,8 @@ class ProviderGoogleGenAI(Provider):
|
||||
None,
|
||||
)
|
||||
|
||||
model = payloads.get("model", self.get_model())
|
||||
|
||||
modalities = ["TEXT"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalities.append("IMAGE")
|
||||
@@ -449,7 +528,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
temperature,
|
||||
)
|
||||
result = await self.client.models.generate_content(
|
||||
model=self.get_model(),
|
||||
model=model,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
)
|
||||
@@ -475,11 +554,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
e.message = ""
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
logger.warning(f"{model} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
elif (
|
||||
"Multi-modal output is not supported" in e.message
|
||||
@@ -488,7 +567,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
or "only supports text output" in e.message
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态",
|
||||
f"{model} 不支持多模态输出,降级为文本模态",
|
||||
)
|
||||
modalities = ["TEXT"]
|
||||
else:
|
||||
@@ -501,6 +580,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
result.candidates[0],
|
||||
llm_response,
|
||||
)
|
||||
llm_response.id = result.response_id
|
||||
if result.usage_metadata:
|
||||
llm_response.usage = self._extract_usage(result.usage_metadata)
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
@@ -513,7 +595,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
model = payloads.get("model", self.get_model())
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
result = None
|
||||
@@ -525,7 +607,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_instruction,
|
||||
)
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=self.get_model(),
|
||||
model=model,
|
||||
contents=cast(types.ContentListUnion, conversation),
|
||||
config=config,
|
||||
)
|
||||
@@ -535,11 +617,11 @@ class ProviderGoogleGenAI(Provider):
|
||||
e.message = ""
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
logger.warning(f"{model} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
else:
|
||||
raise
|
||||
@@ -569,6 +651,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
chunk.candidates[0],
|
||||
llm_response,
|
||||
)
|
||||
llm_response.id = chunk.response_id
|
||||
if chunk.usage_metadata:
|
||||
llm_response.usage = self._extract_usage(chunk.usage_metadata)
|
||||
yield llm_response
|
||||
return
|
||||
|
||||
@@ -596,6 +681,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
chunk.candidates[0],
|
||||
final_response,
|
||||
)
|
||||
final_response.id = chunk.response_id
|
||||
if chunk.usage_metadata:
|
||||
final_response.usage = self._extract_usage(chunk.usage_metadata)
|
||||
break
|
||||
|
||||
# Yield final complete response with accumulated text
|
||||
@@ -627,13 +715,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -652,10 +743,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
@@ -680,13 +770,16 @@ class ProviderGoogleGenAI(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
if contexts is None:
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -705,10 +798,9 @@ class ProviderGoogleGenAI(Provider):
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
@@ -746,33 +838,75 @@ class ProviderGoogleGenAI(Provider):
|
||||
self.chosen_api_key = key
|
||||
self._init_client()
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
|
||||
async def assemble_context(
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
):
|
||||
"""组装上下文。"""
|
||||
if image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
|
||||
async def resolve_image_part(image_url: str) -> dict | None:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
return None
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content_blocks.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for part in extra_user_content_parts:
|
||||
if isinstance(part, TextPart):
|
||||
content_blocks.append({"type": "text", "text": part.text})
|
||||
elif isinstance(part, ImageURLPart):
|
||||
image_part = await resolve_image_part(part.image_url.url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": text}
|
||||
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
image_part = await resolve_image_part(image_url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -51,7 +51,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
"voice_id": ""
|
||||
if self.is_timber_weight
|
||||
else provider_config.get("minimax-voice-id", ""),
|
||||
"emotion": provider_config.get("minimax-voice-emotion", "neutral"),
|
||||
"emotion": provider_config.get("minimax-voice-emotion", "auto"),
|
||||
"latex_read": provider_config.get("minimax-voice-latex", False),
|
||||
"english_normalization": provider_config.get(
|
||||
"minimax-voice-english-normalization",
|
||||
@@ -59,6 +59,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
|
||||
),
|
||||
}
|
||||
|
||||
if self.voice_setting["emotion"] == "auto":
|
||||
self.voice_setting.pop("emotion", None)
|
||||
|
||||
self.audio_setting: dict = {
|
||||
"sample_rate": 32000,
|
||||
"bitrate": 128000,
|
||||
|
||||
@@ -12,14 +12,15 @@ from openai._exceptions import NotFoundError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Provider
|
||||
from astrbot.core.agent.message import Message
|
||||
from astrbot.core.agent.message import ContentPart, ImageURLPart, Message, TextPart
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
@@ -68,34 +69,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.client.chat.completions.create,
|
||||
).parameters.keys()
|
||||
|
||||
model_config = provider_config.get("model_config", {})
|
||||
model = model_config.get("model", "unknown")
|
||||
model = provider_config.get("model", "unknown")
|
||||
self.set_model(model)
|
||||
|
||||
self.reasoning_key = "reasoning_content"
|
||||
|
||||
def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
|
||||
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
|
||||
|
||||
- 仅在 provider_config.xai_native_search 为 True 时生效
|
||||
- 默认注入 {"mode": "auto"}
|
||||
- 允许通过 kwargs 使用 xai_search_mode 覆盖(on/auto/off)
|
||||
"""
|
||||
if not bool(self.provider_config.get("xai_native_search", False)):
|
||||
return
|
||||
|
||||
mode = kwargs.get("xai_search_mode", "auto")
|
||||
mode = str(mode).lower()
|
||||
if mode not in ("auto", "on", "off"):
|
||||
mode = "auto"
|
||||
|
||||
# off 时不注入,保持与未开启一致
|
||||
if mode == "off":
|
||||
return
|
||||
|
||||
# OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body
|
||||
payloads["search_parameters"] = {"mode": mode}
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
models_str = []
|
||||
@@ -134,10 +112,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
model = payloads.get("model", "").lower()
|
||||
|
||||
# 针对 deepseek 模型的特殊处理:deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
|
||||
if model == "deepseek-reasoner" and "tools" in payloads:
|
||||
del payloads["tools"]
|
||||
|
||||
completion = await self.client.chat.completions.create(
|
||||
**payloads,
|
||||
stream=False,
|
||||
@@ -208,6 +182,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
# handle the content delta
|
||||
reasoning = self._extract_reasoning_content(chunk)
|
||||
_y = False
|
||||
llm_response.id = chunk.id
|
||||
if reasoning:
|
||||
llm_response.reasoning_content = reasoning
|
||||
_y = True
|
||||
@@ -217,6 +192,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
chain=[Comp.Plain(completion_text)],
|
||||
)
|
||||
_y = True
|
||||
if chunk.usage:
|
||||
llm_response.usage = self._extract_usage(chunk.usage)
|
||||
if _y:
|
||||
yield llm_response
|
||||
|
||||
@@ -245,6 +222,19 @@ class ProviderOpenAIOfficial(Provider):
|
||||
reasoning_text = str(reasoning_attr)
|
||||
return reasoning_text
|
||||
|
||||
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
|
||||
ptd = usage.prompt_tokens_details
|
||||
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
|
||||
prompt_tokens = 0 if usage.prompt_tokens is None else usage.prompt_tokens
|
||||
completion_tokens = (
|
||||
0 if usage.completion_tokens is None else usage.completion_tokens
|
||||
)
|
||||
return TokenUsage(
|
||||
input_other=prompt_tokens - cached,
|
||||
input_cached=cached,
|
||||
output=completion_tokens,
|
||||
)
|
||||
|
||||
async def _parse_openai_completion(
|
||||
self, completion: ChatCompletion, tools: ToolSet | None
|
||||
) -> LLMResponse:
|
||||
@@ -321,6 +311,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||
|
||||
llm_response.raw_completion = completion
|
||||
llm_response.id = completion.id
|
||||
|
||||
if completion.usage:
|
||||
llm_response.usage = self._extract_usage(completion.usage)
|
||||
|
||||
return llm_response
|
||||
|
||||
@@ -332,6 +326,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt: str | None = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
|
||||
model: str | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
**kwargs,
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
@@ -339,7 +334,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
contexts = []
|
||||
new_record = None
|
||||
if prompt is not None:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
new_record = await self.assemble_context(
|
||||
prompt, image_urls, extra_user_content_parts
|
||||
)
|
||||
context_query = self._ensure_message_to_dicts(contexts)
|
||||
if new_record:
|
||||
context_query.append(new_record)
|
||||
@@ -358,16 +355,31 @@ class ProviderOpenAIOfficial(Provider):
|
||||
for tcr in tool_calls_result:
|
||||
context_query.extend(tcr.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = model or self.get_model()
|
||||
model = model or self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
payloads = {"messages": context_query, "model": model}
|
||||
|
||||
# xAI origin search tool inject
|
||||
self._maybe_inject_xai_search(payloads, **kwargs)
|
||||
self._finally_convert_payload(payloads)
|
||||
|
||||
return payloads, context_query
|
||||
|
||||
def _finally_convert_payload(self, payloads: dict):
|
||||
"""Finally convert the payload. Such as think part conversion, tool inject."""
|
||||
for message in payloads.get("messages", []):
|
||||
if message.get("role") == "assistant" and isinstance(
|
||||
message.get("content"), list
|
||||
):
|
||||
reasoning_content = ""
|
||||
new_content = [] # not including think part
|
||||
for part in message["content"]:
|
||||
if part.get("type") == "think":
|
||||
reasoning_content += str(part.get("think"))
|
||||
else:
|
||||
new_content.append(part)
|
||||
message["content"] = new_content
|
||||
# reasoning key is "reasoning_content"
|
||||
message["reasoning_content"] = reasoning_content
|
||||
|
||||
async def _handle_api_error(
|
||||
self,
|
||||
e: Exception,
|
||||
@@ -461,6 +473,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
model=None,
|
||||
extra_user_content_parts=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
payloads, context_query = await self._prepare_chat_payload(
|
||||
@@ -470,6 +483,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
model=model,
|
||||
extra_user_content_parts=extra_user_content_parts,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -609,33 +623,71 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self,
|
||||
text: str,
|
||||
image_urls: list[str] | None = None,
|
||||
extra_user_content_parts: list[ContentPart] | None = None,
|
||||
) -> dict:
|
||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||
if image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
|
||||
async def resolve_image_part(image_url: str) -> dict | None:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
return None
|
||||
return {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self.encode_image_bs64(image_path)
|
||||
|
||||
# 构建内容块列表
|
||||
content_blocks = []
|
||||
|
||||
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
|
||||
if text:
|
||||
content_blocks.append({"type": "text", "text": text})
|
||||
elif image_urls:
|
||||
# 如果没有文本但有图片,添加占位文本
|
||||
content_blocks.append({"type": "text", "text": "[图片]"})
|
||||
elif extra_user_content_parts:
|
||||
# 如果只有额外内容块,也需要添加占位文本
|
||||
content_blocks.append({"type": "text", "text": " "})
|
||||
|
||||
# 2. 额外的内容块(系统提醒、指令等)
|
||||
if extra_user_content_parts:
|
||||
for part in extra_user_content_parts:
|
||||
if isinstance(part, TextPart):
|
||||
content_blocks.append({"type": "text", "text": part.text})
|
||||
elif isinstance(part, ImageURLPart):
|
||||
image_part = await resolve_image_part(part.image_url.url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
else:
|
||||
image_data = await self.encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
},
|
||||
)
|
||||
return user_content
|
||||
return {"role": "user", "content": text}
|
||||
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
|
||||
|
||||
# 3. 图片内容
|
||||
if image_urls:
|
||||
for image_url in image_urls:
|
||||
image_part = await resolve_image_part(image_url)
|
||||
if image_part:
|
||||
content_blocks.append(image_part)
|
||||
|
||||
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
|
||||
if (
|
||||
text
|
||||
and not extra_user_content_parts
|
||||
and not image_urls
|
||||
and len(content_blocks) == 1
|
||||
and content_blocks[0]["type"] == "text"
|
||||
):
|
||||
return {"role": "user", "content": content_blocks[0]["text"]}
|
||||
|
||||
# 否则返回多模态格式
|
||||
return {"role": "user", "content": content_blocks}
|
||||
|
||||
async def encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
from ..register import register_provider_adapter
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"xai_chat_completion", "xAI Chat Completion Provider Adapter"
|
||||
)
|
||||
class ProviderXAI(ProviderOpenAIOfficial):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
|
||||
def _maybe_inject_xai_search(self, payloads: dict):
|
||||
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
|
||||
|
||||
- 仅在 provider_config.xai_native_search 为 True 时生效
|
||||
- 默认注入 {"mode": "auto"}
|
||||
"""
|
||||
if not bool(self.provider_config.get("xai_native_search", False)):
|
||||
return
|
||||
# OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body
|
||||
payloads["search_parameters"] = {"mode": "auto"}
|
||||
|
||||
def _finally_convert_payload(self, payloads: dict):
|
||||
self._maybe_inject_xai_search(payloads)
|
||||
super()._finally_convert_payload(payloads)
|
||||
@@ -8,7 +8,10 @@ from xinference_client.client.restful.async_restful_client import (
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
from astrbot.core.utils.tencent_record_helper import (
|
||||
convert_to_pcm_wav,
|
||||
tencent_silk_to_wav,
|
||||
)
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import STTProvider
|
||||
@@ -111,17 +114,22 @@ class ProviderXinferenceSTT(STTProvider):
|
||||
return ""
|
||||
|
||||
# 2. Check for conversion
|
||||
needs_conversion = False
|
||||
if (
|
||||
audio_url.endswith((".amr", ".silk"))
|
||||
or is_tencent
|
||||
or b"SILK" in audio_bytes[:8]
|
||||
):
|
||||
needs_conversion = True
|
||||
conversion_type = None
|
||||
|
||||
if b"SILK" in audio_bytes[:8]:
|
||||
conversion_type = "silk"
|
||||
elif b"#!AMR" in audio_bytes[:6]:
|
||||
conversion_type = "amr"
|
||||
elif audio_url.endswith(".silk") or is_tencent:
|
||||
conversion_type = "silk"
|
||||
elif audio_url.endswith(".amr"):
|
||||
conversion_type = "amr"
|
||||
|
||||
# 3. Perform conversion if needed
|
||||
if needs_conversion:
|
||||
logger.info("Audio requires conversion, using temporary files...")
|
||||
if conversion_type:
|
||||
logger.info(
|
||||
f"Audio requires conversion ({conversion_type}), using temporary files..."
|
||||
)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
@@ -132,8 +140,12 @@ class ProviderXinferenceSTT(STTProvider):
|
||||
with open(input_path, "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
logger.info("Converting silk/amr file to wav ...")
|
||||
await tencent_silk_to_wav(input_path, output_path)
|
||||
if conversion_type == "silk":
|
||||
logger.info("Converting silk to wav ...")
|
||||
await tencent_silk_to_wav(input_path, output_path)
|
||||
elif conversion_type == "amr":
|
||||
logger.info("Converting amr to wav ...")
|
||||
await convert_to_pcm_wav(input_path, output_path)
|
||||
|
||||
with open(output_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
|
||||
@@ -0,0 +1,496 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from astrbot.core import db_helper, logger
|
||||
from astrbot.core.db.po import CommandConfig
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandDescriptor:
|
||||
handler: StarHandlerMetadata = field(repr=False)
|
||||
filter_ref: CommandFilter | CommandGroupFilter | None = field(
|
||||
default=None,
|
||||
repr=False,
|
||||
)
|
||||
handler_full_name: str = ""
|
||||
handler_name: str = ""
|
||||
plugin_name: str = ""
|
||||
plugin_display_name: str | None = None
|
||||
module_path: str = ""
|
||||
description: str = ""
|
||||
command_type: str = "command" # "command" | "group" | "sub_command"
|
||||
raw_command_name: str | None = None
|
||||
current_fragment: str | None = None
|
||||
parent_signature: str = ""
|
||||
parent_group_handler: str = ""
|
||||
original_command: str | None = None
|
||||
effective_command: str | None = None
|
||||
aliases: list[str] = field(default_factory=list)
|
||||
permission: str = "everyone"
|
||||
enabled: bool = True
|
||||
is_group: bool = False
|
||||
is_sub_command: bool = False
|
||||
reserved: bool = False
|
||||
config: CommandConfig | None = None
|
||||
has_conflict: bool = False
|
||||
sub_commands: list[CommandDescriptor] = field(default_factory=list)
|
||||
|
||||
|
||||
async def sync_command_configs() -> None:
|
||||
"""同步指令配置,清理过期配置。"""
|
||||
descriptors = _collect_descriptors(include_sub_commands=False)
|
||||
config_records = await db_helper.get_command_configs()
|
||||
config_map = _bind_configs_to_descriptors(descriptors, config_records)
|
||||
live_handlers = {desc.handler_full_name for desc in descriptors}
|
||||
|
||||
stale_configs = [key for key in config_map if key not in live_handlers]
|
||||
if stale_configs:
|
||||
await db_helper.delete_command_configs(stale_configs)
|
||||
|
||||
|
||||
async def toggle_command(handler_full_name: str, enabled: bool) -> CommandDescriptor:
|
||||
descriptor = _build_descriptor_by_full_name(handler_full_name)
|
||||
if not descriptor:
|
||||
raise ValueError("指定的处理函数不存在或不是指令。")
|
||||
|
||||
existing_cfg = await db_helper.get_command_config(handler_full_name)
|
||||
config = await db_helper.upsert_command_config(
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=descriptor.plugin_name or "",
|
||||
module_path=descriptor.module_path,
|
||||
original_command=descriptor.original_command or descriptor.handler_name,
|
||||
resolved_command=(
|
||||
existing_cfg.resolved_command
|
||||
if existing_cfg
|
||||
else descriptor.current_fragment
|
||||
),
|
||||
enabled=enabled,
|
||||
keep_original_alias=False,
|
||||
conflict_key=existing_cfg.conflict_key
|
||||
if existing_cfg and existing_cfg.conflict_key
|
||||
else descriptor.original_command,
|
||||
resolution_strategy=existing_cfg.resolution_strategy if existing_cfg else None,
|
||||
note=existing_cfg.note if existing_cfg else None,
|
||||
extra_data=existing_cfg.extra_data if existing_cfg else None,
|
||||
auto_managed=False,
|
||||
)
|
||||
_bind_descriptor_with_config(descriptor, config)
|
||||
await sync_command_configs()
|
||||
return descriptor
|
||||
|
||||
|
||||
async def rename_command(
|
||||
handler_full_name: str,
|
||||
new_fragment: str,
|
||||
aliases: list[str] | None = None,
|
||||
) -> CommandDescriptor:
|
||||
descriptor = _build_descriptor_by_full_name(handler_full_name)
|
||||
if not descriptor:
|
||||
raise ValueError("指定的处理函数不存在或不是指令。")
|
||||
|
||||
new_fragment = new_fragment.strip()
|
||||
if not new_fragment:
|
||||
raise ValueError("指令名不能为空。")
|
||||
|
||||
# 校验主指令名
|
||||
candidate_full = _compose_command(descriptor.parent_signature, new_fragment)
|
||||
if _is_command_in_use(handler_full_name, candidate_full):
|
||||
raise ValueError(f"指令名 '{candidate_full}' 已被其他指令占用。")
|
||||
|
||||
# 校验别名
|
||||
if aliases:
|
||||
for alias in aliases:
|
||||
alias = alias.strip()
|
||||
if not alias:
|
||||
continue
|
||||
alias_full = _compose_command(descriptor.parent_signature, alias)
|
||||
if _is_command_in_use(handler_full_name, alias_full):
|
||||
raise ValueError(f"别名 '{alias_full}' 已被其他指令占用。")
|
||||
|
||||
existing_cfg = await db_helper.get_command_config(handler_full_name)
|
||||
merged_extra = dict(existing_cfg.extra_data or {}) if existing_cfg else {}
|
||||
merged_extra["resolved_aliases"] = aliases or []
|
||||
|
||||
config = await db_helper.upsert_command_config(
|
||||
handler_full_name=handler_full_name,
|
||||
plugin_name=descriptor.plugin_name or "",
|
||||
module_path=descriptor.module_path,
|
||||
original_command=descriptor.original_command or descriptor.handler_name,
|
||||
resolved_command=new_fragment,
|
||||
enabled=True if descriptor.enabled else False,
|
||||
keep_original_alias=False,
|
||||
conflict_key=descriptor.original_command,
|
||||
resolution_strategy="manual_rename",
|
||||
note=None,
|
||||
extra_data=merged_extra,
|
||||
auto_managed=False,
|
||||
)
|
||||
_bind_descriptor_with_config(descriptor, config)
|
||||
|
||||
await sync_command_configs()
|
||||
return descriptor
|
||||
|
||||
|
||||
async def list_commands() -> list[dict[str, Any]]:
|
||||
descriptors = _collect_descriptors(include_sub_commands=True)
|
||||
config_records = await db_helper.get_command_configs()
|
||||
_bind_configs_to_descriptors(descriptors, config_records)
|
||||
|
||||
conflict_groups = _group_conflicts(descriptors)
|
||||
conflict_handler_names: set[str] = {
|
||||
d.handler_full_name for group in conflict_groups.values() for d in group
|
||||
}
|
||||
|
||||
# 分类,设置冲突标志,将子指令挂载到父指令组
|
||||
group_map: dict[str, CommandDescriptor] = {}
|
||||
sub_commands: list[CommandDescriptor] = []
|
||||
root_commands: list[CommandDescriptor] = []
|
||||
|
||||
for desc in descriptors:
|
||||
desc.has_conflict = desc.handler_full_name in conflict_handler_names
|
||||
if desc.is_group:
|
||||
group_map[desc.handler_full_name] = desc
|
||||
elif desc.is_sub_command:
|
||||
sub_commands.append(desc)
|
||||
else:
|
||||
root_commands.append(desc)
|
||||
|
||||
for sub in sub_commands:
|
||||
if sub.parent_group_handler and sub.parent_group_handler in group_map:
|
||||
group_map[sub.parent_group_handler].sub_commands.append(sub)
|
||||
else:
|
||||
root_commands.append(sub)
|
||||
|
||||
# 指令组 + 普通指令,按 effective_command 字母排序
|
||||
all_commands = list(group_map.values()) + root_commands
|
||||
all_commands.sort(key=lambda d: (d.effective_command or "").lower())
|
||||
|
||||
result = [_descriptor_to_dict(desc) for desc in all_commands]
|
||||
return result
|
||||
|
||||
|
||||
async def list_command_conflicts() -> list[dict[str, Any]]:
|
||||
"""列出所有冲突的指令组。"""
|
||||
descriptors = _collect_descriptors(include_sub_commands=False)
|
||||
config_records = await db_helper.get_command_configs()
|
||||
_bind_configs_to_descriptors(descriptors, config_records)
|
||||
|
||||
conflict_groups = _group_conflicts(descriptors)
|
||||
details = [
|
||||
{
|
||||
"conflict_key": key,
|
||||
"handlers": [
|
||||
{
|
||||
"handler_full_name": item.handler_full_name,
|
||||
"plugin": item.plugin_name,
|
||||
"current_name": item.effective_command,
|
||||
}
|
||||
for item in group
|
||||
],
|
||||
}
|
||||
for key, group in conflict_groups.items()
|
||||
]
|
||||
return details
|
||||
|
||||
|
||||
# Internal helpers ----------------------------------------------------------
|
||||
|
||||
|
||||
def _collect_descriptors(include_sub_commands: bool) -> list[CommandDescriptor]:
|
||||
"""收集指令,按需包含子指令。"""
|
||||
descriptors: list[CommandDescriptor] = []
|
||||
for handler in star_handlers_registry:
|
||||
try:
|
||||
desc = _build_descriptor(handler)
|
||||
if not desc:
|
||||
continue
|
||||
if not include_sub_commands and desc.is_sub_command:
|
||||
continue
|
||||
descriptors.append(desc)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"解析指令处理函数 {handler.handler_full_name} 失败,跳过该指令。原因: {e!s}"
|
||||
)
|
||||
continue
|
||||
return descriptors
|
||||
|
||||
|
||||
def _build_descriptor(handler: StarHandlerMetadata) -> CommandDescriptor | None:
|
||||
filter_ref = _locate_primary_filter(handler)
|
||||
if filter_ref is None:
|
||||
return None
|
||||
|
||||
plugin_meta = star_map.get(handler.handler_module_path)
|
||||
plugin_name = (
|
||||
plugin_meta.name if plugin_meta else None
|
||||
) or handler.handler_module_path
|
||||
plugin_display = plugin_meta.display_name if plugin_meta else None
|
||||
|
||||
is_sub_command = bool(handler.extras_configs.get("sub_command"))
|
||||
parent_group_handler = ""
|
||||
|
||||
if isinstance(filter_ref, CommandFilter):
|
||||
raw_fragment = getattr(
|
||||
filter_ref, "_original_command_name", filter_ref.command_name
|
||||
)
|
||||
current_fragment = filter_ref.command_name
|
||||
parent_signature = (filter_ref.parent_command_names or [""])[0].strip()
|
||||
# 如果是子指令,尝试找到父指令组的 handler_full_name
|
||||
if is_sub_command and parent_signature:
|
||||
parent_group_handler = _find_parent_group_handler(
|
||||
handler.handler_module_path, parent_signature
|
||||
)
|
||||
else:
|
||||
raw_fragment = getattr(
|
||||
filter_ref, "_original_group_name", filter_ref.group_name
|
||||
)
|
||||
current_fragment = filter_ref.group_name
|
||||
parent_signature = _resolve_group_parent_signature(filter_ref)
|
||||
|
||||
original_command = _compose_command(parent_signature, raw_fragment)
|
||||
effective_command = _compose_command(parent_signature, current_fragment)
|
||||
|
||||
# 确定 command_type
|
||||
if isinstance(filter_ref, CommandGroupFilter):
|
||||
command_type = "group"
|
||||
elif is_sub_command:
|
||||
command_type = "sub_command"
|
||||
else:
|
||||
command_type = "command"
|
||||
|
||||
descriptor = CommandDescriptor(
|
||||
handler=handler,
|
||||
filter_ref=filter_ref,
|
||||
handler_full_name=handler.handler_full_name,
|
||||
handler_name=handler.handler_name,
|
||||
plugin_name=plugin_name,
|
||||
plugin_display_name=plugin_display,
|
||||
module_path=handler.handler_module_path,
|
||||
description=handler.desc or "",
|
||||
command_type=command_type,
|
||||
raw_command_name=raw_fragment,
|
||||
current_fragment=current_fragment,
|
||||
parent_signature=parent_signature,
|
||||
parent_group_handler=parent_group_handler,
|
||||
original_command=original_command,
|
||||
effective_command=effective_command,
|
||||
aliases=sorted(getattr(filter_ref, "alias", set())),
|
||||
permission=_determine_permission(handler),
|
||||
enabled=handler.enabled,
|
||||
is_group=isinstance(filter_ref, CommandGroupFilter),
|
||||
is_sub_command=is_sub_command,
|
||||
reserved=plugin_meta.reserved if plugin_meta else False,
|
||||
)
|
||||
return descriptor
|
||||
|
||||
|
||||
def _build_descriptor_by_full_name(full_name: str) -> CommandDescriptor | None:
|
||||
handler = star_handlers_registry.get_handler_by_full_name(full_name)
|
||||
if not handler:
|
||||
return None
|
||||
return _build_descriptor(handler)
|
||||
|
||||
|
||||
def _locate_primary_filter(
|
||||
handler: StarHandlerMetadata,
|
||||
) -> CommandFilter | CommandGroupFilter | None:
|
||||
for filter_ref in handler.event_filters:
|
||||
if isinstance(filter_ref, (CommandFilter, CommandGroupFilter)):
|
||||
return filter_ref
|
||||
return None
|
||||
|
||||
|
||||
def _determine_permission(handler: StarHandlerMetadata) -> str:
|
||||
for filter_ref in handler.event_filters:
|
||||
if isinstance(filter_ref, PermissionTypeFilter):
|
||||
return (
|
||||
"admin"
|
||||
if filter_ref.permission_type == PermissionType.ADMIN
|
||||
else "member"
|
||||
)
|
||||
return "everyone"
|
||||
|
||||
|
||||
def _resolve_group_parent_signature(group_filter: CommandGroupFilter) -> str:
|
||||
signatures: list[str] = []
|
||||
parent = group_filter.parent_group
|
||||
while parent:
|
||||
signatures.append(getattr(parent, "_original_group_name", parent.group_name))
|
||||
parent = parent.parent_group
|
||||
return " ".join(reversed(signatures)).strip()
|
||||
|
||||
|
||||
def _find_parent_group_handler(module_path: str, parent_signature: str) -> str:
|
||||
"""根据模块路径和父级签名,找到对应的指令组 handler_full_name。"""
|
||||
parent_sig_normalized = parent_signature.strip()
|
||||
for handler in star_handlers_registry:
|
||||
if handler.handler_module_path != module_path:
|
||||
continue
|
||||
filter_ref = _locate_primary_filter(handler)
|
||||
if not isinstance(filter_ref, CommandGroupFilter):
|
||||
continue
|
||||
# 检查该指令组的完整指令名是否匹配 parent_signature
|
||||
group_names = filter_ref.get_complete_command_names()
|
||||
if parent_sig_normalized in group_names:
|
||||
return handler.handler_full_name
|
||||
return ""
|
||||
|
||||
|
||||
def _compose_command(parent_signature: str, fragment: str | None) -> str:
|
||||
fragment = (fragment or "").strip()
|
||||
parent_signature = parent_signature.strip()
|
||||
if not parent_signature:
|
||||
return fragment
|
||||
if not fragment:
|
||||
return parent_signature
|
||||
return f"{parent_signature} {fragment}"
|
||||
|
||||
|
||||
def _bind_descriptor_with_config(
|
||||
descriptor: CommandDescriptor,
|
||||
config: CommandConfig,
|
||||
) -> None:
|
||||
_apply_config_to_descriptor(descriptor, config)
|
||||
_apply_config_to_runtime(descriptor, config)
|
||||
|
||||
|
||||
def _apply_config_to_descriptor(
|
||||
descriptor: CommandDescriptor,
|
||||
config: CommandConfig,
|
||||
) -> None:
|
||||
descriptor.config = config
|
||||
descriptor.enabled = config.enabled
|
||||
|
||||
if config.original_command:
|
||||
descriptor.original_command = config.original_command
|
||||
|
||||
new_fragment = config.resolved_command or descriptor.current_fragment
|
||||
descriptor.current_fragment = new_fragment
|
||||
descriptor.effective_command = _compose_command(
|
||||
descriptor.parent_signature,
|
||||
new_fragment,
|
||||
)
|
||||
|
||||
extra = config.extra_data or {}
|
||||
resolved_aliases = extra.get("resolved_aliases")
|
||||
if isinstance(resolved_aliases, list):
|
||||
descriptor.aliases = [str(x) for x in resolved_aliases if str(x).strip()]
|
||||
|
||||
|
||||
def _apply_config_to_runtime(
|
||||
descriptor: CommandDescriptor,
|
||||
config: CommandConfig,
|
||||
) -> None:
|
||||
descriptor.handler.enabled = config.enabled
|
||||
if descriptor.filter_ref:
|
||||
if descriptor.current_fragment:
|
||||
_set_filter_fragment(descriptor.filter_ref, descriptor.current_fragment)
|
||||
extra = config.extra_data or {}
|
||||
resolved_aliases = extra.get("resolved_aliases")
|
||||
if isinstance(resolved_aliases, list):
|
||||
_set_filter_aliases(
|
||||
descriptor.filter_ref,
|
||||
[str(x) for x in resolved_aliases if str(x).strip()],
|
||||
)
|
||||
|
||||
|
||||
def _bind_configs_to_descriptors(
|
||||
descriptors: list[CommandDescriptor],
|
||||
config_records: list[CommandConfig],
|
||||
) -> dict[str, CommandConfig]:
|
||||
config_map = {cfg.handler_full_name: cfg for cfg in config_records}
|
||||
for desc in descriptors:
|
||||
if cfg := config_map.get(desc.handler_full_name):
|
||||
_bind_descriptor_with_config(desc, cfg)
|
||||
return config_map
|
||||
|
||||
|
||||
def _group_conflicts(
|
||||
descriptors: list[CommandDescriptor],
|
||||
) -> dict[str, list[CommandDescriptor]]:
|
||||
conflicts: dict[str, list[CommandDescriptor]] = defaultdict(list)
|
||||
for desc in descriptors:
|
||||
if desc.effective_command and desc.enabled:
|
||||
conflicts[desc.effective_command].append(desc)
|
||||
return {k: v for k, v in conflicts.items() if len(v) > 1}
|
||||
|
||||
|
||||
def _set_filter_fragment(
|
||||
filter_ref: CommandFilter | CommandGroupFilter,
|
||||
fragment: str,
|
||||
) -> None:
|
||||
attr = (
|
||||
"group_name" if isinstance(filter_ref, CommandGroupFilter) else "command_name"
|
||||
)
|
||||
current_value = getattr(filter_ref, attr)
|
||||
if fragment == current_value:
|
||||
return
|
||||
setattr(filter_ref, attr, fragment)
|
||||
if hasattr(filter_ref, "_cmpl_cmd_names"):
|
||||
filter_ref._cmpl_cmd_names = None
|
||||
|
||||
|
||||
def _set_filter_aliases(
|
||||
filter_ref: CommandFilter | CommandGroupFilter,
|
||||
aliases: list[str],
|
||||
) -> None:
|
||||
current_aliases = getattr(filter_ref, "alias", set())
|
||||
if set(aliases) == current_aliases:
|
||||
return
|
||||
setattr(filter_ref, "alias", set(aliases))
|
||||
if hasattr(filter_ref, "_cmpl_cmd_names"):
|
||||
filter_ref._cmpl_cmd_names = None
|
||||
|
||||
|
||||
def _is_command_in_use(
|
||||
target_handler_full_name: str,
|
||||
candidate_full_command: str,
|
||||
) -> bool:
|
||||
candidate = candidate_full_command.strip()
|
||||
for handler in star_handlers_registry:
|
||||
if handler.handler_full_name == target_handler_full_name:
|
||||
continue
|
||||
filter_ref = _locate_primary_filter(handler)
|
||||
if not filter_ref:
|
||||
continue
|
||||
names = {name.strip() for name in filter_ref.get_complete_command_names()}
|
||||
if candidate in names:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _descriptor_to_dict(desc: CommandDescriptor) -> dict[str, Any]:
|
||||
result = {
|
||||
"handler_full_name": desc.handler_full_name,
|
||||
"handler_name": desc.handler_name,
|
||||
"plugin": desc.plugin_name,
|
||||
"plugin_display_name": desc.plugin_display_name,
|
||||
"module_path": desc.module_path,
|
||||
"description": desc.description,
|
||||
"type": desc.command_type,
|
||||
"parent_signature": desc.parent_signature,
|
||||
"parent_group_handler": desc.parent_group_handler,
|
||||
"original_command": desc.original_command,
|
||||
"current_fragment": desc.current_fragment,
|
||||
"effective_command": desc.effective_command,
|
||||
"aliases": desc.aliases,
|
||||
"permission": desc.permission,
|
||||
"enabled": desc.enabled,
|
||||
"is_group": desc.is_group,
|
||||
"has_conflict": desc.has_conflict,
|
||||
"reserved": desc.reserved,
|
||||
}
|
||||
# 如果是指令组,包含子指令列表
|
||||
if desc.is_group and desc.sub_commands:
|
||||
result["sub_commands"] = [_descriptor_to_dict(sub) for sub in desc.sub_commands]
|
||||
else:
|
||||
result["sub_commands"] = []
|
||||
return result
|
||||
@@ -267,6 +267,10 @@ class Context:
|
||||
):
|
||||
"""通过 ID 获取对应的 LLM Provider。"""
|
||||
prov = self.provider_manager.inst_map.get(provider_id)
|
||||
if provider_id and not prov:
|
||||
logger.warning(
|
||||
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
|
||||
)
|
||||
return prov
|
||||
|
||||
def get_all_providers(self) -> list[Provider]:
|
||||
@@ -373,7 +377,7 @@ class Context:
|
||||
if not module_path:
|
||||
_parts = []
|
||||
module_part = tool.__module__.split(".")
|
||||
flags = ["packages", "plugins"]
|
||||
flags = ["builtin_stars", "plugins"]
|
||||
for i, part in enumerate(module_part):
|
||||
_parts.append(part)
|
||||
if part in flags and i + 1 < len(module_part):
|
||||
|
||||
@@ -40,6 +40,7 @@ class CommandFilter(HandlerFilter):
|
||||
):
|
||||
self.command_name = command_name
|
||||
self.alias = alias if alias else set()
|
||||
self._original_command_name = command_name
|
||||
self.parent_command_names = (
|
||||
parent_command_names if parent_command_names is not None else [""]
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
):
|
||||
self.group_name = group_name
|
||||
self.alias = alias if alias else set()
|
||||
self._original_group_name = group_name
|
||||
self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = []
|
||||
self.custom_filter_list: list[CustomFilter] = []
|
||||
self.parent_group = parent_group
|
||||
|
||||
@@ -118,6 +118,8 @@ class StarHandlerRegistry(Generic[T]):
|
||||
# 过滤事件类型
|
||||
if handler.event_type != event_type:
|
||||
continue
|
||||
if not handler.enabled:
|
||||
continue
|
||||
# 过滤启用状态
|
||||
if only_activated:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
@@ -220,6 +222,8 @@ class StarHandlerMetadata(Generic[H]):
|
||||
extras_configs: dict = field(default_factory=dict)
|
||||
"""插件注册的一些其他的信息, 如 priority 等"""
|
||||
|
||||
enabled: bool = True
|
||||
|
||||
def __lt__(self, other: StarHandlerMetadata):
|
||||
"""定义小于运算符以支持优先队列"""
|
||||
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
||||
|
||||
@@ -18,11 +18,13 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
get_astrbot_path,
|
||||
get_astrbot_plugin_path,
|
||||
)
|
||||
from astrbot.core.utils.io import remove_dir
|
||||
|
||||
from . import StarMetadata
|
||||
from .command_management import sync_command_configs
|
||||
from .context import Context
|
||||
from .filter.permission import PermissionType, PermissionTypeFilter
|
||||
from .star import star_map, star_registry
|
||||
@@ -48,13 +50,10 @@ class PluginManager:
|
||||
"""存储插件的路径。即 data/plugins"""
|
||||
self.plugin_config_path = get_astrbot_config_path()
|
||||
"""存储插件配置的路径。data/config"""
|
||||
self.reserved_plugin_path = os.path.abspath(
|
||||
os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
"../../../packages",
|
||||
),
|
||||
self.reserved_plugin_path = os.path.join(
|
||||
get_astrbot_path(), "astrbot", "builtin_stars"
|
||||
)
|
||||
"""保留插件的路径。在 packages 目录下"""
|
||||
"""保留插件的路径。在 astrbot/builtin_stars 目录下"""
|
||||
self.conf_schema_fname = "_conf_schema.json"
|
||||
self.logo_fname = "logo.png"
|
||||
"""插件配置 Schema 文件名"""
|
||||
@@ -251,7 +250,7 @@ class PluginManager:
|
||||
list[str]: 与该插件相关的模块名列表
|
||||
|
||||
"""
|
||||
prefix = "packages." if is_reserved else "data.plugins."
|
||||
prefix = "astrbot.builtin_stars." if is_reserved else "data.plugins."
|
||||
return [
|
||||
key
|
||||
for key in list(sys.modules.keys())
|
||||
@@ -269,7 +268,7 @@ class PluginManager:
|
||||
可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存
|
||||
|
||||
Args:
|
||||
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "packages"])
|
||||
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "astrbot.builtin_stars"])
|
||||
root_dir_name: 插件根目录名,用于移除与该插件相关的所有模块
|
||||
is_reserved: 插件是否为保留插件(影响模块路径前缀)
|
||||
|
||||
@@ -381,9 +380,9 @@ class PluginManager:
|
||||
reserved = plugin_module.get(
|
||||
"reserved",
|
||||
False,
|
||||
) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
|
||||
) # 是否是保留插件。目前在 astrbot/builtin_stars 目录下的都是保留插件。保留插件不可以卸载。
|
||||
|
||||
path = "data.plugins." if not reserved else "packages."
|
||||
path = "data.plugins." if not reserved else "astrbot.builtin_stars."
|
||||
path += root_dir_name + "." + module_str
|
||||
|
||||
# 检查是否需要载入指定的插件
|
||||
@@ -630,6 +629,11 @@ class PluginManager:
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
try:
|
||||
await sync_command_configs()
|
||||
except Exception as e:
|
||||
logger.error(f"同步指令配置失败: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if not fail_rec:
|
||||
return True, None
|
||||
@@ -823,7 +827,7 @@ class PluginManager:
|
||||
if (
|
||||
mp
|
||||
and mp.startswith(plugin_module_path)
|
||||
and not mp.endswith(("packages", "data.plugins"))
|
||||
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||
):
|
||||
to_remove.append(func_tool)
|
||||
for func_tool in to_remove:
|
||||
@@ -878,7 +882,7 @@ class PluginManager:
|
||||
plugin.module_path
|
||||
and mp
|
||||
and plugin.module_path.startswith(mp)
|
||||
and not mp.endswith(("packages", "data.plugins"))
|
||||
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||
):
|
||||
func_tool.active = False
|
||||
if func_tool.name not in inactivated_llm_tools:
|
||||
@@ -927,7 +931,7 @@ class PluginManager:
|
||||
plugin.module_path
|
||||
and mp
|
||||
and plugin.module_path.startswith(mp)
|
||||
and not mp.endswith(("packages", "data.plugins"))
|
||||
and not mp.endswith(("astrbot.builtin_stars", "data.plugins"))
|
||||
and func_tool.name in inactivated_llm_tools
|
||||
):
|
||||
inactivated_llm_tools.remove(func_tool.name)
|
||||
@@ -940,8 +944,49 @@ class PluginManager:
|
||||
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
|
||||
dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower()
|
||||
desti_dir = os.path.join(self.plugin_store_path, dir_name)
|
||||
|
||||
# 第一步:检查是否已安装同目录名的插件,先终止旧插件
|
||||
existing_plugin = None
|
||||
for star in self.context.get_all_stars():
|
||||
if star.root_dir_name == dir_name:
|
||||
existing_plugin = star
|
||||
break
|
||||
|
||||
if existing_plugin:
|
||||
logger.info(f"检测到插件 {existing_plugin.name} 已安装,正在终止旧插件...")
|
||||
try:
|
||||
await self._terminate_plugin(existing_plugin)
|
||||
except Exception:
|
||||
logger.warning(traceback.format_exc())
|
||||
if existing_plugin.name and existing_plugin.module_path:
|
||||
await self._unbind_plugin(
|
||||
existing_plugin.name, existing_plugin.module_path
|
||||
)
|
||||
|
||||
self.updator.unzip_file(zip_file_path, desti_dir)
|
||||
|
||||
# 第二步:解压后,读取新插件的 metadata.yaml,检查是否存在同名但不同目录的插件
|
||||
try:
|
||||
new_metadata = self._load_plugin_metadata(desti_dir)
|
||||
if new_metadata and new_metadata.name:
|
||||
for star in self.context.get_all_stars():
|
||||
if (
|
||||
star.name == new_metadata.name
|
||||
and star.root_dir_name != dir_name
|
||||
):
|
||||
logger.warning(
|
||||
f"检测到同名插件 {star.name} 存在于不同目录 {star.root_dir_name},正在终止..."
|
||||
)
|
||||
try:
|
||||
await self._terminate_plugin(star)
|
||||
except Exception:
|
||||
logger.warning(traceback.format_exc())
|
||||
if star.name and star.module_path:
|
||||
await self._unbind_plugin(star.name, star.module_path)
|
||||
break # 只处理第一个匹配的
|
||||
except Exception as e:
|
||||
logger.debug(f"读取新插件 metadata.yaml 失败,跳过同名检查: {e!s}")
|
||||
|
||||
# remove the zip
|
||||
try:
|
||||
os.remove(zip_file_path)
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
数据目录路径:固定为根目录下的 data 目录
|
||||
配置文件路径:固定为数据目录下的 config 目录
|
||||
插件目录路径:固定为数据目录下的 plugins 目录
|
||||
插件数据目录路径:固定为数据目录下的 plugin_data 目录
|
||||
T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录
|
||||
WebChat 数据目录路径:固定为数据目录下的 webchat 目录
|
||||
临时文件目录路径:固定为数据目录下的 temp 目录
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -37,3 +41,33 @@ def get_astrbot_config_path() -> str:
|
||||
def get_astrbot_plugin_path() -> str:
|
||||
"""获取Astrbot插件目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
|
||||
|
||||
|
||||
def get_astrbot_plugin_data_path() -> str:
|
||||
"""获取Astrbot插件数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugin_data"))
|
||||
|
||||
|
||||
def get_astrbot_t2i_templates_path() -> str:
|
||||
"""获取Astrbot T2I 模板目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "t2i_templates"))
|
||||
|
||||
|
||||
def get_astrbot_webchat_path() -> str:
|
||||
"""获取Astrbot WebChat 数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "webchat"))
|
||||
|
||||
|
||||
def get_astrbot_temp_path() -> str:
|
||||
"""获取Astrbot临时文件目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "temp"))
|
||||
|
||||
|
||||
def get_astrbot_knowledge_base_path() -> str:
|
||||
"""获取Astrbot知识库根目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base"))
|
||||
|
||||
|
||||
def get_astrbot_backups_path() -> str:
|
||||
"""获取Astrbot备份目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "backups"))
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
import aiohttp
|
||||
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
class LLMModalities(TypedDict):
|
||||
input: list[Literal["text", "image", "audio", "video"]]
|
||||
output: list[Literal["text", "image", "audio", "video"]]
|
||||
|
||||
|
||||
class LLMLimit(TypedDict):
|
||||
context: int
|
||||
output: int
|
||||
|
||||
|
||||
class LLMMetadata(TypedDict):
|
||||
id: str
|
||||
reasoning: bool
|
||||
tool_call: bool
|
||||
knowledge: str
|
||||
release_date: str
|
||||
modalities: LLMModalities
|
||||
open_weights: bool
|
||||
limit: LLMLimit
|
||||
|
||||
|
||||
LLM_METADATAS: dict[str, LLMMetadata] = {}
|
||||
|
||||
|
||||
async def update_llm_metadata():
|
||||
url = "https://models.dev/api.json"
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
data = await response.json()
|
||||
global LLM_METADATAS
|
||||
models = {}
|
||||
for info in data.values():
|
||||
for model in info.get("models", {}).values():
|
||||
model_id = model.get("id")
|
||||
if not model_id:
|
||||
continue
|
||||
models[model_id] = LLMMetadata(
|
||||
id=model_id,
|
||||
reasoning=model.get("reasoning", False),
|
||||
tool_call=model.get("tool_call", False),
|
||||
knowledge=model.get("knowledge", "none"),
|
||||
release_date=model.get("release_date", ""),
|
||||
modalities=model.get(
|
||||
"modalities", {"input": [], "output": []}
|
||||
),
|
||||
open_weights=model.get("open_weights", False),
|
||||
limit=model.get("limit", {"context": 0, "output": 0}),
|
||||
)
|
||||
# Replace the global cache in-place so references remain valid
|
||||
LLM_METADATAS.clear()
|
||||
LLM_METADATAS.update(models)
|
||||
logger.info(f"Successfully fetched metadata for {len(models)} LLMs.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch LLM metadata: {e}")
|
||||
return
|
||||
@@ -32,6 +32,92 @@ def _migra_agent_runner_configs(conf: AstrBotConfig, ids_map: dict) -> None:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def _migra_provider_to_source_structure(conf: AstrBotConfig) -> None:
|
||||
"""
|
||||
Migrate old provider structure to new provider-source separation.
|
||||
Provider only keeps: id, provider_source_id, model, modalities, custom_extra_body
|
||||
All other fields move to provider_sources.
|
||||
"""
|
||||
providers = conf.get("provider", [])
|
||||
provider_sources = conf.get("provider_sources", [])
|
||||
|
||||
# Track if any migration happened
|
||||
migrated = False
|
||||
|
||||
# Provider-only fields that should stay in provider
|
||||
provider_only_fields = {
|
||||
"id",
|
||||
"provider_source_id",
|
||||
"model",
|
||||
"modalities",
|
||||
"custom_extra_body",
|
||||
"enable",
|
||||
}
|
||||
|
||||
# Fields that should not go to source
|
||||
source_exclude_fields = provider_only_fields | {"model_config"}
|
||||
|
||||
for provider in providers:
|
||||
# Skip if already has provider_source_id
|
||||
if provider.get("provider_source_id"):
|
||||
continue
|
||||
|
||||
# Skip non-chat-completion types (they don't need source separation)
|
||||
provider_type = provider.get("provider_type", "")
|
||||
if provider_type != "chat_completion":
|
||||
# For old types without provider_type, check type field
|
||||
old_type = provider.get("type", "")
|
||||
if "chat_completion" not in old_type:
|
||||
continue
|
||||
|
||||
migrated = True
|
||||
logger.info(f"Migrating provider {provider.get('id')} to new structure")
|
||||
|
||||
# Extract source fields from provider
|
||||
source_fields = {}
|
||||
for key, value in list(provider.items()):
|
||||
if key not in source_exclude_fields:
|
||||
source_fields[key] = value
|
||||
|
||||
# Create new provider_source
|
||||
source_id = provider.get("id", "") + "_source"
|
||||
new_source = {"id": source_id, **source_fields}
|
||||
|
||||
# Update provider to only keep necessary fields
|
||||
provider["provider_source_id"] = source_id
|
||||
|
||||
# Extract model from model_config if exists
|
||||
if "model_config" in provider and isinstance(provider["model_config"], dict):
|
||||
model_config = provider["model_config"]
|
||||
provider["model"] = model_config.get("model", "")
|
||||
|
||||
# Put other model_config fields into custom_extra_body
|
||||
extra_body_fields = {k: v for k, v in model_config.items() if k != "model"}
|
||||
if extra_body_fields:
|
||||
if "custom_extra_body" not in provider:
|
||||
provider["custom_extra_body"] = {}
|
||||
provider["custom_extra_body"].update(extra_body_fields)
|
||||
|
||||
# Initialize new fields if not present
|
||||
if "modalities" not in provider:
|
||||
provider["modalities"] = []
|
||||
if "custom_extra_body" not in provider:
|
||||
provider["custom_extra_body"] = {}
|
||||
|
||||
# Remove fields that should be in source
|
||||
keys_to_remove = [k for k in provider.keys() if k not in provider_only_fields]
|
||||
for key in keys_to_remove:
|
||||
del provider[key]
|
||||
|
||||
# Add source to provider_sources
|
||||
provider_sources.append(new_source)
|
||||
|
||||
if migrated:
|
||||
conf["provider_sources"] = provider_sources
|
||||
conf.save_config()
|
||||
logger.info("Provider-source structure migration completed")
|
||||
|
||||
|
||||
async def migra(
|
||||
db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager
|
||||
) -> None:
|
||||
@@ -71,3 +157,10 @@ async def migra(
|
||||
|
||||
for conf in acm.confs.values():
|
||||
_migra_agent_runner_configs(conf, ids_map)
|
||||
|
||||
# Migrate providers to new structure: extract source fields to provider_sources
|
||||
try:
|
||||
_migra_provider_to_source_structure(astrbot_config)
|
||||
except Exception as e:
|
||||
logger.error(f"Migration for provider-source structure failed: {e!s}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -1,10 +1,29 @@
|
||||
import asyncio
|
||||
import locale
|
||||
import logging
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
def _robust_decode(line: bytes) -> str:
|
||||
"""解码字节流,兼容不同平台的编码"""
|
||||
try:
|
||||
return line.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
try:
|
||||
return line.decode(locale.getpreferredencoding(False)).strip()
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
if sys.platform.startswith("win"):
|
||||
try:
|
||||
return line.decode("gbk").strip()
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
return line.decode("utf-8", errors="replace").strip()
|
||||
|
||||
|
||||
class PipInstaller:
|
||||
def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None):
|
||||
self.pip_install_arg = pip_install_arg
|
||||
@@ -42,7 +61,7 @@ class PipInstaller:
|
||||
|
||||
assert process.stdout is not None
|
||||
async for line in process.stdout:
|
||||
logger.info(line.decode().strip())
|
||||
logger.info(_robust_decode(line))
|
||||
|
||||
await process.wait()
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from .auth import AuthRoute
|
||||
from .backup import BackupRoute
|
||||
from .chat import ChatRoute
|
||||
from .command import CommandRoute
|
||||
from .config import ConfigRoute
|
||||
from .conversation import ConversationRoute
|
||||
from .file import FileRoute
|
||||
@@ -16,7 +18,9 @@ from .update import UpdateRoute
|
||||
|
||||
__all__ = [
|
||||
"AuthRoute",
|
||||
"BackupRoute",
|
||||
"ChatRoute",
|
||||
"CommandRoute",
|
||||
"ConfigRoute",
|
||||
"ConversationRoute",
|
||||
"FileRoute",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user