From b1e3018b6b971e6edd036ee63660236bb623c140 Mon Sep 17 00:00:00 2001
From: Soulter <37870767+Soulter@users.noreply.github.com>
Date: Mon, 4 Aug 2025 00:56:26 +0800
Subject: [PATCH] =?UTF-8?q?Improve:=20=E5=BC=95=E5=85=A5=E5=85=A8=E6=96=B0?=
=?UTF-8?q?=E7=9A=84=E4=BA=BA=E6=A0=BC=E7=AE=A1=E7=90=86=E6=A8=A1=E5=BC=8F?=
=?UTF-8?q?=E4=BB=A5=E5=8F=8A=E9=87=8D=E6=9E=84=E5=87=BD=E6=95=B0=E5=B7=A5?=
=?UTF-8?q?=E5=85=B7=E7=AE=A1=E7=90=86=E5=99=A8=20(#2305)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* feat: add persona management
* refactor: 重构函数工具管理器,引入 ToolSet,并让 Persona 支持绑定 Tools
* feat: 更新 Persona 工具选择逻辑,支持全选和指定工具的切换
* feat: 更新 BaseDatabase 中的 persona 方法返回类型,支持返回 None
---
astrbot/core/config/default.py | 40 +-
astrbot/core/core_lifecycle.py | 23 +-
astrbot/core/db/__init__.py | 17 +
astrbot/core/db/migration/helper.py | 9 +-
astrbot/core/db/migration/migra_3_to_4.py | 90 +-
astrbot/core/db/po.py | 39 +-
astrbot/core/db/sqlite.py | 42 +-
astrbot/core/persona_mgr.py | 162 ++++
astrbot/core/pipeline/context.py | 10 +-
.../agent_runner/tool_loop_agent.py | 133 ++-
.../process_stage/method/llm_request.py | 19 +-
astrbot/core/provider/entities.py | 4 +-
astrbot/core/provider/func_tool_manager.py | 585 ++++++++-----
astrbot/core/provider/manager.py | 97 +--
astrbot/core/provider/provider.py | 20 +-
astrbot/core/star/context.py | 54 +-
astrbot/dashboard/routes/__init__.py | 5 +-
astrbot/dashboard/routes/config.py | 7 -
astrbot/dashboard/routes/persona.py | 199 +++++
astrbot/dashboard/routes/tools.py | 40 +
astrbot/dashboard/server.py | 3 +
dashboard/src/i18n/loader.ts | 1 +
.../i18n/locales/en-US/core/navigation.json | 1 +
.../i18n/locales/en-US/features/persona.json | 67 ++
.../i18n/locales/en-US/features/tool-use.json | 5 +-
.../i18n/locales/zh-CN/core/navigation.json | 1 +
.../i18n/locales/zh-CN/features/persona.json | 67 ++
.../i18n/locales/zh-CN/features/tool-use.json | 7 +-
dashboard/src/i18n/translations.ts | 8 +-
.../full/vertical-sidebar/sidebarItem.ts | 6 +-
dashboard/src/router/MainRoutes.ts | 5 +
dashboard/src/views/PersonaPage.vue | 808 ++++++++++++++++++
dashboard/src/views/ToolUsePage.vue | 56 +-
packages/astrbot/main.py | 62 +-
34 files changed, 2112 insertions(+), 580 deletions(-)
create mode 100644 astrbot/core/persona_mgr.py
create mode 100644 astrbot/dashboard/routes/persona.py
create mode 100644 dashboard/src/i18n/locales/en-US/features/persona.json
create mode 100644 dashboard/src/i18n/locales/zh-CN/features/persona.json
create mode 100644 dashboard/src/views/PersonaPage.vue
diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py
index 41225e0ee..960ef107a 100644
--- a/astrbot/core/config/default.py
+++ b/astrbot/core/config/default.py
@@ -115,8 +115,7 @@ DEFAULT_CONFIG = {
"log_level": "INFO",
"pip_install_arg": "",
"pypi_index_url": "https://mirrors.aliyun.com/pypi/simple/",
- "knowledge_db": {},
- "persona": [],
+ "persona": [], # deprecated
"timezone": "",
"callback_api_base": "",
}
@@ -1701,43 +1700,6 @@ CONFIG_METADATA_2 = {
},
},
},
- "persona": {
- "description": "人格情景设置",
- "type": "list",
- "config_template": {
- "新人格情景": {
- "name": "",
- "prompt": "",
- "begin_dialogs": [],
- "mood_imitation_dialogs": [],
- }
- },
- "tmpl_display_title": "name",
- "items": {
- "name": {
- "description": "人格名称",
- "type": "string",
- "hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
- },
- "prompt": {
- "description": "设定(系统提示词)",
- "type": "text",
- "hint": "填写人格的身份背景、性格特征、兴趣爱好、个人经历、口头禅等。",
- },
- "begin_dialogs": {
- "description": "预设对话",
- "type": "list",
- "items": {"type": "string"},
- "hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
- },
- "mood_imitation_dialogs": {
- "description": "对话风格模仿",
- "type": "list",
- "items": {"type": "string"},
- "hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
- },
- },
- },
"provider_stt_settings": {
"description": "语音转文本(STT)",
"type": "object",
diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py
index 025bf03e5..28545a238 100644
--- a/astrbot/core/core_lifecycle.py
+++ b/astrbot/core/core_lifecycle.py
@@ -22,6 +22,7 @@ from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.star.context import Context
+from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
@@ -69,13 +70,23 @@ class AstrBotCoreLifecycle:
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
await self.db.initialize()
- await do_migration_v4(self.db, {})
+
+ try:
+ await do_migration_v4(self.db, {}, self.astrbot_config)
+ except Exception as e:
+ logger.error(f"迁移到 v4.0.0 新版本数据格式失败: {e}")
# 初始化事件队列
self.event_queue = Queue()
+ # 初始化人格管理器
+ self.persona_mgr = PersonaManager(self.db, self.astrbot_config)
+ await self.persona_mgr.initialize()
+
# 初始化供应商管理器
- self.provider_manager = ProviderManager(self.astrbot_config, self.db)
+ self.provider_manager = ProviderManager(
+ self.astrbot_config, self.db, self.persona_mgr
+ )
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
@@ -94,6 +105,8 @@ class AstrBotCoreLifecycle:
self.provider_manager,
self.platform_manager,
self.conversation_manager,
+ self.platform_message_history_manager,
+ self.persona_mgr,
)
# 初始化插件管理器
@@ -110,6 +123,7 @@ class AstrBotCoreLifecycle:
PipelineContext(self.astrbot_config, self.plugin_manager)
)
await self.pipeline_scheduler.initialize()
+ self.star_context.pipeline_ctx = self.pipeline_scheduler.ctx
# 初始化更新器
self.astrbot_updator = AstrBotUpdator()
@@ -232,6 +246,9 @@ class AstrBotCoreLifecycle:
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
tasks.append(
- asyncio.create_task(platform_inst.run(), name=f"{platform_inst.meta().id}({platform_inst.meta().name})")
+ asyncio.create_task(
+ platform_inst.run(),
+ name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
+ )
)
return tasks
diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py
index 53fccacfc..00c42505f 100644
--- a/astrbot/core/db/__init__.py
+++ b/astrbot/core/db/__init__.py
@@ -205,6 +205,7 @@ class BaseDatabase(abc.ABC):
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
+ tools: list[str] = None,
) -> Persona:
"""Insert a new persona record."""
...
@@ -219,6 +220,22 @@ class BaseDatabase(abc.ABC):
"""Get all personas for a specific bot."""
...
+ @abc.abstractmethod
+ async def update_persona(
+ self,
+ persona_id: str,
+ system_prompt: str = None,
+ begin_dialogs: list[str] = None,
+ tools: list[str] = None,
+ ) -> Persona | None:
+ """Update a persona's system prompt or begin dialogs."""
+ ...
+
+ @abc.abstractmethod
+ async def delete_persona(self, persona_id: str) -> None:
+ """Delete a persona by its ID."""
+ ...
+
@abc.abstractmethod
async def insert_preference_or_update(self, key: str, value: str) -> Preference:
"""Insert a new preference record."""
diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py
index d4b9c99f9..4eb428153 100644
--- a/astrbot/core/db/migration/helper.py
+++ b/astrbot/core/db/migration/helper.py
@@ -1,11 +1,13 @@
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.db import BaseDatabase
+from astrbot.core.config import AstrBotConfig
from astrbot.api import logger
from .migra_3_to_4 import (
migration_conversation_table,
migration_platform_table,
migration_webchat_data,
+ migration_persona_data,
)
@@ -24,7 +26,9 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
async def do_migration_v4(
- db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
+ db_helper: BaseDatabase,
+ platform_id_map: dict[str, dict[str, str]],
+ astrbot_config: AstrBotConfig,
):
"""
执行数据库迁移
@@ -45,6 +49,9 @@ async def do_migration_v4(
# 执行 WebChat 数据迁移
await migration_webchat_data(db_helper, platform_id_map)
+ # 执行人格数据迁移
+ await migration_persona_data(db_helper, astrbot_config)
+
# 标记迁移完成
await db_helper.insert_preference_or_update("migration_done_v4", "true")
diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py
index b6de3214e..e3eccd292 100644
--- a/astrbot/core/db/migration/migra_3_to_4.py
+++ b/astrbot/core/db/migration/migra_3_to_4.py
@@ -4,12 +4,10 @@ from .. import BaseDatabase
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
from astrbot.core.config.default import DB_PATH
from astrbot.api import logger
+from astrbot.core.config import AstrBotConfig
from astrbot.core.platform.astr_message_event import MessageSesion
from sqlalchemy.ext.asyncio import AsyncSession
-from astrbot.core.db.po import (
- ConversationV2,
- PlatformMessageHistory,
-)
+from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
from sqlalchemy import text
"""
@@ -50,20 +48,21 @@ async def migration_conversation_table(
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
- for conversation in conversations:
+ for idx, conversation in enumerate(conversations):
+ if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
+ progress = int((idx + 1) / total_cnt * 100)
+ if progress % 10 == 0:
+ logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
- logger.warning(
+ logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" not in conv.user_id:
- logger.warning(
- f"跳过 user_id 为 {conv.user_id} 的会话,它可能是 WebChat 的消息历史记录。"
- )
continue
session = MessageSesion.from_str(session_str=conv.user_id)
platform_id = get_platform_id(
@@ -81,13 +80,12 @@ async def migration_conversation_table(
updated_at=datetime.datetime.fromtimestamp(conv.updated_at),
)
dbsession.add(conv_v2)
- if conv_v2:
- logger.info(f"迁移旧会话 {conv.cid} 到新表成功。")
except Exception as e:
logger.error(
f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}",
exc_info=True,
)
+ logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。")
async def migration_platform_table(
@@ -107,7 +105,7 @@ async def migration_platform_table(
platform_stats_v3 = stats.platform
if not platform_stats_v3:
- logger.warning("没有找到旧平台数据,跳过迁移。")
+ logger.info("没有找到旧平台数据,跳过迁移。")
return
first_time_stamp = platform_stats_v3[0].timestamp
@@ -120,7 +118,13 @@ async def migration_platform_table(
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
- for bucket_end in range(start_time, end_time, 3600):
+ total_buckets = (end_time - start_time) // 3600
+ for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
+ if bucket_idx % 500 == 0:
+ progress = int((bucket_idx + 1) / total_buckets * 100)
+ logger.info(
+ f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})"
+ )
cnt = 0
while (
idx < len(platform_stats_v3)
@@ -136,9 +140,6 @@ async def migration_platform_table(
platform_type = get_platform_type(
platform_id_map, platform_stats_v3[idx].name
)
- logger.info(
- f"迁移平台统计数据: {platform_id}, {platform_type}, 时间戳: {bucket_end}, 计数: {cnt}"
- )
try:
await dbsession.execute(
text("""
@@ -161,6 +162,7 @@ async def migration_platform_table(
f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}",
exc_info=True,
)
+ logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。")
async def migration_webchat_data(
@@ -178,20 +180,21 @@ async def migration_webchat_data(
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
- for conversation in conversations:
+ for idx, conversation in enumerate(conversations):
+ if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
+ progress = int((idx + 1) / total_cnt * 100)
+ if progress % 10 == 0:
+ logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
- logger.warning(
+ logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" in conv.user_id:
- logger.warning(
- f"跳过 user_id 为 {conv.user_id} 的会话,它不是 WebChat 的消息历史记录。"
- )
continue
platform_id = "webchat"
history = json.loads(conv.history) if conv.history else []
@@ -206,9 +209,52 @@ async def migration_webchat_data(
)
dbsession.add(new_history)
- logger.info(f"迁移旧 WebChat 会话 {conv.cid} 到新表成功。")
except Exception:
logger.error(
f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败",
exc_info=True,
)
+
+ logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。")
+
+
+async def migration_persona_data(
+ db_helper: BaseDatabase, astrbot_config: AstrBotConfig
+):
+ """
+ 迁移 Persona 数据到新的表中。
+ 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
+ """
+ v3_persona_config: list[dict] = astrbot_config.get("persona", [])
+ total_personas = len(v3_persona_config)
+ logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...")
+
+ for idx, persona in enumerate(v3_persona_config):
+ if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0:
+ progress = int((idx + 1) / total_personas * 100)
+ if progress % 10 == 0:
+ logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})")
+ try:
+ begin_dialogs = persona.get("begin_dialogs", [])
+ mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
+ mood_prompt = ""
+ user_turn = True
+ for mood_dialog in mood_imitation_dialogs:
+ if user_turn:
+ mood_prompt += f"A: {mood_dialog}\n"
+ else:
+ mood_prompt += f"B: {mood_dialog}\n"
+ user_turn = not user_turn
+ system_prompt = persona.get("prompt", "")
+ if mood_prompt:
+ system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
+ persona_new = await db_helper.insert_persona(
+ persona_id=persona["name"],
+ system_prompt=system_prompt,
+ begin_dialogs=begin_dialogs,
+ )
+ logger.info(
+ f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
+ )
+ except Exception as e:
+ logger.error(f"解析 Persona 配置失败:{e}")
diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py
index 30f7188a1..cbc8c8d4e 100644
--- a/astrbot/core/db/po.py
+++ b/astrbot/core/db/po.py
@@ -9,7 +9,7 @@ from sqlmodel import (
UniqueConstraint,
Field,
)
-from typing import Optional
+from typing import Optional, TypedDict
class PlatformStat(SQLModel, table=True):
@@ -39,12 +39,14 @@ class PlatformStat(SQLModel, table=True):
class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations"
- inner_conversation_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
+ inner_conversation_id: int = Field(
+ primary_key=True, sa_column_kwargs={"autoincrement": True}
+ )
conversation_id: str = Field(
max_length=36,
nullable=False,
unique=True,
- default_factory=lambda: str(uuid.uuid4())
+ default_factory=lambda: str(uuid.uuid4()),
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False)
@@ -78,6 +80,8 @@ class Persona(SQLModel, table=True):
system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
"""a list of strings, each representing a dialog to start with"""
+ tools: Optional[list] = Field(default=None, sa_type=JSON)
+ """None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
@@ -119,7 +123,9 @@ class PlatformMessageHistory(SQLModel, table=True):
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
- sender_name: Optional[str] = Field(default=None) # Name of the sender in the platform
+ sender_name: Optional[str] = Field(
+ default=None
+ ) # Name of the sender in the platform
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
@@ -136,12 +142,14 @@ class Attachment(SQLModel, table=True):
__tablename__ = "attachments"
- inner_attachment_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
+ inner_attachment_id: int = Field(
+ primary_key=True, sa_column_kwargs={"autoincrement": True}
+ )
attachment_id: str = Field(
max_length=36,
nullable=False,
unique=True,
- default_factory=lambda: str(uuid.uuid4())
+ default_factory=lambda: str(uuid.uuid4()),
)
path: str = Field(nullable=False) # Path to the file on disk
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
@@ -182,6 +190,25 @@ class Conversation:
updated_at: int = 0
+class Personality(TypedDict):
+ """LLM 人格类。
+
+ 在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
+ """
+
+ prompt: str = ""
+ name: str = ""
+ begin_dialogs: list[str] = []
+ mood_imitation_dialogs: list[str] = []
+ """情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
+ tools: list[str] | None = None
+ """工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
+
+ # cache
+ _begin_dialogs_processed: list[dict] = []
+ _mood_imitation_dialogs_processed: str = ""
+
+
# ====
# Deprecated, and will be removed in future versions.
# ====
diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py
index 0308807da..0ecf787e5 100644
--- a/astrbot/core/db/sqlite.py
+++ b/astrbot/core/db/sqlite.py
@@ -19,6 +19,7 @@ from sqlalchemy import select, update, delete, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import func
+NOT_GIVEN = T.TypeVar("NOT_GIVEN")
class SQLiteDatabase(BaseDatabase):
def __init__(self, db_path: str) -> None:
@@ -99,9 +100,7 @@ class SQLiteDatabase(BaseDatabase):
# Conversation Management
# ====
- async def get_conversations(
- self, user_id=None, platform_id=None
- ):
+ async def get_conversations(self, user_id=None, platform_id=None):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2)
@@ -224,7 +223,7 @@ class SQLiteDatabase(BaseDatabase):
return
query = query.values(**values)
await session.execute(query)
- return await self.get_conversation_by_id(cid)
+ return await self.get_conversation_by_id(cid)
async def delete_conversation(self, cid):
async with self.get_db() as session:
@@ -312,7 +311,9 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalar_one_or_none()
- async def insert_persona(self, persona_id, system_prompt, begin_dialogs=None):
+ async def insert_persona(
+ self, persona_id, system_prompt, begin_dialogs=None, tools=None
+ ):
"""Insert a new persona record."""
async with self.get_db() as session:
session: AsyncSession
@@ -321,6 +322,7 @@ class SQLiteDatabase(BaseDatabase):
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs or [],
+ tools=tools,
)
session.add(new_persona)
return new_persona
@@ -341,6 +343,36 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalars().all()
+ async def update_persona(
+ self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN
+ ):
+ """Update a persona's system prompt or begin dialogs."""
+ async with self.get_db() as session:
+ session: AsyncSession
+ async with session.begin():
+ query = update(Persona).where(Persona.persona_id == persona_id)
+ values = {}
+ if system_prompt is not None:
+ values["system_prompt"] = system_prompt
+ if begin_dialogs is not None:
+ values["begin_dialogs"] = begin_dialogs
+ if tools is not NOT_GIVEN:
+ values["tools"] = tools
+ if not values:
+ return
+ query = query.values(**values)
+ await session.execute(query)
+ return await self.get_persona_by_id(persona_id)
+
+ async def delete_persona(self, persona_id):
+ """Delete a persona by its ID."""
+ async with self.get_db() as session:
+ session: AsyncSession
+ async with session.begin():
+ await session.execute(
+ delete(Persona).where(Persona.persona_id == persona_id)
+ )
+
async def insert_preference_or_update(self, key, value):
"""Insert a new preference record or update if it exists."""
async with self.get_db() as session:
diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py
new file mode 100644
index 000000000..dc322c789
--- /dev/null
+++ b/astrbot/core/persona_mgr.py
@@ -0,0 +1,162 @@
+from astrbot.core.db import BaseDatabase
+from astrbot.core.db.po import Persona, Personality
+from astrbot.core.config import AstrBotConfig
+from astrbot import logger
+
+
+class PersonaManager:
+ def __init__(self, db_helper: BaseDatabase, astrbot_config: AstrBotConfig):
+ self.db = db_helper
+ self.config = astrbot_config
+ _ps: dict = astrbot_config["provider_settings"]
+ self.default_persona: str = _ps.get("default_personality", "default")
+ self.personas: list[Persona] = []
+ self.selected_default_persona: Persona | None = None
+
+ self.personas_v3: list[Personality] = []
+ self.selected_default_persona_v3: Personality | None = None
+ self.persona_v3_config: list[dict] = []
+
+ async def initialize(self):
+ self.personas = await self.get_all_personas()
+ self.get_v3_persona_data()
+ logger.info(f"已加载 {len(self.personas)} 个人格。")
+
+ async def get_persona(self, persona_id: str):
+ """获取指定 persona 的信息"""
+ persona = await self.db.get_persona_by_id(persona_id)
+ if not persona:
+ raise ValueError(f"Persona with ID {persona_id} does not exist.")
+ return persona
+
+ async def delete_persona(self, persona_id: str):
+ """删除指定 persona"""
+ if not await self.db.get_persona_by_id(persona_id):
+ raise ValueError(f"Persona with ID {persona_id} does not exist.")
+ await self.db.delete_persona(persona_id)
+ self.personas = [p for p in self.personas if p.persona_id != persona_id]
+ self.get_v3_persona_data()
+
+ async def update_persona(
+ self,
+ persona_id: str,
+ system_prompt: str = None,
+ begin_dialogs: list[str] = None,
+ tools: list[str] = None,
+ ):
+ """更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
+ existing_persona = await self.db.get_persona_by_id(persona_id)
+ if not existing_persona:
+ raise ValueError(f"Persona with ID {persona_id} does not exist.")
+ persona = await self.db.update_persona(
+ persona_id, system_prompt, begin_dialogs, tools=tools
+ )
+ if persona:
+ for i, p in enumerate(self.personas):
+ if p.persona_id == persona_id:
+ self.personas[i] = persona
+ break
+ self.get_v3_persona_data()
+ return persona
+
+ async def get_all_personas(self) -> list[Persona]:
+ """获取所有 personas"""
+ return await self.db.get_personas()
+
+ async def create_persona(
+ self,
+ persona_id: str,
+ system_prompt: str,
+ begin_dialogs: list[str] = None,
+ tools: list[str] = None,
+ ) -> Persona:
+ """创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
+ if await self.db.get_persona_by_id(persona_id):
+ raise ValueError(f"Persona with ID {persona_id} already exists.")
+ new_persona = await self.db.insert_persona(
+ persona_id, system_prompt, begin_dialogs, tools=tools
+ )
+ self.personas.append(new_persona)
+ self.get_v3_persona_data()
+ return new_persona
+
+ def get_v3_persona_data(
+ self,
+ ) -> tuple[list[dict], list[Personality], Personality]:
+ """获取 AstrBot <4.0.0 版本的 persona 数据。
+
+ Returns:
+ - list[dict]: 包含 persona 配置的字典列表。
+ - list[Personality]: 包含 Personality 对象的列表。
+ - Personality: 默认选择的 Personality 对象。
+ """
+ v3_persona_config = [
+ {
+ "prompt": persona.system_prompt,
+ "name": persona.persona_id,
+ "begin_dialogs": persona.begin_dialogs or [],
+ "mood_imitation_dialogs": [], # deprecated
+ "tools": persona.tools,
+ }
+ for persona in self.personas
+ ]
+
+ personas_v3: list[Personality] = []
+ selected_default_persona: Personality | None = None
+
+ for persona_cfg in v3_persona_config:
+ begin_dialogs = persona_cfg.get("begin_dialogs", [])
+ bd_processed = []
+ if begin_dialogs:
+ if len(begin_dialogs) % 2 != 0:
+ logger.error(
+ f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
+ )
+ begin_dialogs = []
+ user_turn = True
+ for dialog in begin_dialogs:
+ bd_processed.append(
+ {
+ "role": "user" if user_turn else "assistant",
+ "content": dialog,
+ "_no_save": None, # 不持久化到 db
+ }
+ )
+ user_turn = not user_turn
+
+ try:
+ persona = Personality(
+ **persona_cfg,
+ _begin_dialogs_processed=bd_processed,
+ _mood_imitation_dialogs_processed="", # deprecated
+ )
+ if persona["name"] == self.default_persona:
+ selected_default_persona = persona
+ personas_v3.append(persona)
+ except Exception as e:
+ logger.error(f"解析 Persona 配置失败:{e}")
+
+ if not selected_default_persona and len(personas_v3) > 0:
+ # 默认选择第一个
+ selected_default_persona = personas_v3[0]
+
+ if not selected_default_persona:
+ selected_default_persona = Personality(
+ prompt="You are a helpful and friendly assistant.",
+ name="default",
+ tools=None,
+ _begin_dialogs_processed=[],
+ )
+ personas_v3.append(selected_default_persona)
+
+ self.personas_v3 = personas_v3
+ self.selected_default_persona_v3 = selected_default_persona
+ self.persona_v3_config = v3_persona_config
+ self.selected_default_persona = Persona(
+ persona_id=selected_default_persona["name"],
+ system_prompt=selected_default_persona["prompt"],
+ begin_dialogs=selected_default_persona["begin_dialogs"],
+ tools=selected_default_persona["tools"] or None,
+ )
+
+ return v3_persona_config, personas_v3, selected_default_persona
diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py
index 0b9d9e533..932c5d5c2 100644
--- a/astrbot/core/pipeline/context.py
+++ b/astrbot/core/pipeline/context.py
@@ -55,7 +55,7 @@ class PipelineContext:
handler: T.Awaitable,
*args,
**kwargs,
- ) -> T.AsyncGenerator[None, None]:
+ ) -> T.AsyncGenerator[T.Any, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
@@ -77,10 +77,16 @@ class PipelineContext:
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as _:
+ if self.plugin_manager:
+ context = self.plugin_manager.context
+ else:
+ raise ValueError(
+ "Cannot call handler without a valid context or plugin manager."
+ )
# 向下兼容
trace_ = traceback.format_exc()
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
- ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
+ ready_to_call = handler(event, context, *args, **kwargs)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
diff --git a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py
index c2961ded5..cd705275e 100644
--- a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py
+++ b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py
@@ -21,6 +21,7 @@ from mcp.types import (
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
+ CallToolResult,
)
from astrbot.core.star.star_handler import EventType
from astrbot import logger
@@ -193,50 +194,25 @@ class ToolLoopAgent(BaseAgentRunner):
if not req.func_tool:
return
func_tool = req.func_tool.get_func(func_tool_name)
- if func_tool.origin == "mcp":
- logger.info(
- f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
- )
- client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
- res = await client.session.call_tool(func_tool.name, func_tool_args)
- if not res:
- continue
- if isinstance(res.content[0], TextContent):
- tool_call_result_blocks.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content=res.content[0].text,
- )
- )
- yield MessageChain().message(res.content[0].text)
- elif isinstance(res.content[0], ImageContent):
- tool_call_result_blocks.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content="返回了图片(已直接发送给用户)",
- )
- )
- yield MessageChain(type="tool_direct_result").base64_image(
- res.content[0].data
- )
- elif isinstance(res.content[0], EmbeddedResource):
- resource = res.content[0].resource
- if isinstance(resource, TextResourceContents):
+ logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
+ executor = func_tool.execute(
+ event=self.event,
+ pipeline_context=self.pipeline_ctx,
+ **func_tool_args,
+ )
+ async for resp in executor:
+ if isinstance(resp, CallToolResult):
+ res = resp
+ if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
- content=resource.text,
+ content=res.content[0].text,
)
)
- yield MessageChain().message(resource.text)
- elif (
- isinstance(resource, BlobResourceContents)
- and resource.mimeType
- and resource.mimeType.startswith("image/")
- ):
+ yield MessageChain().message(res.content[0].text)
+ elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
@@ -247,41 +223,54 @@ class ToolLoopAgent(BaseAgentRunner):
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
- else:
- tool_call_result_blocks.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content="返回的数据类型不受支持",
- )
- )
- yield MessageChain().message("返回的数据类型不受支持。")
- else:
- logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
- # 尝试调用工具函数
- wrapper = self.pipeline_ctx.call_handler(
- self.event, func_tool.handler, **func_tool_args
- )
- async for resp in wrapper:
- if resp is not None:
- # Tool 返回结果
- tool_call_result_blocks.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content=resp,
- )
- )
- yield MessageChain().message(resp)
- else:
- # Tool 直接请求发送消息给用户
- # 这里我们将直接结束 Agent Loop。
- self._transition_state(AgentState.DONE)
- if res := self.event.get_result():
- if res.chain:
- yield MessageChain(
- chain=res.chain, type="tool_direct_result"
+ elif isinstance(res.content[0], EmbeddedResource):
+ resource = res.content[0].resource
+ if isinstance(resource, TextResourceContents):
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content=resource.text,
)
+ )
+ yield MessageChain().message(resource.text)
+ elif (
+ isinstance(resource, BlobResourceContents)
+ and resource.mimeType
+ and resource.mimeType.startswith("image/")
+ ):
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content="返回了图片(已直接发送给用户)",
+ )
+ )
+ yield MessageChain(
+ type="tool_direct_result"
+ ).base64_image(res.content[0].data)
+ else:
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content="返回的数据类型不受支持",
+ )
+ )
+ yield MessageChain().message("返回的数据类型不受支持。")
+ elif resp is None:
+ # Tool 直接请求发送消息给用户
+ # 这里我们将直接结束 Agent Loop。
+ self._transition_state(AgentState.DONE)
+ if res := self.event.get_result():
+ if res.chain:
+ yield MessageChain(
+ chain=res.chain, type="tool_direct_result"
+ )
+ else:
+ logger.warning(
+ f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
+ )
self.event.clear_result()
except Exception as e:
diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py
index ae44cf36d..505b18e8e 100644
--- a/astrbot/core/pipeline/process_stage/method/llm_request.py
+++ b/astrbot/core/pipeline/process_stage/method/llm_request.py
@@ -100,7 +100,8 @@ class LLMRequestSubStage(Stage):
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
- req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
+ # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
+ # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
@@ -274,7 +275,6 @@ class LLMRequestSubStage(Stage):
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
-
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
@@ -307,19 +307,10 @@ class LLMRequestSubStage(Stage):
if not title or "
+ {{ tm('page.description') }}
+ {{ tm('empty.description') }}
+ {{ tm('form.toolsHelp') }}
+ {{ tm('form.noToolsAvailable')
+ }}
+ {{ tm('form.noToolsFound') }}
+ {{ tm('form.loadingTools')
+ }}
+
+ {{ tm('form.presetDialogsHelp') }}
+
+
+ {{ tm('empty.title') }}
+ {{ tm('form.mcpServersQuickSelect') }}
+
+ {{ tm('form.selectedTools') }}
+
+ ({{ tm('form.allSelected') }})
+
+
+ ({{ personaForm.tools.length }})
+
+
+ {{ tm('form.systemPrompt') }}
+ {{ tm('form.presetDialogs') }}
+ {{ tm('form.tools') }}
+
{{ tool.function.description }}
+{{ tool.description }}
- +