From b1e3018b6b971e6edd036ee63660236bb623c140 Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Mon, 4 Aug 2025 00:56:26 +0800 Subject: [PATCH] =?UTF-8?q?Improve:=20=E5=BC=95=E5=85=A5=E5=85=A8=E6=96=B0?= =?UTF-8?q?=E7=9A=84=E4=BA=BA=E6=A0=BC=E7=AE=A1=E7=90=86=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=E4=BB=A5=E5=8F=8A=E9=87=8D=E6=9E=84=E5=87=BD=E6=95=B0=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E7=AE=A1=E7=90=86=E5=99=A8=20(#2305)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add persona management * refactor: 重构函数工具管理器,引入 ToolSet,并让 Persona 支持绑定 Tools * feat: 更新 Persona 工具选择逻辑,支持全选和指定工具的切换 * feat: 更新 BaseDatabase 中的 persona 方法返回类型,支持返回 None --- astrbot/core/config/default.py | 40 +- astrbot/core/core_lifecycle.py | 23 +- astrbot/core/db/__init__.py | 17 + astrbot/core/db/migration/helper.py | 9 +- astrbot/core/db/migration/migra_3_to_4.py | 90 +- astrbot/core/db/po.py | 39 +- astrbot/core/db/sqlite.py | 42 +- astrbot/core/persona_mgr.py | 162 ++++ astrbot/core/pipeline/context.py | 10 +- .../agent_runner/tool_loop_agent.py | 133 ++- .../process_stage/method/llm_request.py | 19 +- astrbot/core/provider/entities.py | 4 +- astrbot/core/provider/func_tool_manager.py | 585 ++++++++----- astrbot/core/provider/manager.py | 97 +-- astrbot/core/provider/provider.py | 20 +- astrbot/core/star/context.py | 54 +- astrbot/dashboard/routes/__init__.py | 5 +- astrbot/dashboard/routes/config.py | 7 - astrbot/dashboard/routes/persona.py | 199 +++++ astrbot/dashboard/routes/tools.py | 40 + astrbot/dashboard/server.py | 3 + dashboard/src/i18n/loader.ts | 1 + .../i18n/locales/en-US/core/navigation.json | 1 + .../i18n/locales/en-US/features/persona.json | 67 ++ .../i18n/locales/en-US/features/tool-use.json | 5 +- .../i18n/locales/zh-CN/core/navigation.json | 1 + .../i18n/locales/zh-CN/features/persona.json | 67 ++ .../i18n/locales/zh-CN/features/tool-use.json | 7 +- dashboard/src/i18n/translations.ts | 8 +- .../full/vertical-sidebar/sidebarItem.ts | 6 +- dashboard/src/router/MainRoutes.ts | 5 + dashboard/src/views/PersonaPage.vue | 808 ++++++++++++++++++ dashboard/src/views/ToolUsePage.vue | 56 +- packages/astrbot/main.py | 62 +- 34 files changed, 2112 insertions(+), 580 deletions(-) create mode 100644 astrbot/core/persona_mgr.py create mode 100644 astrbot/dashboard/routes/persona.py create mode 100644 dashboard/src/i18n/locales/en-US/features/persona.json create mode 100644 dashboard/src/i18n/locales/zh-CN/features/persona.json create mode 100644 dashboard/src/views/PersonaPage.vue diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 41225e0ee..960ef107a 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -115,8 +115,7 @@ DEFAULT_CONFIG = { "log_level": "INFO", "pip_install_arg": "", "pypi_index_url": "https://mirrors.aliyun.com/pypi/simple/", - "knowledge_db": {}, - "persona": [], + "persona": [], # deprecated "timezone": "", "callback_api_base": "", } @@ -1701,43 +1700,6 @@ CONFIG_METADATA_2 = { }, }, }, - "persona": { - "description": "人格情景设置", - "type": "list", - "config_template": { - "新人格情景": { - "name": "", - "prompt": "", - "begin_dialogs": [], - "mood_imitation_dialogs": [], - } - }, - "tmpl_display_title": "name", - "items": { - "name": { - "description": "人格名称", - "type": "string", - "hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。", - }, - "prompt": { - "description": "设定(系统提示词)", - "type": "text", - "hint": "填写人格的身份背景、性格特征、兴趣爱好、个人经历、口头禅等。", - }, - "begin_dialogs": { - "description": "预设对话", - "type": "list", - "items": {"type": "string"}, - "hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话", - }, - "mood_imitation_dialogs": { - "description": "对话风格模仿", - "type": "list", - "items": {"type": "string"}, - "hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话", - }, - }, - }, "provider_stt_settings": { "description": "语音转文本(STT)", "type": "object", diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index 025bf03e5..28545a238 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -22,6 +22,7 @@ from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext from astrbot.core.star import PluginManager from astrbot.core.platform.manager import PlatformManager from astrbot.core.star.context import Context +from astrbot.core.persona_mgr import PersonaManager from astrbot.core.provider.manager import ProviderManager from astrbot.core import LogBroker from astrbot.core.db import BaseDatabase @@ -69,13 +70,23 @@ class AstrBotCoreLifecycle: logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别 await self.db.initialize() - await do_migration_v4(self.db, {}) + + try: + await do_migration_v4(self.db, {}, self.astrbot_config) + except Exception as e: + logger.error(f"迁移到 v4.0.0 新版本数据格式失败: {e}") # 初始化事件队列 self.event_queue = Queue() + # 初始化人格管理器 + self.persona_mgr = PersonaManager(self.db, self.astrbot_config) + await self.persona_mgr.initialize() + # 初始化供应商管理器 - self.provider_manager = ProviderManager(self.astrbot_config, self.db) + self.provider_manager = ProviderManager( + self.astrbot_config, self.db, self.persona_mgr + ) # 初始化平台管理器 self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue) @@ -94,6 +105,8 @@ class AstrBotCoreLifecycle: self.provider_manager, self.platform_manager, self.conversation_manager, + self.platform_message_history_manager, + self.persona_mgr, ) # 初始化插件管理器 @@ -110,6 +123,7 @@ class AstrBotCoreLifecycle: PipelineContext(self.astrbot_config, self.plugin_manager) ) await self.pipeline_scheduler.initialize() + self.star_context.pipeline_ctx = self.pipeline_scheduler.ctx # 初始化更新器 self.astrbot_updator = AstrBotUpdator() @@ -232,6 +246,9 @@ class AstrBotCoreLifecycle: platform_insts = self.platform_manager.get_insts() for platform_inst in platform_insts: tasks.append( - asyncio.create_task(platform_inst.run(), name=f"{platform_inst.meta().id}({platform_inst.meta().name})") + asyncio.create_task( + platform_inst.run(), + name=f"{platform_inst.meta().id}({platform_inst.meta().name})", + ) ) return tasks diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 53fccacfc..00c42505f 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -205,6 +205,7 @@ class BaseDatabase(abc.ABC): persona_id: str, system_prompt: str, begin_dialogs: list[str] = None, + tools: list[str] = None, ) -> Persona: """Insert a new persona record.""" ... @@ -219,6 +220,22 @@ class BaseDatabase(abc.ABC): """Get all personas for a specific bot.""" ... + @abc.abstractmethod + async def update_persona( + self, + persona_id: str, + system_prompt: str = None, + begin_dialogs: list[str] = None, + tools: list[str] = None, + ) -> Persona | None: + """Update a persona's system prompt or begin dialogs.""" + ... + + @abc.abstractmethod + async def delete_persona(self, persona_id: str) -> None: + """Delete a persona by its ID.""" + ... + @abc.abstractmethod async def insert_preference_or_update(self, key: str, value: str) -> Preference: """Insert a new preference record.""" diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py index d4b9c99f9..4eb428153 100644 --- a/astrbot/core/db/migration/helper.py +++ b/astrbot/core/db/migration/helper.py @@ -1,11 +1,13 @@ import os from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.db import BaseDatabase +from astrbot.core.config import AstrBotConfig from astrbot.api import logger from .migra_3_to_4 import ( migration_conversation_table, migration_platform_table, migration_webchat_data, + migration_persona_data, ) @@ -24,7 +26,9 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool: async def do_migration_v4( - db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]] + db_helper: BaseDatabase, + platform_id_map: dict[str, dict[str, str]], + astrbot_config: AstrBotConfig, ): """ 执行数据库迁移 @@ -45,6 +49,9 @@ async def do_migration_v4( # 执行 WebChat 数据迁移 await migration_webchat_data(db_helper, platform_id_map) + # 执行人格数据迁移 + await migration_persona_data(db_helper, astrbot_config) + # 标记迁移完成 await db_helper.insert_preference_or_update("migration_done_v4", "true") diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index b6de3214e..e3eccd292 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -4,12 +4,10 @@ from .. import BaseDatabase from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3 from astrbot.core.config.default import DB_PATH from astrbot.api import logger +from astrbot.core.config import AstrBotConfig from astrbot.core.platform.astr_message_event import MessageSesion from sqlalchemy.ext.asyncio import AsyncSession -from astrbot.core.db.po import ( - ConversationV2, - PlatformMessageHistory, -) +from astrbot.core.db.po import ConversationV2, PlatformMessageHistory from sqlalchemy import text """ @@ -50,20 +48,21 @@ async def migration_conversation_table( async with db_helper.get_db() as dbsession: dbsession: AsyncSession async with dbsession.begin(): - for conversation in conversations: + for idx, conversation in enumerate(conversations): + if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0: + progress = int((idx + 1) / total_cnt * 100) + if progress % 10 == 0: + logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})") try: conv = db_helper_v3.get_conversation_by_user_id( user_id=conversation.get("user_id", "unknown"), cid=conversation.get("cid", "unknown"), ) if not conv: - logger.warning( + logger.info( f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。" ) if ":" not in conv.user_id: - logger.warning( - f"跳过 user_id 为 {conv.user_id} 的会话,它可能是 WebChat 的消息历史记录。" - ) continue session = MessageSesion.from_str(session_str=conv.user_id) platform_id = get_platform_id( @@ -81,13 +80,12 @@ async def migration_conversation_table( updated_at=datetime.datetime.fromtimestamp(conv.updated_at), ) dbsession.add(conv_v2) - if conv_v2: - logger.info(f"迁移旧会话 {conv.cid} 到新表成功。") except Exception as e: logger.error( f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}", exc_info=True, ) + logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。") async def migration_platform_table( @@ -107,7 +105,7 @@ async def migration_platform_table( platform_stats_v3 = stats.platform if not platform_stats_v3: - logger.warning("没有找到旧平台数据,跳过迁移。") + logger.info("没有找到旧平台数据,跳过迁移。") return first_time_stamp = platform_stats_v3[0].timestamp @@ -120,7 +118,13 @@ async def migration_platform_table( async with db_helper.get_db() as dbsession: dbsession: AsyncSession async with dbsession.begin(): - for bucket_end in range(start_time, end_time, 3600): + total_buckets = (end_time - start_time) // 3600 + for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)): + if bucket_idx % 500 == 0: + progress = int((bucket_idx + 1) / total_buckets * 100) + logger.info( + f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})" + ) cnt = 0 while ( idx < len(platform_stats_v3) @@ -136,9 +140,6 @@ async def migration_platform_table( platform_type = get_platform_type( platform_id_map, platform_stats_v3[idx].name ) - logger.info( - f"迁移平台统计数据: {platform_id}, {platform_type}, 时间戳: {bucket_end}, 计数: {cnt}" - ) try: await dbsession.execute( text(""" @@ -161,6 +162,7 @@ async def migration_platform_table( f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}", exc_info=True, ) + logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。") async def migration_webchat_data( @@ -178,20 +180,21 @@ async def migration_webchat_data( async with db_helper.get_db() as dbsession: dbsession: AsyncSession async with dbsession.begin(): - for conversation in conversations: + for idx, conversation in enumerate(conversations): + if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0: + progress = int((idx + 1) / total_cnt * 100) + if progress % 10 == 0: + logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})") try: conv = db_helper_v3.get_conversation_by_user_id( user_id=conversation.get("user_id", "unknown"), cid=conversation.get("cid", "unknown"), ) if not conv: - logger.warning( + logger.info( f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。" ) if ":" in conv.user_id: - logger.warning( - f"跳过 user_id 为 {conv.user_id} 的会话,它不是 WebChat 的消息历史记录。" - ) continue platform_id = "webchat" history = json.loads(conv.history) if conv.history else [] @@ -206,9 +209,52 @@ async def migration_webchat_data( ) dbsession.add(new_history) - logger.info(f"迁移旧 WebChat 会话 {conv.cid} 到新表成功。") except Exception: logger.error( f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败", exc_info=True, ) + + logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。") + + +async def migration_persona_data( + db_helper: BaseDatabase, astrbot_config: AstrBotConfig +): + """ + 迁移 Persona 数据到新的表中。 + 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。 + """ + v3_persona_config: list[dict] = astrbot_config.get("persona", []) + total_personas = len(v3_persona_config) + logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...") + + for idx, persona in enumerate(v3_persona_config): + if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0: + progress = int((idx + 1) / total_personas * 100) + if progress % 10 == 0: + logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})") + try: + begin_dialogs = persona.get("begin_dialogs", []) + mood_imitation_dialogs = persona.get("mood_imitation_dialogs", []) + mood_prompt = "" + user_turn = True + for mood_dialog in mood_imitation_dialogs: + if user_turn: + mood_prompt += f"A: {mood_dialog}\n" + else: + mood_prompt += f"B: {mood_dialog}\n" + user_turn = not user_turn + system_prompt = persona.get("prompt", "") + if mood_prompt: + system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}" + persona_new = await db_helper.insert_persona( + persona_id=persona["name"], + system_prompt=system_prompt, + begin_dialogs=begin_dialogs, + ) + logger.info( + f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。" + ) + except Exception as e: + logger.error(f"解析 Persona 配置失败:{e}") diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 30f7188a1..cbc8c8d4e 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -9,7 +9,7 @@ from sqlmodel import ( UniqueConstraint, Field, ) -from typing import Optional +from typing import Optional, TypedDict class PlatformStat(SQLModel, table=True): @@ -39,12 +39,14 @@ class PlatformStat(SQLModel, table=True): class ConversationV2(SQLModel, table=True): __tablename__ = "conversations" - inner_conversation_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + inner_conversation_id: int = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True} + ) conversation_id: str = Field( max_length=36, nullable=False, unique=True, - default_factory=lambda: str(uuid.uuid4()) + default_factory=lambda: str(uuid.uuid4()), ) platform_id: str = Field(nullable=False) user_id: str = Field(nullable=False) @@ -78,6 +80,8 @@ class Persona(SQLModel, table=True): system_prompt: str = Field(sa_type=Text, nullable=False) begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON) """a list of strings, each representing a dialog to start with""" + tools: Optional[list] = Field(default=None, sa_type=JSON) + """None means use ALL tools for default, empty list means no tools, otherwise a list of tool names.""" created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( default_factory=lambda: datetime.now(timezone.utc), @@ -119,7 +123,9 @@ class PlatformMessageHistory(SQLModel, table=True): platform_id: str = Field(nullable=False) user_id: str = Field(nullable=False) # An id of group, user in platform sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform - sender_name: Optional[str] = Field(default=None) # Name of the sender in the platform + sender_name: Optional[str] = Field( + default=None + ) # Name of the sender in the platform content: dict = Field(sa_type=JSON, nullable=False) # a message chain list created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) updated_at: datetime = Field( @@ -136,12 +142,14 @@ class Attachment(SQLModel, table=True): __tablename__ = "attachments" - inner_attachment_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) + inner_attachment_id: int = Field( + primary_key=True, sa_column_kwargs={"autoincrement": True} + ) attachment_id: str = Field( max_length=36, nullable=False, unique=True, - default_factory=lambda: str(uuid.uuid4()) + default_factory=lambda: str(uuid.uuid4()), ) path: str = Field(nullable=False) # Path to the file on disk type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file') @@ -182,6 +190,25 @@ class Conversation: updated_at: int = 0 +class Personality(TypedDict): + """LLM 人格类。 + + 在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。 + """ + + prompt: str = "" + name: str = "" + begin_dialogs: list[str] = [] + mood_imitation_dialogs: list[str] = [] + """情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。""" + tools: list[str] | None = None + """工具列表。None 表示使用所有工具,空列表表示不使用任何工具""" + + # cache + _begin_dialogs_processed: list[dict] = [] + _mood_imitation_dialogs_processed: str = "" + + # ==== # Deprecated, and will be removed in future versions. # ==== diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index 0308807da..0ecf787e5 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -19,6 +19,7 @@ from sqlalchemy import select, update, delete, text from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql import func +NOT_GIVEN = T.TypeVar("NOT_GIVEN") class SQLiteDatabase(BaseDatabase): def __init__(self, db_path: str) -> None: @@ -99,9 +100,7 @@ class SQLiteDatabase(BaseDatabase): # Conversation Management # ==== - async def get_conversations( - self, user_id=None, platform_id=None - ): + async def get_conversations(self, user_id=None, platform_id=None): async with self.get_db() as session: session: AsyncSession query = select(ConversationV2) @@ -224,7 +223,7 @@ class SQLiteDatabase(BaseDatabase): return query = query.values(**values) await session.execute(query) - return await self.get_conversation_by_id(cid) + return await self.get_conversation_by_id(cid) async def delete_conversation(self, cid): async with self.get_db() as session: @@ -312,7 +311,9 @@ class SQLiteDatabase(BaseDatabase): result = await session.execute(query) return result.scalar_one_or_none() - async def insert_persona(self, persona_id, system_prompt, begin_dialogs=None): + async def insert_persona( + self, persona_id, system_prompt, begin_dialogs=None, tools=None + ): """Insert a new persona record.""" async with self.get_db() as session: session: AsyncSession @@ -321,6 +322,7 @@ class SQLiteDatabase(BaseDatabase): persona_id=persona_id, system_prompt=system_prompt, begin_dialogs=begin_dialogs or [], + tools=tools, ) session.add(new_persona) return new_persona @@ -341,6 +343,36 @@ class SQLiteDatabase(BaseDatabase): result = await session.execute(query) return result.scalars().all() + async def update_persona( + self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN + ): + """Update a persona's system prompt or begin dialogs.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = update(Persona).where(Persona.persona_id == persona_id) + values = {} + if system_prompt is not None: + values["system_prompt"] = system_prompt + if begin_dialogs is not None: + values["begin_dialogs"] = begin_dialogs + if tools is not NOT_GIVEN: + values["tools"] = tools + if not values: + return + query = query.values(**values) + await session.execute(query) + return await self.get_persona_by_id(persona_id) + + async def delete_persona(self, persona_id): + """Delete a persona by its ID.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(Persona).where(Persona.persona_id == persona_id) + ) + async def insert_preference_or_update(self, key, value): """Insert a new preference record or update if it exists.""" async with self.get_db() as session: diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py new file mode 100644 index 000000000..dc322c789 --- /dev/null +++ b/astrbot/core/persona_mgr.py @@ -0,0 +1,162 @@ +from astrbot.core.db import BaseDatabase +from astrbot.core.db.po import Persona, Personality +from astrbot.core.config import AstrBotConfig +from astrbot import logger + + +class PersonaManager: + def __init__(self, db_helper: BaseDatabase, astrbot_config: AstrBotConfig): + self.db = db_helper + self.config = astrbot_config + _ps: dict = astrbot_config["provider_settings"] + self.default_persona: str = _ps.get("default_personality", "default") + self.personas: list[Persona] = [] + self.selected_default_persona: Persona | None = None + + self.personas_v3: list[Personality] = [] + self.selected_default_persona_v3: Personality | None = None + self.persona_v3_config: list[dict] = [] + + async def initialize(self): + self.personas = await self.get_all_personas() + self.get_v3_persona_data() + logger.info(f"已加载 {len(self.personas)} 个人格。") + + async def get_persona(self, persona_id: str): + """获取指定 persona 的信息""" + persona = await self.db.get_persona_by_id(persona_id) + if not persona: + raise ValueError(f"Persona with ID {persona_id} does not exist.") + return persona + + async def delete_persona(self, persona_id: str): + """删除指定 persona""" + if not await self.db.get_persona_by_id(persona_id): + raise ValueError(f"Persona with ID {persona_id} does not exist.") + await self.db.delete_persona(persona_id) + self.personas = [p for p in self.personas if p.persona_id != persona_id] + self.get_v3_persona_data() + + async def update_persona( + self, + persona_id: str, + system_prompt: str = None, + begin_dialogs: list[str] = None, + tools: list[str] = None, + ): + """更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" + existing_persona = await self.db.get_persona_by_id(persona_id) + if not existing_persona: + raise ValueError(f"Persona with ID {persona_id} does not exist.") + persona = await self.db.update_persona( + persona_id, system_prompt, begin_dialogs, tools=tools + ) + if persona: + for i, p in enumerate(self.personas): + if p.persona_id == persona_id: + self.personas[i] = persona + break + self.get_v3_persona_data() + return persona + + async def get_all_personas(self) -> list[Persona]: + """获取所有 personas""" + return await self.db.get_personas() + + async def create_persona( + self, + persona_id: str, + system_prompt: str, + begin_dialogs: list[str] = None, + tools: list[str] = None, + ) -> Persona: + """创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" + if await self.db.get_persona_by_id(persona_id): + raise ValueError(f"Persona with ID {persona_id} already exists.") + new_persona = await self.db.insert_persona( + persona_id, system_prompt, begin_dialogs, tools=tools + ) + self.personas.append(new_persona) + self.get_v3_persona_data() + return new_persona + + def get_v3_persona_data( + self, + ) -> tuple[list[dict], list[Personality], Personality]: + """获取 AstrBot <4.0.0 版本的 persona 数据。 + + Returns: + - list[dict]: 包含 persona 配置的字典列表。 + - list[Personality]: 包含 Personality 对象的列表。 + - Personality: 默认选择的 Personality 对象。 + """ + v3_persona_config = [ + { + "prompt": persona.system_prompt, + "name": persona.persona_id, + "begin_dialogs": persona.begin_dialogs or [], + "mood_imitation_dialogs": [], # deprecated + "tools": persona.tools, + } + for persona in self.personas + ] + + personas_v3: list[Personality] = [] + selected_default_persona: Personality | None = None + + for persona_cfg in v3_persona_config: + begin_dialogs = persona_cfg.get("begin_dialogs", []) + bd_processed = [] + if begin_dialogs: + if len(begin_dialogs) % 2 != 0: + logger.error( + f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。" + ) + begin_dialogs = [] + user_turn = True + for dialog in begin_dialogs: + bd_processed.append( + { + "role": "user" if user_turn else "assistant", + "content": dialog, + "_no_save": None, # 不持久化到 db + } + ) + user_turn = not user_turn + + try: + persona = Personality( + **persona_cfg, + _begin_dialogs_processed=bd_processed, + _mood_imitation_dialogs_processed="", # deprecated + ) + if persona["name"] == self.default_persona: + selected_default_persona = persona + personas_v3.append(persona) + except Exception as e: + logger.error(f"解析 Persona 配置失败:{e}") + + if not selected_default_persona and len(personas_v3) > 0: + # 默认选择第一个 + selected_default_persona = personas_v3[0] + + if not selected_default_persona: + selected_default_persona = Personality( + prompt="You are a helpful and friendly assistant.", + name="default", + tools=None, + _begin_dialogs_processed=[], + ) + personas_v3.append(selected_default_persona) + + self.personas_v3 = personas_v3 + self.selected_default_persona_v3 = selected_default_persona + self.persona_v3_config = v3_persona_config + self.selected_default_persona = Persona( + persona_id=selected_default_persona["name"], + system_prompt=selected_default_persona["prompt"], + begin_dialogs=selected_default_persona["begin_dialogs"], + tools=selected_default_persona["tools"] or None, + ) + + return v3_persona_config, personas_v3, selected_default_persona diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index 0b9d9e533..932c5d5c2 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -55,7 +55,7 @@ class PipelineContext: handler: T.Awaitable, *args, **kwargs, - ) -> T.AsyncGenerator[None, None]: + ) -> T.AsyncGenerator[T.Any, None]: """执行事件处理函数并处理其返回结果 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数: @@ -77,10 +77,16 @@ class PipelineContext: try: ready_to_call = handler(event, *args, **kwargs) except TypeError as _: + if self.plugin_manager: + context = self.plugin_manager.context + else: + raise ValueError( + "Cannot call handler without a valid context or plugin manager." + ) # 向下兼容 trace_ = traceback.format_exc() # 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份 - ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs) + ready_to_call = handler(event, context, *args, **kwargs) if inspect.isasyncgen(ready_to_call): _has_yielded = False diff --git a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py index c2961ded5..cd705275e 100644 --- a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py +++ b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py @@ -21,6 +21,7 @@ from mcp.types import ( EmbeddedResource, TextResourceContents, BlobResourceContents, + CallToolResult, ) from astrbot.core.star.star_handler import EventType from astrbot import logger @@ -193,50 +194,25 @@ class ToolLoopAgent(BaseAgentRunner): if not req.func_tool: return func_tool = req.func_tool.get_func(func_tool_name) - if func_tool.origin == "mcp": - logger.info( - f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}" - ) - client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name] - res = await client.session.call_tool(func_tool.name, func_tool_args) - if not res: - continue - if isinstance(res.content[0], TextContent): - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=res.content[0].text, - ) - ) - yield MessageChain().message(res.content[0].text) - elif isinstance(res.content[0], ImageContent): - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="返回了图片(已直接发送给用户)", - ) - ) - yield MessageChain(type="tool_direct_result").base64_image( - res.content[0].data - ) - elif isinstance(res.content[0], EmbeddedResource): - resource = res.content[0].resource - if isinstance(resource, TextResourceContents): + logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}") + executor = func_tool.execute( + event=self.event, + pipeline_context=self.pipeline_ctx, + **func_tool_args, + ) + async for resp in executor: + if isinstance(resp, CallToolResult): + res = resp + if isinstance(res.content[0], TextContent): tool_call_result_blocks.append( ToolCallMessageSegment( role="tool", tool_call_id=func_tool_id, - content=resource.text, + content=res.content[0].text, ) ) - yield MessageChain().message(resource.text) - elif ( - isinstance(resource, BlobResourceContents) - and resource.mimeType - and resource.mimeType.startswith("image/") - ): + yield MessageChain().message(res.content[0].text) + elif isinstance(res.content[0], ImageContent): tool_call_result_blocks.append( ToolCallMessageSegment( role="tool", @@ -247,41 +223,54 @@ class ToolLoopAgent(BaseAgentRunner): yield MessageChain(type="tool_direct_result").base64_image( res.content[0].data ) - else: - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content="返回的数据类型不受支持", - ) - ) - yield MessageChain().message("返回的数据类型不受支持。") - else: - logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}") - # 尝试调用工具函数 - wrapper = self.pipeline_ctx.call_handler( - self.event, func_tool.handler, **func_tool_args - ) - async for resp in wrapper: - if resp is not None: - # Tool 返回结果 - tool_call_result_blocks.append( - ToolCallMessageSegment( - role="tool", - tool_call_id=func_tool_id, - content=resp, - ) - ) - yield MessageChain().message(resp) - else: - # Tool 直接请求发送消息给用户 - # 这里我们将直接结束 Agent Loop。 - self._transition_state(AgentState.DONE) - if res := self.event.get_result(): - if res.chain: - yield MessageChain( - chain=res.chain, type="tool_direct_result" + elif isinstance(res.content[0], EmbeddedResource): + resource = res.content[0].resource + if isinstance(resource, TextResourceContents): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content=resource.text, ) + ) + yield MessageChain().message(resource.text) + elif ( + isinstance(resource, BlobResourceContents) + and resource.mimeType + and resource.mimeType.startswith("image/") + ): + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回了图片(已直接发送给用户)", + ) + ) + yield MessageChain( + type="tool_direct_result" + ).base64_image(res.content[0].data) + else: + tool_call_result_blocks.append( + ToolCallMessageSegment( + role="tool", + tool_call_id=func_tool_id, + content="返回的数据类型不受支持", + ) + ) + yield MessageChain().message("返回的数据类型不受支持。") + elif resp is None: + # Tool 直接请求发送消息给用户 + # 这里我们将直接结束 Agent Loop。 + self._transition_state(AgentState.DONE) + if res := self.event.get_result(): + if res.chain: + yield MessageChain( + chain=res.chain, type="tool_direct_result" + ) + else: + logger.warning( + f"Tool 返回了不支持的类型: {type(resp)},将忽略。" + ) self.event.clear_result() except Exception as e: diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index ae44cf36d..505b18e8e 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -100,7 +100,8 @@ class LLMRequestSubStage(Stage): if not event.message_str.startswith(self.provider_wake_prefix): return req.prompt = event.message_str[len(self.provider_wake_prefix) :] - req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() + # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。 + # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() for comp in event.message_obj.message: if isinstance(comp, Image): image_path = await comp.convert_to_file_path() @@ -274,7 +275,6 @@ class LLMRequestSubStage(Stage): if event.get_platform_name() == "webchat": asyncio.create_task(self._handle_webchat(event, req, provider)) - async def _handle_webchat( self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider ): @@ -307,19 +307,10 @@ class LLMRequestSubStage(Stage): if not title or "" in title: return await self.conv_manager.update_conversation_title( - event.unified_msg_origin, title=title + unified_msg_origin=event.unified_msg_origin, + title=title, + conversation_id=req.conversation.cid, ) - # 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题 - # webchat adapter 中,session_id 的格式是 f"webchat!{username}!{cid}" - # TODO: 优化 WebChat 适配器的对话管理 - if event.session_id: - username, cid = event.session_id.split("!")[1:3] - db_helper = self.ctx.plugin_manager.context._db - db_helper.update_conversation_title( - user_id=username, - cid=cid, - title=title, - ) async def _save_to_history( self, diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 2d120d7f6..d0ba920c1 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -5,7 +5,7 @@ from astrbot.core.utils.io import download_image_by_url from astrbot import logger from dataclasses import dataclass, field from typing import List, Dict, Type -from .func_tool_manager import FuncCall +from .func_tool_manager import FunctionToolManager, ToolSet from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall, @@ -97,7 +97,7 @@ class ProviderRequest: """会话 ID""" image_urls: list[str] = field(default_factory=list) """图片 URL 列表""" - func_tool: FuncCall | None = None + func_tool: FunctionToolManager | ToolSet | None = None """可用的函数工具""" contexts: list[dict] = field(default_factory=list) """上下文。格式与 openai 的上下文格式一致: diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 07a0fbd8f..117a03800 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -1,16 +1,17 @@ from __future__ import annotations import json -import textwrap import os import asyncio import logging from datetime import timedelta +from deprecated import deprecated -from typing import Dict, List, Awaitable, Literal, Any +from typing import Dict, List, Awaitable, Literal, Any, AsyncGenerator from dataclasses import dataclass from typing import Optional from contextlib import AsyncExitStack from astrbot import logger +from astrbot.core import sp from astrbot.core.utils.log_pipe import LogPipe from astrbot.core.utils.astrbot_path import get_astrbot_data_path @@ -28,6 +29,13 @@ except (ModuleNotFoundError, ImportError): "警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。" ) +from typing_extensions import TYPE_CHECKING + +if TYPE_CHECKING: + from astrbot.core.platform.astr_message_event import AstrMessageEvent + from astrbot.core.pipeline.context import PipelineContext + + DEFAULT_MCP_CONFIG = {"mcpServers": {}} SUPPORTED_TYPES = [ @@ -39,6 +47,302 @@ SUPPORTED_TYPES = [ ] # json schema 支持的数据类型 +@dataclass +class FunctionTool: + """A class representing a function tool that can be used in function calling.""" + + name: str + parameters: Dict + description: str + handler: Awaitable = None + """处理函数, 当 origin 为 mcp 时,这个为空""" + handler_module_path: str = None + """处理函数的模块路径,当 origin 为 mcp 时,这个为空 + + 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools + """ + active: bool = True + """是否激活""" + + origin: Literal["local", "mcp"] = "local" + """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务""" + + # MCP 相关字段 + mcp_server_name: str = None + """MCP 服务名称,当 origin 为 mcp 时有效""" + mcp_client: MCPClient = None + """MCP 客户端,当 origin 为 mcp 时有效""" + + def __repr__(self): + return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})" + + async def execute( + self, + event: AstrMessageEvent = None, + pipeline_context: "PipelineContext" = None, + **tool_args, + ) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]: + """执行函数调用。 + + Args: + event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。 + pipeline_context (PipelineContext): 流水线调度器上下文, 当 origin 为 local 时必须提供。 + **kwargs: 函数调用的参数。 + + Returns: + AsyncGenerator[None | mcp.types.CallToolResult, None] + """ + if self.origin == "local": + if not event: + raise ValueError("Event must be provided for local function tools.") + wrapper = pipeline_context.call_handler( + event=event, + handler=self.handler, + **tool_args, + ) + async for resp in wrapper: + if resp is not None: + text_content = mcp.types.TextContent( + type="text", + text=str(resp), + ) + yield mcp.types.CallToolResult(content=[text_content]) + else: + # NOTE: Tool 在这里直接请求发送消息给用户 + # TODO: 是否需要判断 event.get_result() 是否为空? + # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容" + yield None + + elif self.origin == "mcp": + if not self.mcp_client: + raise ValueError("MCP client is not available for MCP function tools.") + res = await self.mcp_client.session.call_tool( + name=self.name, + arguments=tool_args, + ) + if not res: + return + yield res + + else: + raise Exception(f"Unknown function origin: {self.origin}") + + def __dict__(self) -> dict[str, Any]: + """将 FunctionTool 转换为字典格式""" + return { + "name": self.name, + "parameters": self.parameters, + "description": self.description, + "active": self.active, + "origin": self.origin, + "mcp_server_name": self.mcp_server_name, + } + + +# alias for FunctionTool +FuncTool = FunctionTool + + +class ToolSet: + """A set of function tools that can be used in function calling. + + This class provides methods to add, remove, and retrieve tools, as well as + convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).""" + + def __init__(self, tools: List[FunctionTool] = None): + self.tools: List[FunctionTool] = tools or [] + + def empty(self) -> bool: + """Check if the tool set is empty.""" + return len(self.tools) == 0 + + def add_tool(self, tool: FunctionTool): + """Add a tool to the set.""" + # 检查是否已存在同名工具 + for i, existing_tool in enumerate(self.tools): + if existing_tool.name == tool.name: + self.tools[i] = tool + return + self.tools.append(tool) + + def remove_tool(self, name: str): + """Remove a tool by its name.""" + self.tools = [tool for tool in self.tools if tool.name != name] + + def get_tool(self, name: str) -> Optional[FunctionTool]: + """Get a tool by its name.""" + for tool in self.tools: + if tool.name == name: + return tool + return None + + @deprecated(reason="Use add_tool() instead", version="4.0.0") + def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable): + """Add a function tool to the set.""" + params = { + "type": "object", # hard-coded here + "properties": {}, + } + for param in func_args: + params["properties"][param["name"]] = { + "type": param["type"], + "description": param["description"], + } + _func = FunctionTool( + name=name, + parameters=params, + description=desc, + handler=handler, + ) + self.add_tool(_func) + + @deprecated(reason="Use remove_tool() instead", version="4.0.0") + def remove_func(self, name: str): + """Remove a function tool by its name.""" + self.remove_tool(name) + + @deprecated(reason="Use get_tool() instead", version="4.0.0") + def get_func(self, name: str) -> List[FunctionTool]: + """Get all function tools.""" + return self.get_tool(name) + + def openai_schema(self, omit_empty_parameters: bool = False) -> List[Dict]: + """Convert tools to OpenAI API function calling schema format.""" + result = [] + for tool in self.tools: + func_def = { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + }, + } + + if tool.parameters.get("properties") or not omit_empty_parameters: + func_def["function"]["parameters"] = tool.parameters + + result.append(func_def) + return result + + def anthropic_schema(self) -> List[Dict]: + """Convert tools to Anthropic API format.""" + result = [] + for tool in self.tools: + tool_def = { + "name": tool.name, + "description": tool.description, + "input_schema": { + "type": "object", + "properties": tool.parameters.get("properties", {}), + "required": tool.parameters.get("required", []), + }, + } + result.append(tool_def) + return result + + def google_schema(self) -> Dict: + """Convert tools to Google GenAI API format.""" + + def convert_schema(schema: dict) -> dict: + """Convert schema to Gemini API format.""" + supported_types = { + "string", + "number", + "integer", + "boolean", + "array", + "object", + "null", + } + supported_formats = { + "string": {"enum", "date-time"}, + "integer": {"int32", "int64"}, + "number": {"float", "double"}, + } + + if "anyOf" in schema: + return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]} + + result = {} + + if "type" in schema and schema["type"] in supported_types: + result["type"] = schema["type"] + if "format" in schema and schema["format"] in supported_formats.get( + result["type"], set() + ): + result["format"] = schema["format"] + else: + result["type"] = "null" + + support_fields = { + "title", + "description", + "enum", + "minimum", + "maximum", + "maxItems", + "minItems", + "nullable", + "required", + } + result.update({k: schema[k] for k in support_fields if k in schema}) + + if "properties" in schema: + properties = {} + for key, value in schema["properties"].items(): + prop_value = convert_schema(value) + if "default" in prop_value: + del prop_value["default"] + properties[key] = prop_value + + if properties: + result["properties"] = properties + + if "items" in schema: + result["items"] = convert_schema(schema["items"]) + + return result + + tools = [ + { + "name": tool.name, + "description": tool.description, + "parameters": convert_schema(tool.parameters), + } + for tool in self.tools + ] + + declarations = {} + if tools: + declarations["function_declarations"] = tools + return declarations + + @deprecated(reason="Use openai_schema() instead", version="4.0.0") + def get_func_desc_openai_style(self, omit_empty_parameters: bool = False): + return self.openai_schema(omit_empty_parameters) + + @deprecated(reason="Use anthropic_schema() instead", version="4.0.0") + def get_func_desc_anthropic_style(self): + return self.anthropic_schema() + + @deprecated(reason="Use google_schema() instead", version="4.0.0") + def get_func_desc_google_genai_style(self): + return self.google_schema() + + def names(self) -> List[str]: + """获取所有工具的名称列表""" + return [tool.name for tool in self.tools] + + def __len__(self): + return len(self.tools) + + def __bool__(self): + return len(self.tools) > 0 + + def __iter__(self): + return iter(self.tools) + + def _prepare_config(config: dict) -> dict: """准备配置,处理嵌套格式""" if "mcpServers" in config and config["mcpServers"]: @@ -105,55 +409,6 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: return False, f"{e!s}" -@dataclass -class FuncTool: - """ - 用于描述一个函数调用工具。 - """ - - name: str - parameters: Dict - description: str - handler: Awaitable = None - """处理函数, 当 origin 为 mcp 时,这个为空""" - handler_module_path: str = None - """处理函数的模块路径,当 origin 为 mcp 时,这个为空 - - 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools - """ - active: bool = True - """是否激活""" - - origin: Literal["local", "mcp"] = "local" - """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务""" - - # MCP 相关字段 - mcp_server_name: str = None - """MCP 服务名称,当 origin 为 mcp 时有效""" - mcp_client: MCPClient = None - """MCP 客户端,当 origin 为 mcp 时有效""" - - def __repr__(self): - return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})" - - async def execute(self, **args) -> Any: - """执行函数调用""" - if self.origin == "local": - if not self.handler: - raise Exception(f"Local function {self.name} has no handler") - return await self.handler(**args) - elif self.origin == "mcp": - if not self.mcp_client or not self.mcp_client.session: - raise Exception(f"MCP client for {self.name} is not available") - # 使用name属性而不是额外的mcp_tool_name - actual_tool_name = ( - self.name.split(":")[-1] if ":" in self.name else self.name - ) - return await self.mcp_client.session.call_tool(actual_tool_name, args) - else: - raise Exception(f"Unknown function origin: {self.origin}") - - class MCPClient: def __init__(self): # Initialize session and client objects @@ -276,10 +531,9 @@ class MCPClient: self.running_event.set() # Set the running event to indicate cleanup is done -class FuncCall: +class FunctionToolManager: def __init__(self) -> None: self.func_list: List[FuncTool] = [] - """内部加载的 func tools""" self.mcp_client_dict: Dict[str, MCPClient] = {} """MCP 服务列表""" self.mcp_client_event: Dict[str, asyncio.Event] = {} @@ -331,11 +585,15 @@ class FuncCall: self.func_list.pop(i) break - def get_func(self, name) -> FuncTool: + def get_func(self, name) -> FuncTool | None: for f in self.func_list: if f.name == name: return f - return None + + def get_full_tool_set(self) -> ToolSet: + """获取完整工具集""" + tool_set = ToolSet(self.func_list.copy()) + return tool_set async def init_mcp_clients(self) -> None: """从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下: @@ -556,203 +814,68 @@ class FuncCall: """ 获得 OpenAI API 风格的**已经激活**的工具描述 """ - _l = [] - # 处理所有工具(包括本地和MCP工具) - for f in self.func_list: - if not f.active: - continue - func_ = { - "type": "function", - "function": { - "name": f.name, - # "parameters": f.parameters, - "description": f.description, - }, - } - func_["function"]["parameters"] = f.parameters - if not f.parameters.get("properties") and omit_empty_parameter_field: - # 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段 - del func_["function"]["parameters"] - _l.append(func_) - return _l + tools = [f for f in self.func_list if f.active] + toolset = ToolSet(tools) + return toolset.openai_schema(omit_empty_parameters=omit_empty_parameter_field) def get_func_desc_anthropic_style(self) -> list: """ 获得 Anthropic API 风格的**已经激活**的工具描述 """ - tools = [] - for f in self.func_list: - if not f.active: - continue - - # Convert internal format to Anthropic style - tool = { - "name": f.name, - "description": f.description, - "input_schema": { - "type": "object", - "properties": f.parameters.get("properties", {}), - # Keep the required field from the original parameters if it exists - "required": f.parameters.get("required", []), - }, - } - tools.append(tool) - return tools + tools = [f for f in self.func_list if f.active] + toolset = ToolSet(tools) + return toolset.anthropic_schema() def get_func_desc_google_genai_style(self) -> dict: """ 获得 Google GenAI API 风格的**已经激活**的工具描述 """ + tools = [f for f in self.func_list if f.active] + toolset = ToolSet(tools) + return toolset.google_schema() - # Gemini API 支持的数据类型和格式 - supported_types = { - "string", - "number", - "integer", - "boolean", - "array", - "object", - "null", - } - supported_formats = { - "string": {"enum", "date-time"}, - "integer": {"int32", "int64"}, - "number": {"float", "double"}, - } + def deactivate_llm_tool(self, name: str) -> bool: + """停用一个已经注册的函数调用工具。 - def convert_schema(schema: dict) -> dict: - """转换 schema 为 Gemini API 格式""" + Returns: + 如果没找到,会返回 False""" + func_tool = self.get_func(name) + if func_tool is not None: + func_tool.active = False - # 如果 schema 包含 anyOf,则只返回 anyOf 字段 - if "anyOf" in schema: - return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]} + inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) + if name not in inactivated_llm_tools: + inactivated_llm_tools.append(name) + sp.put("inactivated_llm_tools", inactivated_llm_tools) - result = {} + return True + return False - if "type" in schema and schema["type"] in supported_types: - result["type"] = schema["type"] - if "format" in schema and schema["format"] in supported_formats.get( - result["type"], set() - ): - result["format"] = schema["format"] - else: - # 暂时指定默认为null - result["type"] = "null" + # 因为不想解决循环引用,所以这里直接传入 star_map 先了... + def activate_llm_tool(self, name: str, star_map: dict) -> bool: + func_tool = self.get_func(name) + if func_tool is not None: + if func_tool.handler_module_path in star_map: + if not star_map[func_tool.handler_module_path].activated: + raise ValueError( + f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。" + ) - support_fields = { - "title", - "description", - "enum", - "minimum", - "maximum", - "maxItems", - "minItems", - "nullable", - "required", - } - result.update({k: schema[k] for k in support_fields if k in schema}) + func_tool.active = True - if "properties" in schema: - properties = {} - for key, value in schema["properties"].items(): - prop_value = convert_schema(value) - if "default" in prop_value: - del prop_value["default"] - properties[key] = prop_value + inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) + if name in inactivated_llm_tools: + inactivated_llm_tools.remove(name) + sp.put("inactivated_llm_tools", inactivated_llm_tools) - if properties: # 只在有非空属性时添加 - result["properties"] = properties - - if "items" in schema: - result["items"] = convert_schema(schema["items"]) - - return result - - tools = [ - { - "name": f.name, - "description": f.description, - **({"parameters": convert_schema(f.parameters)}), - } - for f in self.func_list - if f.active - ] - - declarations = {} - if tools: - declarations["function_declarations"] = tools - return declarations - - async def func_call(self, question: str, session_id: str, provider) -> tuple: - _l = [] - for f in self.func_list: - if not f.active: - continue - _l.append( - { - "name": f.name, - "parameters": f.parameters, - "description": f.description, - } - ) - func_definition = json.dumps(_l, ensure_ascii=False) - - prompt = textwrap.dedent(f""" - ROLE: - 你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。 - - TOOLS: - 可用的函数列表: - - {func_definition} - - LIMIT: - 1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。 - 2. 你的 Json 返回的格式如下:`[{{"name": "", "args": }}, ...]`。参数根据上面提供的函数列表中的参数来填写。 - 3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。 - 4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。 - - EXAMPLE: - 1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}] - - 用户的提问是:{question} - """) - - _c = 0 - while _c < 3: - try: - res = await provider.text_chat(prompt, session_id) - if res.find("```") != -1: - res = res[res.find("```json") + 7 : res.rfind("```")] - res = json.loads(res) - break - except Exception as e: - _c += 1 - if _c == 3: - raise e - if "The message you submitted was too long" in str(e): - raise e - - if "res" in res and not res["res"]: - return "", False - - tool_call_result = [] - for tool in res: - # 说明有函数调用 - func_name = tool["name"] - args = tool["args"] - # 调用函数 - func_tool = self.get_func(func_name) - if not func_tool: - raise Exception(f"Request function {func_name} not found.") - - ret = await func_tool.execute(**args) - if ret: - tool_call_result.append(str(ret)) - return tool_call_result, True + return True + return False def __str__(self): return str(self.func_list) def __repr__(self): return str(self.func_list) + + +FuncCall = FunctionToolManager diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 370c5322b..7d7779e81 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -7,85 +7,27 @@ from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.db import BaseDatabase from .entities import ProviderType -from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider +from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider from .register import llm_tools, provider_cls_map +from ..persona_mgr import PersonaManager class ProviderManager: - def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase): + def __init__( + self, + config: AstrBotConfig, + db_helper: BaseDatabase, + persona_mgr: PersonaManager, + ): + self.persona_mgr = persona_mgr + self.astrbot_config = config self.providers_config: List = config["provider"] self.provider_settings: dict = config["provider_settings"] self.provider_stt_settings: dict = config.get("provider_stt_settings", {}) self.provider_tts_settings: dict = config.get("provider_tts_settings", {}) - self.persona_configs: list = config.get("persona", []) - self.astrbot_config = config - # 人格情景管理 - # 目前没有拆成独立的模块 - self.default_persona_name = self.provider_settings.get( - "default_personality", "default" - ) - self.personas: List[Personality] = [] - self.selected_default_persona = None - for persona in self.persona_configs: - begin_dialogs = persona.get("begin_dialogs", []) - mood_imitation_dialogs = persona.get("mood_imitation_dialogs", []) - bd_processed = [] - mid_processed = "" - if begin_dialogs: - if len(begin_dialogs) % 2 != 0: - logger.error( - f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。" - ) - begin_dialogs = [] - user_turn = True - for dialog in begin_dialogs: - bd_processed.append( - { - "role": "user" if user_turn else "assistant", - "content": dialog, - "_no_save": None, # 不持久化到 db - } - ) - user_turn = not user_turn - if mood_imitation_dialogs: - if len(mood_imitation_dialogs) % 2 != 0: - logger.error( - f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。" - ) - mood_imitation_dialogs = [] - user_turn = True - for dialog in mood_imitation_dialogs: - role = "A" if user_turn else "B" - mid_processed += f"{role}: {dialog}\n" - if not user_turn: - mid_processed += "\n" - user_turn = not user_turn - - try: - persona = Personality( - **persona, - _begin_dialogs_processed=bd_processed, - _mood_imitation_dialogs_processed=mid_processed, - ) - if persona["name"] == self.default_persona_name: - self.selected_default_persona = persona - self.personas.append(persona) - except Exception as e: - logger.error(f"解析 Persona 配置失败:{e}") - - if not self.selected_default_persona and len(self.personas) > 0: - # 默认选择第一个 - self.selected_default_persona = self.personas[0] - - if not self.selected_default_persona: - self.selected_default_persona = Personality( - prompt="You are a helpful and friendly assistant.", - name="default", - _begin_dialogs_processed=[], - _mood_imitation_dialogs_processed="", - ) - self.personas.append(self.selected_default_persona) + # 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager + self.default_persona_name = persona_mgr.default_persona self.provider_insts: List[Provider] = [] """加载的 Provider 的实例""" @@ -113,6 +55,21 @@ class ProviderManager: if kdb_cfg and len(kdb_cfg): self.curr_kdb_name = list(kdb_cfg.keys())[0] + @property + def persona_configs(self) -> list: + """动态获取最新的 persona 配置""" + return self.persona_mgr.persona_v3_config + + @property + def personas(self) -> list: + """动态获取最新的 personas 列表""" + return self.persona_mgr.personas_v3 + + @property + def selected_default_persona(self): + """动态获取最新的默认选中 persona""" + return self.persona_mgr.selected_default_persona_v3 + async def set_provider( self, provider_id: str, provider_type: ProviderType, umo: str = None ): diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 36401b089..2b4f81bb4 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -1,23 +1,13 @@ import abc from typing import List -from typing import TypedDict, AsyncGenerator -from astrbot.core.provider.func_tool_manager import FuncCall +from typing import AsyncGenerator +from astrbot.core.provider.func_tool_manager import FunctionToolManager, ToolSet from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType from astrbot.core.provider.register import provider_cls_map +from astrbot.core.db.po import Personality from dataclasses import dataclass -class Personality(TypedDict): - prompt: str = "" - name: str = "" - begin_dialogs: List[str] = [] - mood_imitation_dialogs: List[str] = [] - - # cache - _begin_dialogs_processed: List[dict] = [] - _mood_imitation_dialogs_processed: str = "" - - @dataclass class ProviderMeta: id: str @@ -90,7 +80,7 @@ class Provider(AbstractProvider): prompt: str, session_id: str = None, image_urls: list[str] = None, - func_tool: FuncCall = None, + func_tool: FunctionToolManager | ToolSet = None, contexts: list = None, system_prompt: str = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, @@ -119,7 +109,7 @@ class Provider(AbstractProvider): prompt: str, session_id: str = None, image_urls: list[str] = None, - func_tool: FuncCall = None, + func_tool: FunctionToolManager | ToolSet = None, contexts: list = None, system_prompt: str = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 6d89da57b..76cdea062 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -11,12 +11,14 @@ from astrbot.core.provider.provider import ( from astrbot.core.provider.entities import ProviderType from astrbot.core.db import BaseDatabase from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.provider.func_tool_manager import FuncCall +from astrbot.core.provider.func_tool_manager import FunctionToolManager from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.manager import ProviderManager from astrbot.core.platform import Platform from astrbot.core.platform.manager import PlatformManager +from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager +from astrbot.core.persona_mgr import PersonaManager from .star import star_registry, StarMetadata, star_map from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType from .filter.command import CommandFilter @@ -29,6 +31,11 @@ from astrbot.core.star.filter.platform_adapter_type import ( ) from deprecated import deprecated +from typing_extensions import TYPE_CHECKING + +if TYPE_CHECKING: + from astrbot.core.pipeline.context import PipelineContext + class Context: """ @@ -50,6 +57,8 @@ class Context: registered_web_apis: list = [] + pipeline_ctx: "PipelineContext" = None + # back compatibility _register_tasks: List[Awaitable] = [] _star_manager = None @@ -62,6 +71,8 @@ class Context: provider_manager: ProviderManager = None, platform_manager: PlatformManager = None, conversation_manager: ConversationManager = None, + message_history_manager: PlatformMessageHistoryManager = None, + persona_manager: PersonaManager = None, ): self._event_queue = event_queue self._config = config @@ -69,6 +80,8 @@ class Context: self.provider_manager = provider_manager self.platform_manager = platform_manager self.conversation_manager = conversation_manager + self.message_history_manager = message_history_manager + self.persona_manager = persona_manager def get_registered_star(self, star_name: str) -> StarMetadata: """根据插件名获取插件的 Metadata""" @@ -80,7 +93,7 @@ class Context: """获取当前载入的所有插件 Metadata 的列表""" return star_registry - def get_llm_tool_manager(self) -> FuncCall: + def get_llm_tool_manager(self) -> FunctionToolManager: """获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools""" return self.provider_manager.llm_tools @@ -90,40 +103,14 @@ class Context: Returns: 如果没找到,会返回 False """ - func_tool = self.provider_manager.llm_tools.get_func(name) - if func_tool is not None: - if func_tool.handler_module_path in star_map: - if not star_map[func_tool.handler_module_path].activated: - raise ValueError( - f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。" - ) - - func_tool.active = True - - inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) - if name in inactivated_llm_tools: - inactivated_llm_tools.remove(name) - sp.put("inactivated_llm_tools", inactivated_llm_tools) - - return True - return False + return self.provider_manager.llm_tools.activate_llm_tool(name, star_map) def deactivate_llm_tool(self, name: str) -> bool: """停用一个已经注册的函数调用工具。 Returns: 如果没找到,会返回 False""" - func_tool = self.provider_manager.llm_tools.get_func(name) - if func_tool is not None: - func_tool.active = False - - inactivated_llm_tools: list = sp.get("inactivated_llm_tools", []) - if name not in inactivated_llm_tools: - inactivated_llm_tools.append(name) - sp.put("inactivated_llm_tools", inactivated_llm_tools) - - return True - return False + return self.provider_manager.llm_tools.deactivate_llm_tool(name) def register_provider(self, provider: Provider): """ @@ -208,7 +195,9 @@ class Context: return self._event_queue @deprecated(version="4.0.0", reason="Use get_platform_inst instead") - def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform | None: + def get_platform( + self, platform_type: Union[PlatformAdapterType, str] + ) -> Platform | None: """ 获取指定类型的平台适配器。 @@ -268,6 +257,9 @@ class Context: return True return False + def get_pipeline_context(self) -> "PipelineContext": + return self.pipeline_ctx + """ 以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。 """ diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index 8d08b9d53..ef2fa3e86 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -6,11 +6,11 @@ from .stat import StatRoute from .log import LogRoute from .static_file import StaticFileRoute from .chat import ChatRoute -from .tools import ToolsRoute # 导入新的ToolsRoute +from .tools import ToolsRoute from .conversation import ConversationRoute from .file import FileRoute from .session_management import SessionManagementRoute - +from .persona import PersonaRoute __all__ = [ "AuthRoute", @@ -25,4 +25,5 @@ __all__ = [ "ConversationRoute", "FileRoute", "SessionManagementRoute", + "PersonaRoute", ] diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 7de720a38..11b474860 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -169,7 +169,6 @@ class ConfigRoute(Route): "/config/provider/new": ("POST", self.post_new_provider), "/config/provider/update": ("POST", self.post_update_provider), "/config/provider/delete": ("POST", self.post_delete_provider), - "/config/llmtools": ("GET", self.get_llm_tools), "/config/provider/check_one": ("GET", self.check_one_provider_status), "/config/provider/list": ("GET", self.get_provider_config_list), "/config/provider/model_list": ("GET", self.get_provider_model_list), @@ -509,12 +508,6 @@ class ConfigRoute(Route): return Response().error(str(e)).__dict__ return Response().ok(None, "删除成功,已经实时生效~").__dict__ - async def get_llm_tools(self): - """获取函数调用工具。包含了本地加载的以及 MCP 服务的工具""" - tool_mgr = self.core_lifecycle.provider_manager.llm_tools - tools = tool_mgr.get_func_desc_openai_style() - return Response().ok(tools).__dict__ - async def _get_astrbot_config(self): config = self.config diff --git a/astrbot/dashboard/routes/persona.py b/astrbot/dashboard/routes/persona.py new file mode 100644 index 000000000..032471ee4 --- /dev/null +++ b/astrbot/dashboard/routes/persona.py @@ -0,0 +1,199 @@ +import traceback +from .route import Route, Response, RouteContext +from astrbot.core import logger +from quart import request +from astrbot.core.db import BaseDatabase +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + + +class PersonaRoute(Route): + def __init__( + self, + context: RouteContext, + db_helper: BaseDatabase, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: + super().__init__(context) + self.routes = { + "/persona/list": ("GET", self.list_personas), + "/persona/detail": ("POST", self.get_persona_detail), + "/persona/create": ("POST", self.create_persona), + "/persona/update": ("POST", self.update_persona), + "/persona/delete": ("POST", self.delete_persona), + } + self.db_helper = db_helper + self.persona_mgr = core_lifecycle.persona_mgr + self.register_routes() + + async def list_personas(self): + """获取所有人格列表""" + try: + personas = await self.persona_mgr.get_all_personas() + return ( + Response() + .ok( + [ + { + "persona_id": persona.persona_id, + "system_prompt": persona.system_prompt, + "begin_dialogs": persona.begin_dialogs or [], + "tools": persona.tools, + "created_at": persona.created_at.isoformat() + if persona.created_at + else None, + "updated_at": persona.updated_at.isoformat() + if persona.updated_at + else None, + } + for persona in personas + ] + ) + .__dict__ + ) + except Exception as e: + logger.error(f"获取人格列表失败: {str(e)}\n{traceback.format_exc()}") + return Response().error(f"获取人格列表失败: {str(e)}").__dict__ + + async def get_persona_detail(self): + """获取指定人格的详细信息""" + try: + data = await request.get_json() + persona_id = data.get("persona_id") + + if not persona_id: + return Response().error("缺少必要参数: persona_id").__dict__ + + persona = await self.persona_mgr.get_persona(persona_id) + if not persona: + return Response().error("人格不存在").__dict__ + + return ( + Response() + .ok( + { + "persona_id": persona.persona_id, + "system_prompt": persona.system_prompt, + "begin_dialogs": persona.begin_dialogs or [], + "tools": persona.tools, + "created_at": persona.created_at.isoformat() + if persona.created_at + else None, + "updated_at": persona.updated_at.isoformat() + if persona.updated_at + else None, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(f"获取人格详情失败: {str(e)}\n{traceback.format_exc()}") + return Response().error(f"获取人格详情失败: {str(e)}").__dict__ + + async def create_persona(self): + """创建新人格""" + try: + data = await request.get_json() + persona_id = data.get("persona_id", "").strip() + system_prompt = data.get("system_prompt", "").strip() + begin_dialogs = data.get("begin_dialogs", []) + tools = data.get("tools") + + if not persona_id: + return Response().error("人格ID不能为空").__dict__ + + if not system_prompt: + return Response().error("系统提示词不能为空").__dict__ + + # 验证 begin_dialogs 格式 + if begin_dialogs and len(begin_dialogs) % 2 != 0: + return ( + Response() + .error("预设对话数量必须为偶数(用户和助手轮流对话)") + .__dict__ + ) + + persona = await self.persona_mgr.create_persona( + persona_id=persona_id, + system_prompt=system_prompt, + begin_dialogs=begin_dialogs if begin_dialogs else None, + tools=tools if tools else None, + ) + + return ( + Response() + .ok( + { + "message": "人格创建成功", + "persona": { + "persona_id": persona.persona_id, + "system_prompt": persona.system_prompt, + "begin_dialogs": persona.begin_dialogs or [], + "tools": persona.tools or [], + "created_at": persona.created_at.isoformat() + if persona.created_at + else None, + "updated_at": persona.updated_at.isoformat() + if persona.updated_at + else None, + }, + } + ) + .__dict__ + ) + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"创建人格失败: {str(e)}\n{traceback.format_exc()}") + return Response().error(f"创建人格失败: {str(e)}").__dict__ + + async def update_persona(self): + """更新人格信息""" + try: + data = await request.get_json() + persona_id = data.get("persona_id") + system_prompt = data.get("system_prompt") + begin_dialogs = data.get("begin_dialogs") + tools = data.get("tools") + + if not persona_id: + return Response().error("缺少必要参数: persona_id").__dict__ + + # 验证 begin_dialogs 格式 + if begin_dialogs is not None and len(begin_dialogs) % 2 != 0: + return ( + Response() + .error("预设对话数量必须为偶数(用户和助手轮流对话)") + .__dict__ + ) + + await self.persona_mgr.update_persona( + persona_id=persona_id, + system_prompt=system_prompt, + begin_dialogs=begin_dialogs, + tools=tools, + ) + + return Response().ok({"message": "人格更新成功"}).__dict__ + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"更新人格失败: {str(e)}\n{traceback.format_exc()}") + return Response().error(f"更新人格失败: {str(e)}").__dict__ + + async def delete_persona(self): + """删除人格""" + try: + data = await request.get_json() + persona_id = data.get("persona_id") + + if not persona_id: + return Response().error("缺少必要参数: persona_id").__dict__ + + await self.persona_mgr.delete_persona(persona_id) + + return Response().ok({"message": "人格删除成功"}).__dict__ + except ValueError as e: + return Response().error(str(e)).__dict__ + except Exception as e: + logger.error(f"删除人格失败: {str(e)}\n{traceback.format_exc()}") + return Response().error(f"删除人格失败: {str(e)}").__dict__ diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 5dad2576b..324fc62d3 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -8,6 +8,7 @@ from quart import request from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.star import star_map from .route import Response, Route, RouteContext @@ -27,6 +28,8 @@ class ToolsRoute(Route): "/tools/mcp/delete": ("POST", self.delete_mcp_server), "/tools/mcp/market": ("GET", self.get_mcp_markets), "/tools/mcp/test": ("POST", self.test_mcp_connection), + "/tools/list": ("GET", self.get_tool_list), + "/tools/toggle-tool": ("POST", self.toggle_tool), } self.register_routes() self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools @@ -336,3 +339,40 @@ class ToolsRoute(Route): except Exception as e: logger.error(traceback.format_exc()) return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__ + + async def get_tool_list(self): + """获取所有注册的工具列表""" + try: + tools = self.tool_mgr.func_list + tools_dict = [tool.__dict__() for tool in tools] + return Response().ok(data=tools_dict).__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"获取工具列表失败: {str(e)}").__dict__ + + async def toggle_tool(self): + """启用或停用指定的工具""" + try: + data = await request.json + tool_name = data.get("name") + action = data.get("activate") # True or False + + if not tool_name or action is None: + return Response().error("缺少必要参数: name 或 action").__dict__ + + if action: + try: + ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map) + except ValueError as e: + return Response().error(f"启用工具失败: {str(e)}").__dict__ + else: + ok = self.tool_mgr.deactivate_llm_tool(tool_name) + + if ok: + return Response().ok(None, "操作成功。").__dict__ + else: + return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__ + + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"操作工具失败: {str(e)}").__dict__ diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 06f6f8e60..e22b20524 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -60,6 +60,9 @@ class AstrBotDashboard: self.session_management_route = SessionManagementRoute( self.context, db, core_lifecycle ) + self.persona_route = PersonaRoute( + self.context, db, core_lifecycle + ) self.app.add_url_rule( "/api/plug/", diff --git a/dashboard/src/i18n/loader.ts b/dashboard/src/i18n/loader.ts index 275110ddc..ec76b2c0c 100644 --- a/dashboard/src/i18n/loader.ts +++ b/dashboard/src/i18n/loader.ts @@ -52,6 +52,7 @@ export class I18nLoader { { name: 'features/alkaid/index', path: 'features/alkaid/index.json' }, { name: 'features/alkaid/knowledge-base', path: 'features/alkaid/knowledge-base.json' }, { name: 'features/alkaid/memory', path: 'features/alkaid/memory.json' }, + { name: 'features/persona', path: 'features/persona.json' }, // 消息模块 { name: 'messages/errors', path: 'messages/errors.json' }, diff --git a/dashboard/src/i18n/locales/en-US/core/navigation.json b/dashboard/src/i18n/locales/en-US/core/navigation.json index 501383020..23276163a 100644 --- a/dashboard/src/i18n/locales/en-US/core/navigation.json +++ b/dashboard/src/i18n/locales/en-US/core/navigation.json @@ -2,6 +2,7 @@ "dashboard": "Dashboard", "platforms": "Platforms", "providers": "Providers", + "persona": "Persona", "toolUse": "MCP Tools", "config": "Config", "extension": "Extensions", diff --git a/dashboard/src/i18n/locales/en-US/features/persona.json b/dashboard/src/i18n/locales/en-US/features/persona.json new file mode 100644 index 000000000..94708ee56 --- /dev/null +++ b/dashboard/src/i18n/locales/en-US/features/persona.json @@ -0,0 +1,67 @@ +{ + "page": { + "description": "Manage and configure chat bot personality settings" + }, + "buttons": { + "create": "Create Persona", + "createFirst": "Create First Persona", + "edit": "Edit", + "delete": "Delete", + "cancel": "Cancel", + "save": "Save", + "addDialogPair": "Add Dialog Pair" + }, + "labels": { + "presetDialogs": "Preset Dialogs ({count} pairs)", + "createdAt": "Created At", + "updatedAt": "Updated At" + }, + "form": { + "personaId": "Persona ID", + "systemPrompt": "System Prompt", + "presetDialogs": "Preset Dialogs", + "presetDialogsHelp": "Add some preset dialogs to help the bot better understand the role settings. The number of dialogs must be even (users and assistants take turns).", + "userMessage": "User Message", + "assistantMessage": "Assistant Message", + "tools": "Tool Selection", + "toolsHelp": "Select available tools for this persona. Tools allow the bot to perform specific functions such as searching, calculating, getting information, etc.", + "toolsSelection": "Tool Selection Actions", + "selectAllTools": "Select All Tools", + "clearAllTools": "Clear Selection", + "allSelected": "All Selected", + "mcpServersQuickSelect": "MCP Servers Quick Select", + "searchTools": "Search Tools", + "selectedTools": "Selected Tools", + "noToolsAvailable": "No tools available", + "noToolsFound": "No matching tools found", + "loadingTools": "Loading tools...", + "allToolsAvailable": "Use all available tools", + "noToolsSelected": "No tools selected" + }, + "dialog": { + "create": { + "title": "Create New Persona" + }, + "edit": { + "title": "Edit Persona" + } + }, + "empty": { + "title": "No Persona Configured", + "description": "Create your first persona to start using personalized chatbots" + }, + "validation": { + "required": "This field is required", + "minLength": "Minimum {min} characters required", + "alphanumeric": "Only letters, numbers, underscores and hyphens are allowed", + "dialogRequired": "{type} cannot be empty" + }, + "messages": { + "loadError": "Failed to load persona list", + "saveSuccess": "Saved successfully", + "saveError": "Save failed", + "deleteConfirm": "Are you sure you want to delete persona \"{id}\"? This action cannot be undone.", + "deleteSuccess": "Deleted successfully", + "deleteError": "Delete failed" + } +} diff --git a/dashboard/src/i18n/locales/en-US/features/tool-use.json b/dashboard/src/i18n/locales/en-US/features/tool-use.json index 96c4760e8..1c4c54e97 100644 --- a/dashboard/src/i18n/locales/en-US/features/tool-use.json +++ b/dashboard/src/i18n/locales/en-US/features/tool-use.json @@ -107,6 +107,9 @@ "failed": "Import configuration failed: {error}" }, "configParseError": "Configuration parse error: {error}", - "noAvailableConfig": "No available configuration" + "noAvailableConfig": "No available configuration", + "toggleToolSuccess": "Tool status toggled successfully!", + "toggleToolError": "Failed to toggle tool status: {error}", + "testError": "Test connection failed: {error}" } } \ No newline at end of file diff --git a/dashboard/src/i18n/locales/zh-CN/core/navigation.json b/dashboard/src/i18n/locales/zh-CN/core/navigation.json index d0ed8453a..ccb4be5a1 100644 --- a/dashboard/src/i18n/locales/zh-CN/core/navigation.json +++ b/dashboard/src/i18n/locales/zh-CN/core/navigation.json @@ -2,6 +2,7 @@ "dashboard": "统计", "platforms": "消息平台", "providers": "服务提供商", + "persona": "人格管理", "toolUse": "MCP", "config": "配置文件", "extension": "插件管理", diff --git a/dashboard/src/i18n/locales/zh-CN/features/persona.json b/dashboard/src/i18n/locales/zh-CN/features/persona.json new file mode 100644 index 000000000..619f1b607 --- /dev/null +++ b/dashboard/src/i18n/locales/zh-CN/features/persona.json @@ -0,0 +1,67 @@ +{ + "page": { + "description": "管理和配置聊天机器人的人格角色设定" + }, + "buttons": { + "create": "创建人格", + "createFirst": "创建第一个人格", + "edit": "编辑", + "delete": "删除", + "cancel": "取消", + "save": "保存", + "addDialogPair": "添加对话对" + }, + "labels": { + "presetDialogs": "预设对话 ({count} 对)", + "createdAt": "创建时间", + "updatedAt": "更新时间" + }, + "form": { + "personaId": "人格 ID", + "systemPrompt": "系统提示词", + "presetDialogs": "预设对话", + "presetDialogsHelp": "添加一些预设的对话来帮助机器人更好地理解角色设定。对话数量必须为偶数(用户和助手轮流对话)。", + "userMessage": "用户消息", + "assistantMessage": "助手消息", + "tools": "工具选择", + "toolsHelp": "为这个人格选择可用的工具。工具可以让机器人执行特定的功能,如搜索、计算、获取信息等。", + "toolsSelection": "工具选择操作", + "selectAllTools": "选择所有工具", + "clearAllTools": "清空选择", + "allSelected": "全选", + "mcpServersQuickSelect": "MCP 服务器快速选择", + "searchTools": "搜索工具", + "selectedTools": "已选择的工具", + "noToolsAvailable": "暂无可用工具", + "noToolsFound": "未找到匹配的工具", + "loadingTools": "正在加载工具...", + "allToolsAvailable": "使用所有可用工具", + "noToolsSelected": "未选择任何工具" + }, + "dialog": { + "create": { + "title": "创建新人格" + }, + "edit": { + "title": "编辑人格" + } + }, + "empty": { + "title": "暂无人格配置", + "description": "创建您的第一个人格来开始使用个性化的聊天机器人" + }, + "validation": { + "required": "此字段为必填项", + "minLength": "最少需要 {min} 个字符", + "alphanumeric": "只能包含字母、数字、下划线和连字符", + "dialogRequired": "{type}不能为空" + }, + "messages": { + "loadError": "加载人格列表失败", + "saveSuccess": "保存成功", + "saveError": "保存失败", + "deleteConfirm": "确定要删除人格 \"{id}\" 吗?此操作不可撤销。", + "deleteSuccess": "删除成功", + "deleteError": "删除失败" + } +} diff --git a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json index 61b8691bc..663488497 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json +++ b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json @@ -107,6 +107,9 @@ "failed": "导入配置失败: {error}" }, "configParseError": "配置解析错误: {error}", - "noAvailableConfig": "无可用配置" + "noAvailableConfig": "无可用配置", + "toggleToolSuccess": "工具状态切换成功!", + "toggleToolError": "工具状态切换失败: {error}", + "testError": "测试连接失败: {error}" } -} \ No newline at end of file +} \ No newline at end of file diff --git a/dashboard/src/i18n/translations.ts b/dashboard/src/i18n/translations.ts index c46ca4d0d..f5ea8c97e 100644 --- a/dashboard/src/i18n/translations.ts +++ b/dashboard/src/i18n/translations.ts @@ -25,6 +25,7 @@ import zhCNDashboard from './locales/zh-CN/features/dashboard.json'; import zhCNAlkaidIndex from './locales/zh-CN/features/alkaid/index.json'; import zhCNAlkaidKnowledgeBase from './locales/zh-CN/features/alkaid/knowledge-base.json'; import zhCNAlkaidMemory from './locales/zh-CN/features/alkaid/memory.json'; +import zhCNPersona from './locales/zh-CN/features/persona.json'; import zhCNErrors from './locales/zh-CN/messages/errors.json'; import zhCNSuccess from './locales/zh-CN/messages/success.json'; @@ -54,6 +55,7 @@ import enUSDashboard from './locales/en-US/features/dashboard.json'; import enUSAlkaidIndex from './locales/en-US/features/alkaid/index.json'; import enUSAlkaidKnowledgeBase from './locales/en-US/features/alkaid/knowledge-base.json'; import enUSAlkaidMemory from './locales/en-US/features/alkaid/memory.json'; +import enUSPersona from './locales/en-US/features/persona.json'; import enUSErrors from './locales/en-US/messages/errors.json'; import enUSSuccess from './locales/en-US/messages/success.json'; @@ -88,7 +90,8 @@ export const translations = { index: zhCNAlkaidIndex, 'knowledge-base': zhCNAlkaidKnowledgeBase, memory: zhCNAlkaidMemory - } + }, + persona: zhCNPersona }, messages: { errors: zhCNErrors, @@ -123,7 +126,8 @@ export const translations = { index: enUSAlkaidIndex, 'knowledge-base': enUSAlkaidKnowledgeBase, memory: enUSAlkaidMemory - } + }, + persona: enUSPersona }, messages: { errors: enUSErrors, diff --git a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts index 43f062d1c..09f0e33e0 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts +++ b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts @@ -63,9 +63,13 @@ const sidebarItem: menu[] = [ icon: 'mdi-account-group', to: '/session-management' }, + { + title: 'core.navigation.persona', + icon: 'mdi-heart', + to: '/persona' + }, { title: 'core.navigation.console', - icon: 'mdi-console', to: '/console' }, diff --git a/dashboard/src/router/MainRoutes.ts b/dashboard/src/router/MainRoutes.ts index 8beff1f26..9706f2b73 100644 --- a/dashboard/src/router/MainRoutes.ts +++ b/dashboard/src/router/MainRoutes.ts @@ -56,6 +56,11 @@ const MainRoutes = { path: '/session-management', component: () => import('@/views/SessionManagementPage.vue') }, + { + name: 'Persona', + path: '/persona', + component: () => import('@/views/PersonaPage.vue') + }, { name: 'Console', path: '/console', diff --git a/dashboard/src/views/PersonaPage.vue b/dashboard/src/views/PersonaPage.vue new file mode 100644 index 000000000..3094a5b09 --- /dev/null +++ b/dashboard/src/views/PersonaPage.vue @@ -0,0 +1,808 @@ + + + + + diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/views/ToolUsePage.vue index 6060088a4..c89634153 100644 --- a/dashboard/src/views/ToolUsePage.vue +++ b/dashboard/src/views/ToolUsePage.vue @@ -405,24 +405,36 @@ + 复选框代表该工具是否被启用。 + + + +
- {{ tool.function.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }} + {{ tool.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }} - {{ formatToolName(tool.function.name) }} + :title="tool.name"> + {{ formatToolName(tool.name) }}
- - {{ tool.function.description }} + + {{ tool.description }}
@@ -434,9 +446,9 @@ mdi-information {{ tm('functionTools.description') }}

-

{{ tool.function.description }}

+

{{ tool.description }}

-