diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py
index 41225e0ee..960ef107a 100644
--- a/astrbot/core/config/default.py
+++ b/astrbot/core/config/default.py
@@ -115,8 +115,7 @@ DEFAULT_CONFIG = {
"log_level": "INFO",
"pip_install_arg": "",
"pypi_index_url": "https://mirrors.aliyun.com/pypi/simple/",
- "knowledge_db": {},
- "persona": [],
+ "persona": [], # deprecated
"timezone": "",
"callback_api_base": "",
}
@@ -1701,43 +1700,6 @@ CONFIG_METADATA_2 = {
},
},
},
- "persona": {
- "description": "人格情景设置",
- "type": "list",
- "config_template": {
- "新人格情景": {
- "name": "",
- "prompt": "",
- "begin_dialogs": [],
- "mood_imitation_dialogs": [],
- }
- },
- "tmpl_display_title": "name",
- "items": {
- "name": {
- "description": "人格名称",
- "type": "string",
- "hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
- },
- "prompt": {
- "description": "设定(系统提示词)",
- "type": "text",
- "hint": "填写人格的身份背景、性格特征、兴趣爱好、个人经历、口头禅等。",
- },
- "begin_dialogs": {
- "description": "预设对话",
- "type": "list",
- "items": {"type": "string"},
- "hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
- },
- "mood_imitation_dialogs": {
- "description": "对话风格模仿",
- "type": "list",
- "items": {"type": "string"},
- "hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
- },
- },
- },
"provider_stt_settings": {
"description": "语音转文本(STT)",
"type": "object",
diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py
index 025bf03e5..28545a238 100644
--- a/astrbot/core/core_lifecycle.py
+++ b/astrbot/core/core_lifecycle.py
@@ -22,6 +22,7 @@ from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.star.context import Context
+from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
@@ -69,13 +70,23 @@ class AstrBotCoreLifecycle:
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
await self.db.initialize()
- await do_migration_v4(self.db, {})
+
+ try:
+ await do_migration_v4(self.db, {}, self.astrbot_config)
+ except Exception as e:
+ logger.error(f"迁移到 v4.0.0 新版本数据格式失败: {e}")
# 初始化事件队列
self.event_queue = Queue()
+ # 初始化人格管理器
+ self.persona_mgr = PersonaManager(self.db, self.astrbot_config)
+ await self.persona_mgr.initialize()
+
# 初始化供应商管理器
- self.provider_manager = ProviderManager(self.astrbot_config, self.db)
+ self.provider_manager = ProviderManager(
+ self.astrbot_config, self.db, self.persona_mgr
+ )
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
@@ -94,6 +105,8 @@ class AstrBotCoreLifecycle:
self.provider_manager,
self.platform_manager,
self.conversation_manager,
+ self.platform_message_history_manager,
+ self.persona_mgr,
)
# 初始化插件管理器
@@ -110,6 +123,7 @@ class AstrBotCoreLifecycle:
PipelineContext(self.astrbot_config, self.plugin_manager)
)
await self.pipeline_scheduler.initialize()
+ self.star_context.pipeline_ctx = self.pipeline_scheduler.ctx
# 初始化更新器
self.astrbot_updator = AstrBotUpdator()
@@ -232,6 +246,9 @@ class AstrBotCoreLifecycle:
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
tasks.append(
- asyncio.create_task(platform_inst.run(), name=f"{platform_inst.meta().id}({platform_inst.meta().name})")
+ asyncio.create_task(
+ platform_inst.run(),
+ name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
+ )
)
return tasks
diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py
index 53fccacfc..00c42505f 100644
--- a/astrbot/core/db/__init__.py
+++ b/astrbot/core/db/__init__.py
@@ -205,6 +205,7 @@ class BaseDatabase(abc.ABC):
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
+ tools: list[str] = None,
) -> Persona:
"""Insert a new persona record."""
...
@@ -219,6 +220,22 @@ class BaseDatabase(abc.ABC):
"""Get all personas for a specific bot."""
...
+ @abc.abstractmethod
+ async def update_persona(
+ self,
+ persona_id: str,
+ system_prompt: str = None,
+ begin_dialogs: list[str] = None,
+ tools: list[str] = None,
+ ) -> Persona | None:
+ """Update a persona's system prompt or begin dialogs."""
+ ...
+
+ @abc.abstractmethod
+ async def delete_persona(self, persona_id: str) -> None:
+ """Delete a persona by its ID."""
+ ...
+
@abc.abstractmethod
async def insert_preference_or_update(self, key: str, value: str) -> Preference:
"""Insert a new preference record."""
diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py
index d4b9c99f9..4eb428153 100644
--- a/astrbot/core/db/migration/helper.py
+++ b/astrbot/core/db/migration/helper.py
@@ -1,11 +1,13 @@
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.db import BaseDatabase
+from astrbot.core.config import AstrBotConfig
from astrbot.api import logger
from .migra_3_to_4 import (
migration_conversation_table,
migration_platform_table,
migration_webchat_data,
+ migration_persona_data,
)
@@ -24,7 +26,9 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
async def do_migration_v4(
- db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
+ db_helper: BaseDatabase,
+ platform_id_map: dict[str, dict[str, str]],
+ astrbot_config: AstrBotConfig,
):
"""
执行数据库迁移
@@ -45,6 +49,9 @@ async def do_migration_v4(
# 执行 WebChat 数据迁移
await migration_webchat_data(db_helper, platform_id_map)
+ # 执行人格数据迁移
+ await migration_persona_data(db_helper, astrbot_config)
+
# 标记迁移完成
await db_helper.insert_preference_or_update("migration_done_v4", "true")
diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py
index b6de3214e..e3eccd292 100644
--- a/astrbot/core/db/migration/migra_3_to_4.py
+++ b/astrbot/core/db/migration/migra_3_to_4.py
@@ -4,12 +4,10 @@ from .. import BaseDatabase
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
from astrbot.core.config.default import DB_PATH
from astrbot.api import logger
+from astrbot.core.config import AstrBotConfig
from astrbot.core.platform.astr_message_event import MessageSesion
from sqlalchemy.ext.asyncio import AsyncSession
-from astrbot.core.db.po import (
- ConversationV2,
- PlatformMessageHistory,
-)
+from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
from sqlalchemy import text
"""
@@ -50,20 +48,21 @@ async def migration_conversation_table(
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
- for conversation in conversations:
+ for idx, conversation in enumerate(conversations):
+ if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
+ progress = int((idx + 1) / total_cnt * 100)
+ if progress % 10 == 0:
+ logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
- logger.warning(
+ logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" not in conv.user_id:
- logger.warning(
- f"跳过 user_id 为 {conv.user_id} 的会话,它可能是 WebChat 的消息历史记录。"
- )
continue
session = MessageSesion.from_str(session_str=conv.user_id)
platform_id = get_platform_id(
@@ -81,13 +80,12 @@ async def migration_conversation_table(
updated_at=datetime.datetime.fromtimestamp(conv.updated_at),
)
dbsession.add(conv_v2)
- if conv_v2:
- logger.info(f"迁移旧会话 {conv.cid} 到新表成功。")
except Exception as e:
logger.error(
f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}",
exc_info=True,
)
+ logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。")
async def migration_platform_table(
@@ -107,7 +105,7 @@ async def migration_platform_table(
platform_stats_v3 = stats.platform
if not platform_stats_v3:
- logger.warning("没有找到旧平台数据,跳过迁移。")
+ logger.info("没有找到旧平台数据,跳过迁移。")
return
first_time_stamp = platform_stats_v3[0].timestamp
@@ -120,7 +118,13 @@ async def migration_platform_table(
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
- for bucket_end in range(start_time, end_time, 3600):
+ total_buckets = (end_time - start_time) // 3600
+ for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
+ if bucket_idx % 500 == 0:
+ progress = int((bucket_idx + 1) / total_buckets * 100)
+ logger.info(
+ f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})"
+ )
cnt = 0
while (
idx < len(platform_stats_v3)
@@ -136,9 +140,6 @@ async def migration_platform_table(
platform_type = get_platform_type(
platform_id_map, platform_stats_v3[idx].name
)
- logger.info(
- f"迁移平台统计数据: {platform_id}, {platform_type}, 时间戳: {bucket_end}, 计数: {cnt}"
- )
try:
await dbsession.execute(
text("""
@@ -161,6 +162,7 @@ async def migration_platform_table(
f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}",
exc_info=True,
)
+ logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。")
async def migration_webchat_data(
@@ -178,20 +180,21 @@ async def migration_webchat_data(
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
- for conversation in conversations:
+ for idx, conversation in enumerate(conversations):
+ if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
+ progress = int((idx + 1) / total_cnt * 100)
+ if progress % 10 == 0:
+ logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
- logger.warning(
+ logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" in conv.user_id:
- logger.warning(
- f"跳过 user_id 为 {conv.user_id} 的会话,它不是 WebChat 的消息历史记录。"
- )
continue
platform_id = "webchat"
history = json.loads(conv.history) if conv.history else []
@@ -206,9 +209,52 @@ async def migration_webchat_data(
)
dbsession.add(new_history)
- logger.info(f"迁移旧 WebChat 会话 {conv.cid} 到新表成功。")
except Exception:
logger.error(
f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败",
exc_info=True,
)
+
+ logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。")
+
+
+async def migration_persona_data(
+ db_helper: BaseDatabase, astrbot_config: AstrBotConfig
+):
+ """
+ 迁移 Persona 数据到新的表中。
+ 旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
+ """
+ v3_persona_config: list[dict] = astrbot_config.get("persona", [])
+ total_personas = len(v3_persona_config)
+ logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...")
+
+ for idx, persona in enumerate(v3_persona_config):
+ if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0:
+ progress = int((idx + 1) / total_personas * 100)
+ if progress % 10 == 0:
+ logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})")
+ try:
+ begin_dialogs = persona.get("begin_dialogs", [])
+ mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
+ mood_prompt = ""
+ user_turn = True
+ for mood_dialog in mood_imitation_dialogs:
+ if user_turn:
+ mood_prompt += f"A: {mood_dialog}\n"
+ else:
+ mood_prompt += f"B: {mood_dialog}\n"
+ user_turn = not user_turn
+ system_prompt = persona.get("prompt", "")
+ if mood_prompt:
+ system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
+ persona_new = await db_helper.insert_persona(
+ persona_id=persona["name"],
+ system_prompt=system_prompt,
+ begin_dialogs=begin_dialogs,
+ )
+ logger.info(
+ f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
+ )
+ except Exception as e:
+ logger.error(f"解析 Persona 配置失败:{e}")
diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py
index 30f7188a1..cbc8c8d4e 100644
--- a/astrbot/core/db/po.py
+++ b/astrbot/core/db/po.py
@@ -9,7 +9,7 @@ from sqlmodel import (
UniqueConstraint,
Field,
)
-from typing import Optional
+from typing import Optional, TypedDict
class PlatformStat(SQLModel, table=True):
@@ -39,12 +39,14 @@ class PlatformStat(SQLModel, table=True):
class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations"
- inner_conversation_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
+ inner_conversation_id: int = Field(
+ primary_key=True, sa_column_kwargs={"autoincrement": True}
+ )
conversation_id: str = Field(
max_length=36,
nullable=False,
unique=True,
- default_factory=lambda: str(uuid.uuid4())
+ default_factory=lambda: str(uuid.uuid4()),
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False)
@@ -78,6 +80,8 @@ class Persona(SQLModel, table=True):
system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
"""a list of strings, each representing a dialog to start with"""
+ tools: Optional[list] = Field(default=None, sa_type=JSON)
+ """None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
@@ -119,7 +123,9 @@ class PlatformMessageHistory(SQLModel, table=True):
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
- sender_name: Optional[str] = Field(default=None) # Name of the sender in the platform
+ sender_name: Optional[str] = Field(
+ default=None
+ ) # Name of the sender in the platform
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
@@ -136,12 +142,14 @@ class Attachment(SQLModel, table=True):
__tablename__ = "attachments"
- inner_attachment_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
+ inner_attachment_id: int = Field(
+ primary_key=True, sa_column_kwargs={"autoincrement": True}
+ )
attachment_id: str = Field(
max_length=36,
nullable=False,
unique=True,
- default_factory=lambda: str(uuid.uuid4())
+ default_factory=lambda: str(uuid.uuid4()),
)
path: str = Field(nullable=False) # Path to the file on disk
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
@@ -182,6 +190,25 @@ class Conversation:
updated_at: int = 0
+class Personality(TypedDict):
+ """LLM 人格类。
+
+ 在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
+ """
+
+ prompt: str = ""
+ name: str = ""
+ begin_dialogs: list[str] = []
+ mood_imitation_dialogs: list[str] = []
+ """情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
+ tools: list[str] | None = None
+ """工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
+
+ # cache
+ _begin_dialogs_processed: list[dict] = []
+ _mood_imitation_dialogs_processed: str = ""
+
+
# ====
# Deprecated, and will be removed in future versions.
# ====
diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py
index 0308807da..0ecf787e5 100644
--- a/astrbot/core/db/sqlite.py
+++ b/astrbot/core/db/sqlite.py
@@ -19,6 +19,7 @@ from sqlalchemy import select, update, delete, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import func
+NOT_GIVEN = T.TypeVar("NOT_GIVEN")
class SQLiteDatabase(BaseDatabase):
def __init__(self, db_path: str) -> None:
@@ -99,9 +100,7 @@ class SQLiteDatabase(BaseDatabase):
# Conversation Management
# ====
- async def get_conversations(
- self, user_id=None, platform_id=None
- ):
+ async def get_conversations(self, user_id=None, platform_id=None):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2)
@@ -224,7 +223,7 @@ class SQLiteDatabase(BaseDatabase):
return
query = query.values(**values)
await session.execute(query)
- return await self.get_conversation_by_id(cid)
+ return await self.get_conversation_by_id(cid)
async def delete_conversation(self, cid):
async with self.get_db() as session:
@@ -312,7 +311,9 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalar_one_or_none()
- async def insert_persona(self, persona_id, system_prompt, begin_dialogs=None):
+ async def insert_persona(
+ self, persona_id, system_prompt, begin_dialogs=None, tools=None
+ ):
"""Insert a new persona record."""
async with self.get_db() as session:
session: AsyncSession
@@ -321,6 +322,7 @@ class SQLiteDatabase(BaseDatabase):
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs or [],
+ tools=tools,
)
session.add(new_persona)
return new_persona
@@ -341,6 +343,36 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalars().all()
+ async def update_persona(
+ self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN
+ ):
+ """Update a persona's system prompt or begin dialogs."""
+ async with self.get_db() as session:
+ session: AsyncSession
+ async with session.begin():
+ query = update(Persona).where(Persona.persona_id == persona_id)
+ values = {}
+ if system_prompt is not None:
+ values["system_prompt"] = system_prompt
+ if begin_dialogs is not None:
+ values["begin_dialogs"] = begin_dialogs
+ if tools is not NOT_GIVEN:
+ values["tools"] = tools
+ if not values:
+ return
+ query = query.values(**values)
+ await session.execute(query)
+ return await self.get_persona_by_id(persona_id)
+
+ async def delete_persona(self, persona_id):
+ """Delete a persona by its ID."""
+ async with self.get_db() as session:
+ session: AsyncSession
+ async with session.begin():
+ await session.execute(
+ delete(Persona).where(Persona.persona_id == persona_id)
+ )
+
async def insert_preference_or_update(self, key, value):
"""Insert a new preference record or update if it exists."""
async with self.get_db() as session:
diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py
new file mode 100644
index 000000000..dc322c789
--- /dev/null
+++ b/astrbot/core/persona_mgr.py
@@ -0,0 +1,162 @@
+from astrbot.core.db import BaseDatabase
+from astrbot.core.db.po import Persona, Personality
+from astrbot.core.config import AstrBotConfig
+from astrbot import logger
+
+
+class PersonaManager:
+ def __init__(self, db_helper: BaseDatabase, astrbot_config: AstrBotConfig):
+ self.db = db_helper
+ self.config = astrbot_config
+ _ps: dict = astrbot_config["provider_settings"]
+ self.default_persona: str = _ps.get("default_personality", "default")
+ self.personas: list[Persona] = []
+ self.selected_default_persona: Persona | None = None
+
+ self.personas_v3: list[Personality] = []
+ self.selected_default_persona_v3: Personality | None = None
+ self.persona_v3_config: list[dict] = []
+
+ async def initialize(self):
+ self.personas = await self.get_all_personas()
+ self.get_v3_persona_data()
+ logger.info(f"已加载 {len(self.personas)} 个人格。")
+
+ async def get_persona(self, persona_id: str):
+ """获取指定 persona 的信息"""
+ persona = await self.db.get_persona_by_id(persona_id)
+ if not persona:
+ raise ValueError(f"Persona with ID {persona_id} does not exist.")
+ return persona
+
+ async def delete_persona(self, persona_id: str):
+ """删除指定 persona"""
+ if not await self.db.get_persona_by_id(persona_id):
+ raise ValueError(f"Persona with ID {persona_id} does not exist.")
+ await self.db.delete_persona(persona_id)
+ self.personas = [p for p in self.personas if p.persona_id != persona_id]
+ self.get_v3_persona_data()
+
+ async def update_persona(
+ self,
+ persona_id: str,
+ system_prompt: str = None,
+ begin_dialogs: list[str] = None,
+ tools: list[str] = None,
+ ):
+ """更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
+ existing_persona = await self.db.get_persona_by_id(persona_id)
+ if not existing_persona:
+ raise ValueError(f"Persona with ID {persona_id} does not exist.")
+ persona = await self.db.update_persona(
+ persona_id, system_prompt, begin_dialogs, tools=tools
+ )
+ if persona:
+ for i, p in enumerate(self.personas):
+ if p.persona_id == persona_id:
+ self.personas[i] = persona
+ break
+ self.get_v3_persona_data()
+ return persona
+
+ async def get_all_personas(self) -> list[Persona]:
+ """获取所有 personas"""
+ return await self.db.get_personas()
+
+ async def create_persona(
+ self,
+ persona_id: str,
+ system_prompt: str,
+ begin_dialogs: list[str] = None,
+ tools: list[str] = None,
+ ) -> Persona:
+ """创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
+ if await self.db.get_persona_by_id(persona_id):
+ raise ValueError(f"Persona with ID {persona_id} already exists.")
+ new_persona = await self.db.insert_persona(
+ persona_id, system_prompt, begin_dialogs, tools=tools
+ )
+ self.personas.append(new_persona)
+ self.get_v3_persona_data()
+ return new_persona
+
+ def get_v3_persona_data(
+ self,
+ ) -> tuple[list[dict], list[Personality], Personality]:
+ """获取 AstrBot <4.0.0 版本的 persona 数据。
+
+ Returns:
+ - list[dict]: 包含 persona 配置的字典列表。
+ - list[Personality]: 包含 Personality 对象的列表。
+ - Personality: 默认选择的 Personality 对象。
+ """
+ v3_persona_config = [
+ {
+ "prompt": persona.system_prompt,
+ "name": persona.persona_id,
+ "begin_dialogs": persona.begin_dialogs or [],
+ "mood_imitation_dialogs": [], # deprecated
+ "tools": persona.tools,
+ }
+ for persona in self.personas
+ ]
+
+ personas_v3: list[Personality] = []
+ selected_default_persona: Personality | None = None
+
+ for persona_cfg in v3_persona_config:
+ begin_dialogs = persona_cfg.get("begin_dialogs", [])
+ bd_processed = []
+ if begin_dialogs:
+ if len(begin_dialogs) % 2 != 0:
+ logger.error(
+ f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
+ )
+ begin_dialogs = []
+ user_turn = True
+ for dialog in begin_dialogs:
+ bd_processed.append(
+ {
+ "role": "user" if user_turn else "assistant",
+ "content": dialog,
+ "_no_save": None, # 不持久化到 db
+ }
+ )
+ user_turn = not user_turn
+
+ try:
+ persona = Personality(
+ **persona_cfg,
+ _begin_dialogs_processed=bd_processed,
+ _mood_imitation_dialogs_processed="", # deprecated
+ )
+ if persona["name"] == self.default_persona:
+ selected_default_persona = persona
+ personas_v3.append(persona)
+ except Exception as e:
+ logger.error(f"解析 Persona 配置失败:{e}")
+
+ if not selected_default_persona and len(personas_v3) > 0:
+ # 默认选择第一个
+ selected_default_persona = personas_v3[0]
+
+ if not selected_default_persona:
+ selected_default_persona = Personality(
+ prompt="You are a helpful and friendly assistant.",
+ name="default",
+ tools=None,
+ _begin_dialogs_processed=[],
+ )
+ personas_v3.append(selected_default_persona)
+
+ self.personas_v3 = personas_v3
+ self.selected_default_persona_v3 = selected_default_persona
+ self.persona_v3_config = v3_persona_config
+ self.selected_default_persona = Persona(
+ persona_id=selected_default_persona["name"],
+ system_prompt=selected_default_persona["prompt"],
+ begin_dialogs=selected_default_persona["begin_dialogs"],
+ tools=selected_default_persona["tools"] or None,
+ )
+
+ return v3_persona_config, personas_v3, selected_default_persona
diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py
index 0b9d9e533..932c5d5c2 100644
--- a/astrbot/core/pipeline/context.py
+++ b/astrbot/core/pipeline/context.py
@@ -55,7 +55,7 @@ class PipelineContext:
handler: T.Awaitable,
*args,
**kwargs,
- ) -> T.AsyncGenerator[None, None]:
+ ) -> T.AsyncGenerator[T.Any, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
@@ -77,10 +77,16 @@ class PipelineContext:
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as _:
+ if self.plugin_manager:
+ context = self.plugin_manager.context
+ else:
+ raise ValueError(
+ "Cannot call handler without a valid context or plugin manager."
+ )
# 向下兼容
trace_ = traceback.format_exc()
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
- ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
+ ready_to_call = handler(event, context, *args, **kwargs)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
diff --git a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py
index c2961ded5..cd705275e 100644
--- a/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py
+++ b/astrbot/core/pipeline/process_stage/agent_runner/tool_loop_agent.py
@@ -21,6 +21,7 @@ from mcp.types import (
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
+ CallToolResult,
)
from astrbot.core.star.star_handler import EventType
from astrbot import logger
@@ -193,50 +194,25 @@ class ToolLoopAgent(BaseAgentRunner):
if not req.func_tool:
return
func_tool = req.func_tool.get_func(func_tool_name)
- if func_tool.origin == "mcp":
- logger.info(
- f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
- )
- client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
- res = await client.session.call_tool(func_tool.name, func_tool_args)
- if not res:
- continue
- if isinstance(res.content[0], TextContent):
- tool_call_result_blocks.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content=res.content[0].text,
- )
- )
- yield MessageChain().message(res.content[0].text)
- elif isinstance(res.content[0], ImageContent):
- tool_call_result_blocks.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content="返回了图片(已直接发送给用户)",
- )
- )
- yield MessageChain(type="tool_direct_result").base64_image(
- res.content[0].data
- )
- elif isinstance(res.content[0], EmbeddedResource):
- resource = res.content[0].resource
- if isinstance(resource, TextResourceContents):
+ logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
+ executor = func_tool.execute(
+ event=self.event,
+ pipeline_context=self.pipeline_ctx,
+ **func_tool_args,
+ )
+ async for resp in executor:
+ if isinstance(resp, CallToolResult):
+ res = resp
+ if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
- content=resource.text,
+ content=res.content[0].text,
)
)
- yield MessageChain().message(resource.text)
- elif (
- isinstance(resource, BlobResourceContents)
- and resource.mimeType
- and resource.mimeType.startswith("image/")
- ):
+ yield MessageChain().message(res.content[0].text)
+ elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
@@ -247,41 +223,54 @@ class ToolLoopAgent(BaseAgentRunner):
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
- else:
- tool_call_result_blocks.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content="返回的数据类型不受支持",
- )
- )
- yield MessageChain().message("返回的数据类型不受支持。")
- else:
- logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
- # 尝试调用工具函数
- wrapper = self.pipeline_ctx.call_handler(
- self.event, func_tool.handler, **func_tool_args
- )
- async for resp in wrapper:
- if resp is not None:
- # Tool 返回结果
- tool_call_result_blocks.append(
- ToolCallMessageSegment(
- role="tool",
- tool_call_id=func_tool_id,
- content=resp,
- )
- )
- yield MessageChain().message(resp)
- else:
- # Tool 直接请求发送消息给用户
- # 这里我们将直接结束 Agent Loop。
- self._transition_state(AgentState.DONE)
- if res := self.event.get_result():
- if res.chain:
- yield MessageChain(
- chain=res.chain, type="tool_direct_result"
+ elif isinstance(res.content[0], EmbeddedResource):
+ resource = res.content[0].resource
+ if isinstance(resource, TextResourceContents):
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content=resource.text,
)
+ )
+ yield MessageChain().message(resource.text)
+ elif (
+ isinstance(resource, BlobResourceContents)
+ and resource.mimeType
+ and resource.mimeType.startswith("image/")
+ ):
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content="返回了图片(已直接发送给用户)",
+ )
+ )
+ yield MessageChain(
+ type="tool_direct_result"
+ ).base64_image(res.content[0].data)
+ else:
+ tool_call_result_blocks.append(
+ ToolCallMessageSegment(
+ role="tool",
+ tool_call_id=func_tool_id,
+ content="返回的数据类型不受支持",
+ )
+ )
+ yield MessageChain().message("返回的数据类型不受支持。")
+ elif resp is None:
+ # Tool 直接请求发送消息给用户
+ # 这里我们将直接结束 Agent Loop。
+ self._transition_state(AgentState.DONE)
+ if res := self.event.get_result():
+ if res.chain:
+ yield MessageChain(
+ chain=res.chain, type="tool_direct_result"
+ )
+ else:
+ logger.warning(
+ f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
+ )
self.event.clear_result()
except Exception as e:
diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py
index ae44cf36d..505b18e8e 100644
--- a/astrbot/core/pipeline/process_stage/method/llm_request.py
+++ b/astrbot/core/pipeline/process_stage/method/llm_request.py
@@ -100,7 +100,8 @@ class LLMRequestSubStage(Stage):
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
- req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
+ # func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
+ # req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
@@ -274,7 +275,6 @@ class LLMRequestSubStage(Stage):
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
-
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
@@ -307,19 +307,10 @@ class LLMRequestSubStage(Stage):
if not title or "" in title:
return
await self.conv_manager.update_conversation_title(
- event.unified_msg_origin, title=title
+ unified_msg_origin=event.unified_msg_origin,
+ title=title,
+ conversation_id=req.conversation.cid,
)
- # 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
- # webchat adapter 中,session_id 的格式是 f"webchat!{username}!{cid}"
- # TODO: 优化 WebChat 适配器的对话管理
- if event.session_id:
- username, cid = event.session_id.split("!")[1:3]
- db_helper = self.ctx.plugin_manager.context._db
- db_helper.update_conversation_title(
- user_id=username,
- cid=cid,
- title=title,
- )
async def _save_to_history(
self,
diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py
index 2d120d7f6..d0ba920c1 100644
--- a/astrbot/core/provider/entities.py
+++ b/astrbot/core/provider/entities.py
@@ -5,7 +5,7 @@ from astrbot.core.utils.io import download_image_by_url
from astrbot import logger
from dataclasses import dataclass, field
from typing import List, Dict, Type
-from .func_tool_manager import FuncCall
+from .func_tool_manager import FunctionToolManager, ToolSet
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
@@ -97,7 +97,7 @@ class ProviderRequest:
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
- func_tool: FuncCall | None = None
+ func_tool: FunctionToolManager | ToolSet | None = None
"""可用的函数工具"""
contexts: list[dict] = field(default_factory=list)
"""上下文。格式与 openai 的上下文格式一致:
diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py
index 07a0fbd8f..117a03800 100644
--- a/astrbot/core/provider/func_tool_manager.py
+++ b/astrbot/core/provider/func_tool_manager.py
@@ -1,16 +1,17 @@
from __future__ import annotations
import json
-import textwrap
import os
import asyncio
import logging
from datetime import timedelta
+from deprecated import deprecated
-from typing import Dict, List, Awaitable, Literal, Any
+from typing import Dict, List, Awaitable, Literal, Any, AsyncGenerator
from dataclasses import dataclass
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
+from astrbot.core import sp
from astrbot.core.utils.log_pipe import LogPipe
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -28,6 +29,13 @@ except (ModuleNotFoundError, ImportError):
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
+from typing_extensions import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from astrbot.core.platform.astr_message_event import AstrMessageEvent
+ from astrbot.core.pipeline.context import PipelineContext
+
+
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
SUPPORTED_TYPES = [
@@ -39,6 +47,302 @@ SUPPORTED_TYPES = [
] # json schema 支持的数据类型
+@dataclass
+class FunctionTool:
+ """A class representing a function tool that can be used in function calling."""
+
+ name: str
+ parameters: Dict
+ description: str
+ handler: Awaitable = None
+ """处理函数, 当 origin 为 mcp 时,这个为空"""
+ handler_module_path: str = None
+ """处理函数的模块路径,当 origin 为 mcp 时,这个为空
+
+ 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
+ """
+ active: bool = True
+ """是否激活"""
+
+ origin: Literal["local", "mcp"] = "local"
+ """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
+
+ # MCP 相关字段
+ mcp_server_name: str = None
+ """MCP 服务名称,当 origin 为 mcp 时有效"""
+ mcp_client: MCPClient = None
+ """MCP 客户端,当 origin 为 mcp 时有效"""
+
+ def __repr__(self):
+ return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
+
+ async def execute(
+ self,
+ event: AstrMessageEvent = None,
+ pipeline_context: "PipelineContext" = None,
+ **tool_args,
+ ) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]:
+ """执行函数调用。
+
+ Args:
+ event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
+ pipeline_context (PipelineContext): 流水线调度器上下文, 当 origin 为 local 时必须提供。
+ **kwargs: 函数调用的参数。
+
+ Returns:
+ AsyncGenerator[None | mcp.types.CallToolResult, None]
+ """
+ if self.origin == "local":
+ if not event:
+ raise ValueError("Event must be provided for local function tools.")
+ wrapper = pipeline_context.call_handler(
+ event=event,
+ handler=self.handler,
+ **tool_args,
+ )
+ async for resp in wrapper:
+ if resp is not None:
+ text_content = mcp.types.TextContent(
+ type="text",
+ text=str(resp),
+ )
+ yield mcp.types.CallToolResult(content=[text_content])
+ else:
+ # NOTE: Tool 在这里直接请求发送消息给用户
+ # TODO: 是否需要判断 event.get_result() 是否为空?
+ # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
+ yield None
+
+ elif self.origin == "mcp":
+ if not self.mcp_client:
+ raise ValueError("MCP client is not available for MCP function tools.")
+ res = await self.mcp_client.session.call_tool(
+ name=self.name,
+ arguments=tool_args,
+ )
+ if not res:
+ return
+ yield res
+
+ else:
+ raise Exception(f"Unknown function origin: {self.origin}")
+
+ def __dict__(self) -> dict[str, Any]:
+ """将 FunctionTool 转换为字典格式"""
+ return {
+ "name": self.name,
+ "parameters": self.parameters,
+ "description": self.description,
+ "active": self.active,
+ "origin": self.origin,
+ "mcp_server_name": self.mcp_server_name,
+ }
+
+
+# alias for FunctionTool
+FuncTool = FunctionTool
+
+
+class ToolSet:
+ """A set of function tools that can be used in function calling.
+
+ This class provides methods to add, remove, and retrieve tools, as well as
+ convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
+
+ def __init__(self, tools: List[FunctionTool] = None):
+ self.tools: List[FunctionTool] = tools or []
+
+ def empty(self) -> bool:
+ """Check if the tool set is empty."""
+ return len(self.tools) == 0
+
+ def add_tool(self, tool: FunctionTool):
+ """Add a tool to the set."""
+ # 检查是否已存在同名工具
+ for i, existing_tool in enumerate(self.tools):
+ if existing_tool.name == tool.name:
+ self.tools[i] = tool
+ return
+ self.tools.append(tool)
+
+ def remove_tool(self, name: str):
+ """Remove a tool by its name."""
+ self.tools = [tool for tool in self.tools if tool.name != name]
+
+ def get_tool(self, name: str) -> Optional[FunctionTool]:
+ """Get a tool by its name."""
+ for tool in self.tools:
+ if tool.name == name:
+ return tool
+ return None
+
+ @deprecated(reason="Use add_tool() instead", version="4.0.0")
+ def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
+ """Add a function tool to the set."""
+ params = {
+ "type": "object", # hard-coded here
+ "properties": {},
+ }
+ for param in func_args:
+ params["properties"][param["name"]] = {
+ "type": param["type"],
+ "description": param["description"],
+ }
+ _func = FunctionTool(
+ name=name,
+ parameters=params,
+ description=desc,
+ handler=handler,
+ )
+ self.add_tool(_func)
+
+ @deprecated(reason="Use remove_tool() instead", version="4.0.0")
+ def remove_func(self, name: str):
+ """Remove a function tool by its name."""
+ self.remove_tool(name)
+
+ @deprecated(reason="Use get_tool() instead", version="4.0.0")
+ def get_func(self, name: str) -> List[FunctionTool]:
+ """Get all function tools."""
+ return self.get_tool(name)
+
+ def openai_schema(self, omit_empty_parameters: bool = False) -> List[Dict]:
+ """Convert tools to OpenAI API function calling schema format."""
+ result = []
+ for tool in self.tools:
+ func_def = {
+ "type": "function",
+ "function": {
+ "name": tool.name,
+ "description": tool.description,
+ },
+ }
+
+ if tool.parameters.get("properties") or not omit_empty_parameters:
+ func_def["function"]["parameters"] = tool.parameters
+
+ result.append(func_def)
+ return result
+
+ def anthropic_schema(self) -> List[Dict]:
+ """Convert tools to Anthropic API format."""
+ result = []
+ for tool in self.tools:
+ tool_def = {
+ "name": tool.name,
+ "description": tool.description,
+ "input_schema": {
+ "type": "object",
+ "properties": tool.parameters.get("properties", {}),
+ "required": tool.parameters.get("required", []),
+ },
+ }
+ result.append(tool_def)
+ return result
+
+ def google_schema(self) -> Dict:
+ """Convert tools to Google GenAI API format."""
+
+ def convert_schema(schema: dict) -> dict:
+ """Convert schema to Gemini API format."""
+ supported_types = {
+ "string",
+ "number",
+ "integer",
+ "boolean",
+ "array",
+ "object",
+ "null",
+ }
+ supported_formats = {
+ "string": {"enum", "date-time"},
+ "integer": {"int32", "int64"},
+ "number": {"float", "double"},
+ }
+
+ if "anyOf" in schema:
+ return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
+
+ result = {}
+
+ if "type" in schema and schema["type"] in supported_types:
+ result["type"] = schema["type"]
+ if "format" in schema and schema["format"] in supported_formats.get(
+ result["type"], set()
+ ):
+ result["format"] = schema["format"]
+ else:
+ result["type"] = "null"
+
+ support_fields = {
+ "title",
+ "description",
+ "enum",
+ "minimum",
+ "maximum",
+ "maxItems",
+ "minItems",
+ "nullable",
+ "required",
+ }
+ result.update({k: schema[k] for k in support_fields if k in schema})
+
+ if "properties" in schema:
+ properties = {}
+ for key, value in schema["properties"].items():
+ prop_value = convert_schema(value)
+ if "default" in prop_value:
+ del prop_value["default"]
+ properties[key] = prop_value
+
+ if properties:
+ result["properties"] = properties
+
+ if "items" in schema:
+ result["items"] = convert_schema(schema["items"])
+
+ return result
+
+ tools = [
+ {
+ "name": tool.name,
+ "description": tool.description,
+ "parameters": convert_schema(tool.parameters),
+ }
+ for tool in self.tools
+ ]
+
+ declarations = {}
+ if tools:
+ declarations["function_declarations"] = tools
+ return declarations
+
+ @deprecated(reason="Use openai_schema() instead", version="4.0.0")
+ def get_func_desc_openai_style(self, omit_empty_parameters: bool = False):
+ return self.openai_schema(omit_empty_parameters)
+
+ @deprecated(reason="Use anthropic_schema() instead", version="4.0.0")
+ def get_func_desc_anthropic_style(self):
+ return self.anthropic_schema()
+
+ @deprecated(reason="Use google_schema() instead", version="4.0.0")
+ def get_func_desc_google_genai_style(self):
+ return self.google_schema()
+
+ def names(self) -> List[str]:
+ """获取所有工具的名称列表"""
+ return [tool.name for tool in self.tools]
+
+ def __len__(self):
+ return len(self.tools)
+
+ def __bool__(self):
+ return len(self.tools) > 0
+
+ def __iter__(self):
+ return iter(self.tools)
+
+
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if "mcpServers" in config and config["mcpServers"]:
@@ -105,55 +409,6 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return False, f"{e!s}"
-@dataclass
-class FuncTool:
- """
- 用于描述一个函数调用工具。
- """
-
- name: str
- parameters: Dict
- description: str
- handler: Awaitable = None
- """处理函数, 当 origin 为 mcp 时,这个为空"""
- handler_module_path: str = None
- """处理函数的模块路径,当 origin 为 mcp 时,这个为空
-
- 必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
- """
- active: bool = True
- """是否激活"""
-
- origin: Literal["local", "mcp"] = "local"
- """函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
-
- # MCP 相关字段
- mcp_server_name: str = None
- """MCP 服务名称,当 origin 为 mcp 时有效"""
- mcp_client: MCPClient = None
- """MCP 客户端,当 origin 为 mcp 时有效"""
-
- def __repr__(self):
- return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
-
- async def execute(self, **args) -> Any:
- """执行函数调用"""
- if self.origin == "local":
- if not self.handler:
- raise Exception(f"Local function {self.name} has no handler")
- return await self.handler(**args)
- elif self.origin == "mcp":
- if not self.mcp_client or not self.mcp_client.session:
- raise Exception(f"MCP client for {self.name} is not available")
- # 使用name属性而不是额外的mcp_tool_name
- actual_tool_name = (
- self.name.split(":")[-1] if ":" in self.name else self.name
- )
- return await self.mcp_client.session.call_tool(actual_tool_name, args)
- else:
- raise Exception(f"Unknown function origin: {self.origin}")
-
-
class MCPClient:
def __init__(self):
# Initialize session and client objects
@@ -276,10 +531,9 @@ class MCPClient:
self.running_event.set() # Set the running event to indicate cleanup is done
-class FuncCall:
+class FunctionToolManager:
def __init__(self) -> None:
self.func_list: List[FuncTool] = []
- """内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_client_event: Dict[str, asyncio.Event] = {}
@@ -331,11 +585,15 @@ class FuncCall:
self.func_list.pop(i)
break
- def get_func(self, name) -> FuncTool:
+ def get_func(self, name) -> FuncTool | None:
for f in self.func_list:
if f.name == name:
return f
- return None
+
+ def get_full_tool_set(self) -> ToolSet:
+ """获取完整工具集"""
+ tool_set = ToolSet(self.func_list.copy())
+ return tool_set
async def init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
@@ -556,203 +814,68 @@ class FuncCall:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
"""
- _l = []
- # 处理所有工具(包括本地和MCP工具)
- for f in self.func_list:
- if not f.active:
- continue
- func_ = {
- "type": "function",
- "function": {
- "name": f.name,
- # "parameters": f.parameters,
- "description": f.description,
- },
- }
- func_["function"]["parameters"] = f.parameters
- if not f.parameters.get("properties") and omit_empty_parameter_field:
- # 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段
- del func_["function"]["parameters"]
- _l.append(func_)
- return _l
+ tools = [f for f in self.func_list if f.active]
+ toolset = ToolSet(tools)
+ return toolset.openai_schema(omit_empty_parameters=omit_empty_parameter_field)
def get_func_desc_anthropic_style(self) -> list:
"""
获得 Anthropic API 风格的**已经激活**的工具描述
"""
- tools = []
- for f in self.func_list:
- if not f.active:
- continue
-
- # Convert internal format to Anthropic style
- tool = {
- "name": f.name,
- "description": f.description,
- "input_schema": {
- "type": "object",
- "properties": f.parameters.get("properties", {}),
- # Keep the required field from the original parameters if it exists
- "required": f.parameters.get("required", []),
- },
- }
- tools.append(tool)
- return tools
+ tools = [f for f in self.func_list if f.active]
+ toolset = ToolSet(tools)
+ return toolset.anthropic_schema()
def get_func_desc_google_genai_style(self) -> dict:
"""
获得 Google GenAI API 风格的**已经激活**的工具描述
"""
+ tools = [f for f in self.func_list if f.active]
+ toolset = ToolSet(tools)
+ return toolset.google_schema()
- # Gemini API 支持的数据类型和格式
- supported_types = {
- "string",
- "number",
- "integer",
- "boolean",
- "array",
- "object",
- "null",
- }
- supported_formats = {
- "string": {"enum", "date-time"},
- "integer": {"int32", "int64"},
- "number": {"float", "double"},
- }
+ def deactivate_llm_tool(self, name: str) -> bool:
+ """停用一个已经注册的函数调用工具。
- def convert_schema(schema: dict) -> dict:
- """转换 schema 为 Gemini API 格式"""
+ Returns:
+ 如果没找到,会返回 False"""
+ func_tool = self.get_func(name)
+ if func_tool is not None:
+ func_tool.active = False
- # 如果 schema 包含 anyOf,则只返回 anyOf 字段
- if "anyOf" in schema:
- return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
+ inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
+ if name not in inactivated_llm_tools:
+ inactivated_llm_tools.append(name)
+ sp.put("inactivated_llm_tools", inactivated_llm_tools)
- result = {}
+ return True
+ return False
- if "type" in schema and schema["type"] in supported_types:
- result["type"] = schema["type"]
- if "format" in schema and schema["format"] in supported_formats.get(
- result["type"], set()
- ):
- result["format"] = schema["format"]
- else:
- # 暂时指定默认为null
- result["type"] = "null"
+ # 因为不想解决循环引用,所以这里直接传入 star_map 先了...
+ def activate_llm_tool(self, name: str, star_map: dict) -> bool:
+ func_tool = self.get_func(name)
+ if func_tool is not None:
+ if func_tool.handler_module_path in star_map:
+ if not star_map[func_tool.handler_module_path].activated:
+ raise ValueError(
+ f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
+ )
- support_fields = {
- "title",
- "description",
- "enum",
- "minimum",
- "maximum",
- "maxItems",
- "minItems",
- "nullable",
- "required",
- }
- result.update({k: schema[k] for k in support_fields if k in schema})
+ func_tool.active = True
- if "properties" in schema:
- properties = {}
- for key, value in schema["properties"].items():
- prop_value = convert_schema(value)
- if "default" in prop_value:
- del prop_value["default"]
- properties[key] = prop_value
+ inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
+ if name in inactivated_llm_tools:
+ inactivated_llm_tools.remove(name)
+ sp.put("inactivated_llm_tools", inactivated_llm_tools)
- if properties: # 只在有非空属性时添加
- result["properties"] = properties
-
- if "items" in schema:
- result["items"] = convert_schema(schema["items"])
-
- return result
-
- tools = [
- {
- "name": f.name,
- "description": f.description,
- **({"parameters": convert_schema(f.parameters)}),
- }
- for f in self.func_list
- if f.active
- ]
-
- declarations = {}
- if tools:
- declarations["function_declarations"] = tools
- return declarations
-
- async def func_call(self, question: str, session_id: str, provider) -> tuple:
- _l = []
- for f in self.func_list:
- if not f.active:
- continue
- _l.append(
- {
- "name": f.name,
- "parameters": f.parameters,
- "description": f.description,
- }
- )
- func_definition = json.dumps(_l, ensure_ascii=False)
-
- prompt = textwrap.dedent(f"""
- ROLE:
- 你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
-
- TOOLS:
- 可用的函数列表:
-
- {func_definition}
-
- LIMIT:
- 1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。
- 2. 你的 Json 返回的格式如下:`[{{"name": "", "args": }}, ...]`。参数根据上面提供的函数列表中的参数来填写。
- 3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。
- 4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。
-
- EXAMPLE:
- 1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
-
- 用户的提问是:{question}
- """)
-
- _c = 0
- while _c < 3:
- try:
- res = await provider.text_chat(prompt, session_id)
- if res.find("```") != -1:
- res = res[res.find("```json") + 7 : res.rfind("```")]
- res = json.loads(res)
- break
- except Exception as e:
- _c += 1
- if _c == 3:
- raise e
- if "The message you submitted was too long" in str(e):
- raise e
-
- if "res" in res and not res["res"]:
- return "", False
-
- tool_call_result = []
- for tool in res:
- # 说明有函数调用
- func_name = tool["name"]
- args = tool["args"]
- # 调用函数
- func_tool = self.get_func(func_name)
- if not func_tool:
- raise Exception(f"Request function {func_name} not found.")
-
- ret = await func_tool.execute(**args)
- if ret:
- tool_call_result.append(str(ret))
- return tool_call_result, True
+ return True
+ return False
def __str__(self):
return str(self.func_list)
def __repr__(self):
return str(self.func_list)
+
+
+FuncCall = FunctionToolManager
diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py
index 370c5322b..7d7779e81 100644
--- a/astrbot/core/provider/manager.py
+++ b/astrbot/core/provider/manager.py
@@ -7,85 +7,27 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db import BaseDatabase
from .entities import ProviderType
-from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
+from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
from .register import llm_tools, provider_cls_map
+from ..persona_mgr import PersonaManager
class ProviderManager:
- def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
+ def __init__(
+ self,
+ config: AstrBotConfig,
+ db_helper: BaseDatabase,
+ persona_mgr: PersonaManager,
+ ):
+ self.persona_mgr = persona_mgr
+ self.astrbot_config = config
self.providers_config: List = config["provider"]
self.provider_settings: dict = config["provider_settings"]
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
- self.persona_configs: list = config.get("persona", [])
- self.astrbot_config = config
- # 人格情景管理
- # 目前没有拆成独立的模块
- self.default_persona_name = self.provider_settings.get(
- "default_personality", "default"
- )
- self.personas: List[Personality] = []
- self.selected_default_persona = None
- for persona in self.persona_configs:
- begin_dialogs = persona.get("begin_dialogs", [])
- mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
- bd_processed = []
- mid_processed = ""
- if begin_dialogs:
- if len(begin_dialogs) % 2 != 0:
- logger.error(
- f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。"
- )
- begin_dialogs = []
- user_turn = True
- for dialog in begin_dialogs:
- bd_processed.append(
- {
- "role": "user" if user_turn else "assistant",
- "content": dialog,
- "_no_save": None, # 不持久化到 db
- }
- )
- user_turn = not user_turn
- if mood_imitation_dialogs:
- if len(mood_imitation_dialogs) % 2 != 0:
- logger.error(
- f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。"
- )
- mood_imitation_dialogs = []
- user_turn = True
- for dialog in mood_imitation_dialogs:
- role = "A" if user_turn else "B"
- mid_processed += f"{role}: {dialog}\n"
- if not user_turn:
- mid_processed += "\n"
- user_turn = not user_turn
-
- try:
- persona = Personality(
- **persona,
- _begin_dialogs_processed=bd_processed,
- _mood_imitation_dialogs_processed=mid_processed,
- )
- if persona["name"] == self.default_persona_name:
- self.selected_default_persona = persona
- self.personas.append(persona)
- except Exception as e:
- logger.error(f"解析 Persona 配置失败:{e}")
-
- if not self.selected_default_persona and len(self.personas) > 0:
- # 默认选择第一个
- self.selected_default_persona = self.personas[0]
-
- if not self.selected_default_persona:
- self.selected_default_persona = Personality(
- prompt="You are a helpful and friendly assistant.",
- name="default",
- _begin_dialogs_processed=[],
- _mood_imitation_dialogs_processed="",
- )
- self.personas.append(self.selected_default_persona)
+ # 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
+ self.default_persona_name = persona_mgr.default_persona
self.provider_insts: List[Provider] = []
"""加载的 Provider 的实例"""
@@ -113,6 +55,21 @@ class ProviderManager:
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
+ @property
+ def persona_configs(self) -> list:
+ """动态获取最新的 persona 配置"""
+ return self.persona_mgr.persona_v3_config
+
+ @property
+ def personas(self) -> list:
+ """动态获取最新的 personas 列表"""
+ return self.persona_mgr.personas_v3
+
+ @property
+ def selected_default_persona(self):
+ """动态获取最新的默认选中 persona"""
+ return self.persona_mgr.selected_default_persona_v3
+
async def set_provider(
self, provider_id: str, provider_type: ProviderType, umo: str = None
):
diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py
index 36401b089..2b4f81bb4 100644
--- a/astrbot/core/provider/provider.py
+++ b/astrbot/core/provider/provider.py
@@ -1,23 +1,13 @@
import abc
from typing import List
-from typing import TypedDict, AsyncGenerator
-from astrbot.core.provider.func_tool_manager import FuncCall
+from typing import AsyncGenerator
+from astrbot.core.provider.func_tool_manager import FunctionToolManager, ToolSet
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType
from astrbot.core.provider.register import provider_cls_map
+from astrbot.core.db.po import Personality
from dataclasses import dataclass
-class Personality(TypedDict):
- prompt: str = ""
- name: str = ""
- begin_dialogs: List[str] = []
- mood_imitation_dialogs: List[str] = []
-
- # cache
- _begin_dialogs_processed: List[dict] = []
- _mood_imitation_dialogs_processed: str = ""
-
-
@dataclass
class ProviderMeta:
id: str
@@ -90,7 +80,7 @@ class Provider(AbstractProvider):
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
- func_tool: FuncCall = None,
+ func_tool: FunctionToolManager | ToolSet = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
@@ -119,7 +109,7 @@ class Provider(AbstractProvider):
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
- func_tool: FuncCall = None,
+ func_tool: FunctionToolManager | ToolSet = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py
index 6d89da57b..76cdea062 100644
--- a/astrbot/core/star/context.py
+++ b/astrbot/core/star/context.py
@@ -11,12 +11,14 @@ from astrbot.core.provider.provider import (
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
-from astrbot.core.provider.func_tool_manager import FuncCall
+from astrbot.core.provider.func_tool_manager import FunctionToolManager
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform import Platform
from astrbot.core.platform.manager import PlatformManager
+from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
+from astrbot.core.persona_mgr import PersonaManager
from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
@@ -29,6 +31,11 @@ from astrbot.core.star.filter.platform_adapter_type import (
)
from deprecated import deprecated
+from typing_extensions import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from astrbot.core.pipeline.context import PipelineContext
+
class Context:
"""
@@ -50,6 +57,8 @@ class Context:
registered_web_apis: list = []
+ pipeline_ctx: "PipelineContext" = None
+
# back compatibility
_register_tasks: List[Awaitable] = []
_star_manager = None
@@ -62,6 +71,8 @@ class Context:
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
conversation_manager: ConversationManager = None,
+ message_history_manager: PlatformMessageHistoryManager = None,
+ persona_manager: PersonaManager = None,
):
self._event_queue = event_queue
self._config = config
@@ -69,6 +80,8 @@ class Context:
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.conversation_manager = conversation_manager
+ self.message_history_manager = message_history_manager
+ self.persona_manager = persona_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
"""根据插件名获取插件的 Metadata"""
@@ -80,7 +93,7 @@ class Context:
"""获取当前载入的所有插件 Metadata 的列表"""
return star_registry
- def get_llm_tool_manager(self) -> FuncCall:
+ def get_llm_tool_manager(self) -> FunctionToolManager:
"""获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools"""
return self.provider_manager.llm_tools
@@ -90,40 +103,14 @@ class Context:
Returns:
如果没找到,会返回 False
"""
- func_tool = self.provider_manager.llm_tools.get_func(name)
- if func_tool is not None:
- if func_tool.handler_module_path in star_map:
- if not star_map[func_tool.handler_module_path].activated:
- raise ValueError(
- f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
- )
-
- func_tool.active = True
-
- inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
- if name in inactivated_llm_tools:
- inactivated_llm_tools.remove(name)
- sp.put("inactivated_llm_tools", inactivated_llm_tools)
-
- return True
- return False
+ return self.provider_manager.llm_tools.activate_llm_tool(name, star_map)
def deactivate_llm_tool(self, name: str) -> bool:
"""停用一个已经注册的函数调用工具。
Returns:
如果没找到,会返回 False"""
- func_tool = self.provider_manager.llm_tools.get_func(name)
- if func_tool is not None:
- func_tool.active = False
-
- inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
- if name not in inactivated_llm_tools:
- inactivated_llm_tools.append(name)
- sp.put("inactivated_llm_tools", inactivated_llm_tools)
-
- return True
- return False
+ return self.provider_manager.llm_tools.deactivate_llm_tool(name)
def register_provider(self, provider: Provider):
"""
@@ -208,7 +195,9 @@ class Context:
return self._event_queue
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
- def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform | None:
+ def get_platform(
+ self, platform_type: Union[PlatformAdapterType, str]
+ ) -> Platform | None:
"""
获取指定类型的平台适配器。
@@ -268,6 +257,9 @@ class Context:
return True
return False
+ def get_pipeline_context(self) -> "PipelineContext":
+ return self.pipeline_ctx
+
"""
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
"""
diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py
index 8d08b9d53..ef2fa3e86 100644
--- a/astrbot/dashboard/routes/__init__.py
+++ b/astrbot/dashboard/routes/__init__.py
@@ -6,11 +6,11 @@ from .stat import StatRoute
from .log import LogRoute
from .static_file import StaticFileRoute
from .chat import ChatRoute
-from .tools import ToolsRoute # 导入新的ToolsRoute
+from .tools import ToolsRoute
from .conversation import ConversationRoute
from .file import FileRoute
from .session_management import SessionManagementRoute
-
+from .persona import PersonaRoute
__all__ = [
"AuthRoute",
@@ -25,4 +25,5 @@ __all__ = [
"ConversationRoute",
"FileRoute",
"SessionManagementRoute",
+ "PersonaRoute",
]
diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py
index 7de720a38..11b474860 100644
--- a/astrbot/dashboard/routes/config.py
+++ b/astrbot/dashboard/routes/config.py
@@ -169,7 +169,6 @@ class ConfigRoute(Route):
"/config/provider/new": ("POST", self.post_new_provider),
"/config/provider/update": ("POST", self.post_update_provider),
"/config/provider/delete": ("POST", self.post_delete_provider),
- "/config/llmtools": ("GET", self.get_llm_tools),
"/config/provider/check_one": ("GET", self.check_one_provider_status),
"/config/provider/list": ("GET", self.get_provider_config_list),
"/config/provider/model_list": ("GET", self.get_provider_model_list),
@@ -509,12 +508,6 @@ class ConfigRoute(Route):
return Response().error(str(e)).__dict__
return Response().ok(None, "删除成功,已经实时生效~").__dict__
- async def get_llm_tools(self):
- """获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
- tool_mgr = self.core_lifecycle.provider_manager.llm_tools
- tools = tool_mgr.get_func_desc_openai_style()
- return Response().ok(tools).__dict__
-
async def _get_astrbot_config(self):
config = self.config
diff --git a/astrbot/dashboard/routes/persona.py b/astrbot/dashboard/routes/persona.py
new file mode 100644
index 000000000..032471ee4
--- /dev/null
+++ b/astrbot/dashboard/routes/persona.py
@@ -0,0 +1,199 @@
+import traceback
+from .route import Route, Response, RouteContext
+from astrbot.core import logger
+from quart import request
+from astrbot.core.db import BaseDatabase
+from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
+
+
+class PersonaRoute(Route):
+ def __init__(
+ self,
+ context: RouteContext,
+ db_helper: BaseDatabase,
+ core_lifecycle: AstrBotCoreLifecycle,
+ ) -> None:
+ super().__init__(context)
+ self.routes = {
+ "/persona/list": ("GET", self.list_personas),
+ "/persona/detail": ("POST", self.get_persona_detail),
+ "/persona/create": ("POST", self.create_persona),
+ "/persona/update": ("POST", self.update_persona),
+ "/persona/delete": ("POST", self.delete_persona),
+ }
+ self.db_helper = db_helper
+ self.persona_mgr = core_lifecycle.persona_mgr
+ self.register_routes()
+
+ async def list_personas(self):
+ """获取所有人格列表"""
+ try:
+ personas = await self.persona_mgr.get_all_personas()
+ return (
+ Response()
+ .ok(
+ [
+ {
+ "persona_id": persona.persona_id,
+ "system_prompt": persona.system_prompt,
+ "begin_dialogs": persona.begin_dialogs or [],
+ "tools": persona.tools,
+ "created_at": persona.created_at.isoformat()
+ if persona.created_at
+ else None,
+ "updated_at": persona.updated_at.isoformat()
+ if persona.updated_at
+ else None,
+ }
+ for persona in personas
+ ]
+ )
+ .__dict__
+ )
+ except Exception as e:
+ logger.error(f"获取人格列表失败: {str(e)}\n{traceback.format_exc()}")
+ return Response().error(f"获取人格列表失败: {str(e)}").__dict__
+
+ async def get_persona_detail(self):
+ """获取指定人格的详细信息"""
+ try:
+ data = await request.get_json()
+ persona_id = data.get("persona_id")
+
+ if not persona_id:
+ return Response().error("缺少必要参数: persona_id").__dict__
+
+ persona = await self.persona_mgr.get_persona(persona_id)
+ if not persona:
+ return Response().error("人格不存在").__dict__
+
+ return (
+ Response()
+ .ok(
+ {
+ "persona_id": persona.persona_id,
+ "system_prompt": persona.system_prompt,
+ "begin_dialogs": persona.begin_dialogs or [],
+ "tools": persona.tools,
+ "created_at": persona.created_at.isoformat()
+ if persona.created_at
+ else None,
+ "updated_at": persona.updated_at.isoformat()
+ if persona.updated_at
+ else None,
+ }
+ )
+ .__dict__
+ )
+ except Exception as e:
+ logger.error(f"获取人格详情失败: {str(e)}\n{traceback.format_exc()}")
+ return Response().error(f"获取人格详情失败: {str(e)}").__dict__
+
+ async def create_persona(self):
+ """创建新人格"""
+ try:
+ data = await request.get_json()
+ persona_id = data.get("persona_id", "").strip()
+ system_prompt = data.get("system_prompt", "").strip()
+ begin_dialogs = data.get("begin_dialogs", [])
+ tools = data.get("tools")
+
+ if not persona_id:
+ return Response().error("人格ID不能为空").__dict__
+
+ if not system_prompt:
+ return Response().error("系统提示词不能为空").__dict__
+
+ # 验证 begin_dialogs 格式
+ if begin_dialogs and len(begin_dialogs) % 2 != 0:
+ return (
+ Response()
+ .error("预设对话数量必须为偶数(用户和助手轮流对话)")
+ .__dict__
+ )
+
+ persona = await self.persona_mgr.create_persona(
+ persona_id=persona_id,
+ system_prompt=system_prompt,
+ begin_dialogs=begin_dialogs if begin_dialogs else None,
+ tools=tools if tools else None,
+ )
+
+ return (
+ Response()
+ .ok(
+ {
+ "message": "人格创建成功",
+ "persona": {
+ "persona_id": persona.persona_id,
+ "system_prompt": persona.system_prompt,
+ "begin_dialogs": persona.begin_dialogs or [],
+ "tools": persona.tools or [],
+ "created_at": persona.created_at.isoformat()
+ if persona.created_at
+ else None,
+ "updated_at": persona.updated_at.isoformat()
+ if persona.updated_at
+ else None,
+ },
+ }
+ )
+ .__dict__
+ )
+ except ValueError as e:
+ return Response().error(str(e)).__dict__
+ except Exception as e:
+ logger.error(f"创建人格失败: {str(e)}\n{traceback.format_exc()}")
+ return Response().error(f"创建人格失败: {str(e)}").__dict__
+
+ async def update_persona(self):
+ """更新人格信息"""
+ try:
+ data = await request.get_json()
+ persona_id = data.get("persona_id")
+ system_prompt = data.get("system_prompt")
+ begin_dialogs = data.get("begin_dialogs")
+ tools = data.get("tools")
+
+ if not persona_id:
+ return Response().error("缺少必要参数: persona_id").__dict__
+
+ # 验证 begin_dialogs 格式
+ if begin_dialogs is not None and len(begin_dialogs) % 2 != 0:
+ return (
+ Response()
+ .error("预设对话数量必须为偶数(用户和助手轮流对话)")
+ .__dict__
+ )
+
+ await self.persona_mgr.update_persona(
+ persona_id=persona_id,
+ system_prompt=system_prompt,
+ begin_dialogs=begin_dialogs,
+ tools=tools,
+ )
+
+ return Response().ok({"message": "人格更新成功"}).__dict__
+ except ValueError as e:
+ return Response().error(str(e)).__dict__
+ except Exception as e:
+ logger.error(f"更新人格失败: {str(e)}\n{traceback.format_exc()}")
+ return Response().error(f"更新人格失败: {str(e)}").__dict__
+
+ async def delete_persona(self):
+ """删除人格"""
+ try:
+ data = await request.get_json()
+ persona_id = data.get("persona_id")
+
+ if not persona_id:
+ return Response().error("缺少必要参数: persona_id").__dict__
+
+ await self.persona_mgr.delete_persona(persona_id)
+
+ return Response().ok({"message": "人格删除成功"}).__dict__
+ except ValueError as e:
+ return Response().error(str(e)).__dict__
+ except Exception as e:
+ logger.error(f"删除人格失败: {str(e)}\n{traceback.format_exc()}")
+ return Response().error(f"删除人格失败: {str(e)}").__dict__
diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py
index 5dad2576b..324fc62d3 100644
--- a/astrbot/dashboard/routes/tools.py
+++ b/astrbot/dashboard/routes/tools.py
@@ -8,6 +8,7 @@ from quart import request
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
+from astrbot.core.star import star_map
from .route import Response, Route, RouteContext
@@ -27,6 +28,8 @@ class ToolsRoute(Route):
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
"/tools/mcp/market": ("GET", self.get_mcp_markets),
"/tools/mcp/test": ("POST", self.test_mcp_connection),
+ "/tools/list": ("GET", self.get_tool_list),
+ "/tools/toggle-tool": ("POST", self.toggle_tool),
}
self.register_routes()
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
@@ -336,3 +339,40 @@ class ToolsRoute(Route):
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__
+
+ async def get_tool_list(self):
+ """获取所有注册的工具列表"""
+ try:
+ tools = self.tool_mgr.func_list
+ tools_dict = [tool.__dict__() for tool in tools]
+ return Response().ok(data=tools_dict).__dict__
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return Response().error(f"获取工具列表失败: {str(e)}").__dict__
+
+ async def toggle_tool(self):
+ """启用或停用指定的工具"""
+ try:
+ data = await request.json
+ tool_name = data.get("name")
+ action = data.get("activate") # True or False
+
+ if not tool_name or action is None:
+ return Response().error("缺少必要参数: name 或 action").__dict__
+
+ if action:
+ try:
+ ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
+ except ValueError as e:
+ return Response().error(f"启用工具失败: {str(e)}").__dict__
+ else:
+ ok = self.tool_mgr.deactivate_llm_tool(tool_name)
+
+ if ok:
+ return Response().ok(None, "操作成功。").__dict__
+ else:
+ return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return Response().error(f"操作工具失败: {str(e)}").__dict__
diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py
index 06f6f8e60..e22b20524 100644
--- a/astrbot/dashboard/server.py
+++ b/astrbot/dashboard/server.py
@@ -60,6 +60,9 @@ class AstrBotDashboard:
self.session_management_route = SessionManagementRoute(
self.context, db, core_lifecycle
)
+ self.persona_route = PersonaRoute(
+ self.context, db, core_lifecycle
+ )
self.app.add_url_rule(
"/api/plug/",
diff --git a/dashboard/src/i18n/loader.ts b/dashboard/src/i18n/loader.ts
index 275110ddc..ec76b2c0c 100644
--- a/dashboard/src/i18n/loader.ts
+++ b/dashboard/src/i18n/loader.ts
@@ -52,6 +52,7 @@ export class I18nLoader {
{ name: 'features/alkaid/index', path: 'features/alkaid/index.json' },
{ name: 'features/alkaid/knowledge-base', path: 'features/alkaid/knowledge-base.json' },
{ name: 'features/alkaid/memory', path: 'features/alkaid/memory.json' },
+ { name: 'features/persona', path: 'features/persona.json' },
// 消息模块
{ name: 'messages/errors', path: 'messages/errors.json' },
diff --git a/dashboard/src/i18n/locales/en-US/core/navigation.json b/dashboard/src/i18n/locales/en-US/core/navigation.json
index 501383020..23276163a 100644
--- a/dashboard/src/i18n/locales/en-US/core/navigation.json
+++ b/dashboard/src/i18n/locales/en-US/core/navigation.json
@@ -2,6 +2,7 @@
"dashboard": "Dashboard",
"platforms": "Platforms",
"providers": "Providers",
+ "persona": "Persona",
"toolUse": "MCP Tools",
"config": "Config",
"extension": "Extensions",
diff --git a/dashboard/src/i18n/locales/en-US/features/persona.json b/dashboard/src/i18n/locales/en-US/features/persona.json
new file mode 100644
index 000000000..94708ee56
--- /dev/null
+++ b/dashboard/src/i18n/locales/en-US/features/persona.json
@@ -0,0 +1,67 @@
+{
+ "page": {
+ "description": "Manage and configure chat bot personality settings"
+ },
+ "buttons": {
+ "create": "Create Persona",
+ "createFirst": "Create First Persona",
+ "edit": "Edit",
+ "delete": "Delete",
+ "cancel": "Cancel",
+ "save": "Save",
+ "addDialogPair": "Add Dialog Pair"
+ },
+ "labels": {
+ "presetDialogs": "Preset Dialogs ({count} pairs)",
+ "createdAt": "Created At",
+ "updatedAt": "Updated At"
+ },
+ "form": {
+ "personaId": "Persona ID",
+ "systemPrompt": "System Prompt",
+ "presetDialogs": "Preset Dialogs",
+ "presetDialogsHelp": "Add some preset dialogs to help the bot better understand the role settings. The number of dialogs must be even (users and assistants take turns).",
+ "userMessage": "User Message",
+ "assistantMessage": "Assistant Message",
+ "tools": "Tool Selection",
+ "toolsHelp": "Select available tools for this persona. Tools allow the bot to perform specific functions such as searching, calculating, getting information, etc.",
+ "toolsSelection": "Tool Selection Actions",
+ "selectAllTools": "Select All Tools",
+ "clearAllTools": "Clear Selection",
+ "allSelected": "All Selected",
+ "mcpServersQuickSelect": "MCP Servers Quick Select",
+ "searchTools": "Search Tools",
+ "selectedTools": "Selected Tools",
+ "noToolsAvailable": "No tools available",
+ "noToolsFound": "No matching tools found",
+ "loadingTools": "Loading tools...",
+ "allToolsAvailable": "Use all available tools",
+ "noToolsSelected": "No tools selected"
+ },
+ "dialog": {
+ "create": {
+ "title": "Create New Persona"
+ },
+ "edit": {
+ "title": "Edit Persona"
+ }
+ },
+ "empty": {
+ "title": "No Persona Configured",
+ "description": "Create your first persona to start using personalized chatbots"
+ },
+ "validation": {
+ "required": "This field is required",
+ "minLength": "Minimum {min} characters required",
+ "alphanumeric": "Only letters, numbers, underscores and hyphens are allowed",
+ "dialogRequired": "{type} cannot be empty"
+ },
+ "messages": {
+ "loadError": "Failed to load persona list",
+ "saveSuccess": "Saved successfully",
+ "saveError": "Save failed",
+ "deleteConfirm": "Are you sure you want to delete persona \"{id}\"? This action cannot be undone.",
+ "deleteSuccess": "Deleted successfully",
+ "deleteError": "Delete failed"
+ }
+}
diff --git a/dashboard/src/i18n/locales/en-US/features/tool-use.json b/dashboard/src/i18n/locales/en-US/features/tool-use.json
index 96c4760e8..1c4c54e97 100644
--- a/dashboard/src/i18n/locales/en-US/features/tool-use.json
+++ b/dashboard/src/i18n/locales/en-US/features/tool-use.json
@@ -107,6 +107,9 @@
"failed": "Import configuration failed: {error}"
},
"configParseError": "Configuration parse error: {error}",
- "noAvailableConfig": "No available configuration"
+ "noAvailableConfig": "No available configuration",
+ "toggleToolSuccess": "Tool status toggled successfully!",
+ "toggleToolError": "Failed to toggle tool status: {error}",
+ "testError": "Test connection failed: {error}"
}
}
\ No newline at end of file
diff --git a/dashboard/src/i18n/locales/zh-CN/core/navigation.json b/dashboard/src/i18n/locales/zh-CN/core/navigation.json
index d0ed8453a..ccb4be5a1 100644
--- a/dashboard/src/i18n/locales/zh-CN/core/navigation.json
+++ b/dashboard/src/i18n/locales/zh-CN/core/navigation.json
@@ -2,6 +2,7 @@
"dashboard": "统计",
"platforms": "消息平台",
"providers": "服务提供商",
+ "persona": "人格管理",
"toolUse": "MCP",
"config": "配置文件",
"extension": "插件管理",
diff --git a/dashboard/src/i18n/locales/zh-CN/features/persona.json b/dashboard/src/i18n/locales/zh-CN/features/persona.json
new file mode 100644
index 000000000..619f1b607
--- /dev/null
+++ b/dashboard/src/i18n/locales/zh-CN/features/persona.json
@@ -0,0 +1,67 @@
+{
+ "page": {
+ "description": "管理和配置聊天机器人的人格角色设定"
+ },
+ "buttons": {
+ "create": "创建人格",
+ "createFirst": "创建第一个人格",
+ "edit": "编辑",
+ "delete": "删除",
+ "cancel": "取消",
+ "save": "保存",
+ "addDialogPair": "添加对话对"
+ },
+ "labels": {
+ "presetDialogs": "预设对话 ({count} 对)",
+ "createdAt": "创建时间",
+ "updatedAt": "更新时间"
+ },
+ "form": {
+ "personaId": "人格 ID",
+ "systemPrompt": "系统提示词",
+ "presetDialogs": "预设对话",
+ "presetDialogsHelp": "添加一些预设的对话来帮助机器人更好地理解角色设定。对话数量必须为偶数(用户和助手轮流对话)。",
+ "userMessage": "用户消息",
+ "assistantMessage": "助手消息",
+ "tools": "工具选择",
+ "toolsHelp": "为这个人格选择可用的工具。工具可以让机器人执行特定的功能,如搜索、计算、获取信息等。",
+ "toolsSelection": "工具选择操作",
+ "selectAllTools": "选择所有工具",
+ "clearAllTools": "清空选择",
+ "allSelected": "全选",
+ "mcpServersQuickSelect": "MCP 服务器快速选择",
+ "searchTools": "搜索工具",
+ "selectedTools": "已选择的工具",
+ "noToolsAvailable": "暂无可用工具",
+ "noToolsFound": "未找到匹配的工具",
+ "loadingTools": "正在加载工具...",
+ "allToolsAvailable": "使用所有可用工具",
+ "noToolsSelected": "未选择任何工具"
+ },
+ "dialog": {
+ "create": {
+ "title": "创建新人格"
+ },
+ "edit": {
+ "title": "编辑人格"
+ }
+ },
+ "empty": {
+ "title": "暂无人格配置",
+ "description": "创建您的第一个人格来开始使用个性化的聊天机器人"
+ },
+ "validation": {
+ "required": "此字段为必填项",
+ "minLength": "最少需要 {min} 个字符",
+ "alphanumeric": "只能包含字母、数字、下划线和连字符",
+ "dialogRequired": "{type}不能为空"
+ },
+ "messages": {
+ "loadError": "加载人格列表失败",
+ "saveSuccess": "保存成功",
+ "saveError": "保存失败",
+ "deleteConfirm": "确定要删除人格 \"{id}\" 吗?此操作不可撤销。",
+ "deleteSuccess": "删除成功",
+ "deleteError": "删除失败"
+ }
+}
diff --git a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json
index 61b8691bc..663488497 100644
--- a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json
+++ b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json
@@ -107,6 +107,9 @@
"failed": "导入配置失败: {error}"
},
"configParseError": "配置解析错误: {error}",
- "noAvailableConfig": "无可用配置"
+ "noAvailableConfig": "无可用配置",
+ "toggleToolSuccess": "工具状态切换成功!",
+ "toggleToolError": "工具状态切换失败: {error}",
+ "testError": "测试连接失败: {error}"
}
-}
\ No newline at end of file
+}
\ No newline at end of file
diff --git a/dashboard/src/i18n/translations.ts b/dashboard/src/i18n/translations.ts
index c46ca4d0d..f5ea8c97e 100644
--- a/dashboard/src/i18n/translations.ts
+++ b/dashboard/src/i18n/translations.ts
@@ -25,6 +25,7 @@ import zhCNDashboard from './locales/zh-CN/features/dashboard.json';
import zhCNAlkaidIndex from './locales/zh-CN/features/alkaid/index.json';
import zhCNAlkaidKnowledgeBase from './locales/zh-CN/features/alkaid/knowledge-base.json';
import zhCNAlkaidMemory from './locales/zh-CN/features/alkaid/memory.json';
+import zhCNPersona from './locales/zh-CN/features/persona.json';
import zhCNErrors from './locales/zh-CN/messages/errors.json';
import zhCNSuccess from './locales/zh-CN/messages/success.json';
@@ -54,6 +55,7 @@ import enUSDashboard from './locales/en-US/features/dashboard.json';
import enUSAlkaidIndex from './locales/en-US/features/alkaid/index.json';
import enUSAlkaidKnowledgeBase from './locales/en-US/features/alkaid/knowledge-base.json';
import enUSAlkaidMemory from './locales/en-US/features/alkaid/memory.json';
+import enUSPersona from './locales/en-US/features/persona.json';
import enUSErrors from './locales/en-US/messages/errors.json';
import enUSSuccess from './locales/en-US/messages/success.json';
@@ -88,7 +90,8 @@ export const translations = {
index: zhCNAlkaidIndex,
'knowledge-base': zhCNAlkaidKnowledgeBase,
memory: zhCNAlkaidMemory
- }
+ },
+ persona: zhCNPersona
},
messages: {
errors: zhCNErrors,
@@ -123,7 +126,8 @@ export const translations = {
index: enUSAlkaidIndex,
'knowledge-base': enUSAlkaidKnowledgeBase,
memory: enUSAlkaidMemory
- }
+ },
+ persona: enUSPersona
},
messages: {
errors: enUSErrors,
diff --git a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts
index 43f062d1c..09f0e33e0 100644
--- a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts
+++ b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts
@@ -63,9 +63,13 @@ const sidebarItem: menu[] = [
icon: 'mdi-account-group',
to: '/session-management'
},
+ {
+ title: 'core.navigation.persona',
+ icon: 'mdi-heart',
+ to: '/persona'
+ },
{
title: 'core.navigation.console',
-
icon: 'mdi-console',
to: '/console'
},
diff --git a/dashboard/src/router/MainRoutes.ts b/dashboard/src/router/MainRoutes.ts
index 8beff1f26..9706f2b73 100644
--- a/dashboard/src/router/MainRoutes.ts
+++ b/dashboard/src/router/MainRoutes.ts
@@ -56,6 +56,11 @@ const MainRoutes = {
path: '/session-management',
component: () => import('@/views/SessionManagementPage.vue')
},
+ {
+ name: 'Persona',
+ path: '/persona',
+ component: () => import('@/views/PersonaPage.vue')
+ },
{
name: 'Console',
path: '/console',
diff --git a/dashboard/src/views/PersonaPage.vue b/dashboard/src/views/PersonaPage.vue
new file mode 100644
index 000000000..3094a5b09
--- /dev/null
+++ b/dashboard/src/views/PersonaPage.vue
@@ -0,0 +1,808 @@
+
+
+
+
+
+
+
+ mdi-heart{{ t('core.navigation.persona') }}
+
+
+ {{ tm('page.description') }}
+
+
+
+
+ {{ tm('buttons.create') }}
+
+
+
+
+
+
+
+
+
+
+
+ {{ persona.persona_id }}
+
+
+
+
+
+
+
+
+ mdi-pencil
+ {{ tm('buttons.edit') }}
+
+
+
+
+ mdi-delete
+ {{ tm('buttons.delete') }}
+
+
+
+
+
+
+
+
+ {{ truncateText(persona.system_prompt, 100) }}
+
+
+
+
+ {{ tm('labels.presetDialogs', { count: persona.begin_dialogs.length / 2 }) }}
+
+
+
+
+ {{ tm('labels.createdAt') }}: {{ formatDate(persona.created_at) }}
+
+
+
+
+
+
+
+
+ mdi-account-group
+ {{ tm('empty.title') }}
+ {{ tm('empty.description') }}
+
+ {{ tm('buttons.createFirst') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ editingPersona ? tm('dialog.edit.title') : tm('dialog.create.title') }}
+
+
+
+
+
+
+
+
+
+
+
+
+ mdi-tools
+ {{ tm('form.tools') }}
+
+ {{ personaForm.tools.length }}
+
+
+
+
+
+
+ {{ tm('form.toolsHelp') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
{{ tm('form.mcpServersQuickSelect') }}
+
+
+ mdi-server
+ {{ server.name }}
+
+ ({{ server.tools.length }})
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ item.name }}
+
+ {{ item.mcp_server_name }}
+
+
+
+
+ {{ truncateText(item.description, 100) }}
+
+
+
+
+
+
+
+
mdi-tools
+
{{ tm('form.noToolsAvailable')
+ }}
+
+
+
+
+
mdi-magnify
+
{{ tm('form.noToolsFound') }}
+
+
+
+
+
+
+
{{ tm('form.loadingTools')
+ }}
+
+
+
+
+
+
+ {{ tm('form.selectedTools') }}
+
+ ({{ tm('form.allSelected') }})
+
+
+ ({{ personaForm.tools.length }})
+
+
+
+
+ {{ toolName }}
+
+
+
+ {{ tm('form.noToolsSelected') }}
+
+
+
+
+
+
+
+
+
+
+ mdi-chat
+ {{ tm('form.presetDialogs') }}
+
+ {{ personaForm.begin_dialogs.length / 2 }}
+
+
+
+
+
+
+ {{ tm('form.presetDialogsHelp') }}
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ tm('buttons.addDialogPair') }}
+
+
+
+
+
+
+
+
+
+
+ {{ tm('buttons.cancel') }}
+
+
+ {{ tm('buttons.save') }}
+
+
+
+
+
+
+
+
+
+ {{ viewingPersona.persona_id }}
+
+
+
+
+
+
{{ tm('form.systemPrompt') }}
+
+ {{ viewingPersona.system_prompt }}
+
+
+
+
+
{{ tm('form.presetDialogs') }}
+
+
+ {{ index % 2 === 0 ? tm('form.userMessage') : tm('form.assistantMessage') }}
+
+
+ {{ dialog }}
+
+
+
+
+
+
{{ tm('form.tools') }}
+
+
+ {{ tm('form.allToolsAvailable') }}
+
+
+
+
+ {{ toolName }}
+
+
+
+ {{ tm('form.noToolsSelected') }}
+
+
+
+
+
{{ tm('labels.createdAt') }}: {{ formatDate(viewingPersona.created_at) }}
+
{{ tm('labels.updatedAt') }}: {{
+ formatDate(viewingPersona.updated_at) }}
+
+
+
+
+
+
+
+ {{ message }}
+
+
+
+
+
+
+
diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/views/ToolUsePage.vue
index 6060088a4..c89634153 100644
--- a/dashboard/src/views/ToolUsePage.vue
+++ b/dashboard/src/views/ToolUsePage.vue
@@ -405,24 +405,36 @@
+ 复选框代表该工具是否被启用。
+
+
+
+
- {{ tool.function.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
+ {{ tool.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
- {{ formatToolName(tool.function.name) }}
+ :title="tool.name">
+ {{ formatToolName(tool.name) }}
-
- {{ tool.function.description }}
+
+ {{ tool.description }}
@@ -434,9 +446,9 @@
mdi-information
{{ tm('functionTools.description') }}
- {{ tool.function.description }}
+ {{ tool.description }}
-
+
mdi-code-json
{{ tm('functionTools.parameters') }}
@@ -451,7 +463,7 @@
-
+
| {{ paramName }} |
@@ -562,8 +574,8 @@ export default {
const searchTerm = this.toolSearch.toLowerCase();
return this.tools.filter(tool =>
- tool.function.name.toLowerCase().includes(searchTerm) ||
- tool.function.description.toLowerCase().includes(searchTerm)
+ tool.name.toLowerCase().includes(searchTerm) ||
+ tool.description.toLowerCase().includes(searchTerm)
);
},
@@ -658,7 +670,7 @@ export default {
},
getTools() {
- axios.get('/api/config/llmtools')
+ axios.get('/api/tools/list')
.then(response => {
this.tools = response.data.data || [];
})
@@ -976,6 +988,28 @@ export default {
} catch (e) {
this.showError(this.tm('messages.importError.failed', { error: e.message }));
}
+ },
+
+ // 切换工具状态
+ async toggleToolStatus(tool) {
+ try {
+ const response = await axios.post('/api/tools/toggle-tool', {
+ name: tool.name,
+ activate: tool.active
+ });
+
+ if (response.data.status === 'ok') {
+ this.showSuccess(response.data.message || this.tm('messages.toggleToolSuccess'));
+ } else {
+ // 如果失败,恢复原状态
+ tool.active = !tool.active;
+ this.showError(response.data.message || this.tm('messages.toggleToolError'));
+ }
+ } catch (error) {
+ // 如果失败,恢复原状态
+ tool.active = !tool.active;
+ this.showError(this.tm('messages.toggleToolError', { error: error.response?.data?.message || error.message }));
+ }
}
}
}
diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py
index 8a73945f5..b3b4954d3 100644
--- a/packages/astrbot/main.py
+++ b/packages/astrbot/main.py
@@ -26,6 +26,7 @@ from .long_term_memory import LongTermMemory
from astrbot.core import logger
from astrbot.api.message_components import Plain, Image, Reply
from astrbot.core.star.session_llm_manager import SessionServiceManager
+from astrbot.core.provider.func_tool_manager import ToolSet
from typing import Union
from enum import Enum
@@ -128,7 +129,6 @@ class Main(star.Star):
/reset: 重置 LLM 会话
/history: 当前对话的对话记录
/persona: 人格情景(op)
-/tool ls: 函数工具
/key: API Key(op)
/websearch: 网页搜索
{notice}"""
@@ -157,50 +157,22 @@ class Main(star.Star):
@tool.command("ls")
async def tool_ls(self, event: AstrMessageEvent):
"""查看函数工具列表"""
- tm = self.context.get_llm_tool_manager()
- msg = "函数工具:\n"
- for tool in tm.func_list:
- active = " (启用)" if tool.active else "(停用)"
- msg += f"- {tool.name}: {tool.description} {active}\n"
-
- msg += "\n使用 /tool on/off <工具名> 激活或者停用函数工具。/tool off_all 停用所有函数工具。"
- event.set_result(MessageEventResult().message(msg).use_t2i(False))
+ event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
@tool.command("on")
async def tool_on(self, event: AstrMessageEvent, tool_name: str):
"""启用一个函数工具"""
- if self.context.activate_llm_tool(tool_name):
- event.set_result(
- MessageEventResult().message(f"激活工具 {tool_name} 成功。")
- )
- else:
- event.set_result(
- MessageEventResult().message(
- f"激活工具 {tool_name} 失败,未找到此工具。"
- )
- )
+ event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
@tool.command("off")
async def tool_off(self, event: AstrMessageEvent, tool_name: str):
"""停用一个函数工具"""
- if self.context.deactivate_llm_tool(tool_name):
- event.set_result(
- MessageEventResult().message(f"停用工具 {tool_name} 成功。")
- )
- else:
- event.set_result(
- MessageEventResult().message(
- f"停用工具 {tool_name} 失败,未找到此工具。"
- )
- )
+ event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
@tool.command("off_all")
async def tool_all_off(self, event: AstrMessageEvent):
"""停用所有函数工具"""
- tm = self.context.get_llm_tool_manager()
- for tool in tm.func_list:
- self.context.deactivate_llm_tool(tool.name)
- event.set_result(MessageEventResult().message("停用所有工具成功。"))
+ event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
@filter.command_group("plugin")
def plugin(self):
@@ -1264,26 +1236,34 @@ UID: {user_id} 此 ID 可用于设置管理员。
if req.conversation:
persona_id = req.conversation.persona_id
if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
- persona_id = self.context.provider_manager.selected_default_persona[
+ persona_id = self.context.persona_manager.selected_default_persona_v3[
"name"
]
persona = next(
builtins.filter(
lambda persona: persona["name"] == persona_id,
- self.context.provider_manager.personas,
+ self.context.persona_manager.personas_v3,
),
None,
)
if persona:
if prompt := persona["prompt"]:
req.system_prompt += prompt
- if mood_dialogs := persona["_mood_imitation_dialogs_processed"]:
- req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
- req.system_prompt += mood_dialogs
- if (
- begin_dialogs := persona["_begin_dialogs_processed"]
- ) and not req.contexts:
+ if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
+ # tools select
+ tmgr = self.context.get_llm_tool_manager()
+ if (persona and persona.get("tools") is None) or not persona:
+ # select all
+ toolset = tmgr.get_full_tool_set()
+ else:
+ toolset = ToolSet()
+ for tool_name in persona["tools"]:
+ tool = tmgr.get_func(tool_name)
+ if tool:
+ toolset.add_tool(tool)
+ req.func_tool = toolset
+ logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}")
if quote:
sender_info = ""
|