Improve: 引入全新的人格管理模式以及重构函数工具管理器 (#2305)

* feat: add persona management

* refactor:  重构函数工具管理器,引入 ToolSet,并让 Persona 支持绑定 Tools

* feat: 更新 Persona 工具选择逻辑,支持全选和指定工具的切换

* feat: 更新 BaseDatabase 中的 persona 方法返回类型,支持返回 None
This commit is contained in:
Soulter
2025-08-04 00:56:26 +08:00
committed by GitHub
parent 87f05fce66
commit b1e3018b6b
34 changed files with 2112 additions and 580 deletions
+1 -39
View File
@@ -115,8 +115,7 @@ DEFAULT_CONFIG = {
"log_level": "INFO",
"pip_install_arg": "",
"pypi_index_url": "https://mirrors.aliyun.com/pypi/simple/",
"knowledge_db": {},
"persona": [],
"persona": [], # deprecated
"timezone": "",
"callback_api_base": "",
}
@@ -1701,43 +1700,6 @@ CONFIG_METADATA_2 = {
},
},
},
"persona": {
"description": "人格情景设置",
"type": "list",
"config_template": {
"新人格情景": {
"name": "",
"prompt": "",
"begin_dialogs": [],
"mood_imitation_dialogs": [],
}
},
"tmpl_display_title": "name",
"items": {
"name": {
"description": "人格名称",
"type": "string",
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
},
"prompt": {
"description": "设定(系统提示词)",
"type": "text",
"hint": "填写人格的身份背景、性格特征、兴趣爱好、个人经历、口头禅等。",
},
"begin_dialogs": {
"description": "预设对话",
"type": "list",
"items": {"type": "string"},
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
},
"mood_imitation_dialogs": {
"description": "对话风格模仿",
"type": "list",
"items": {"type": "string"},
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
},
},
},
"provider_stt_settings": {
"description": "语音转文本(STT)",
"type": "object",
+20 -3
View File
@@ -22,6 +22,7 @@ from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.star.context import Context
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core import LogBroker
from astrbot.core.db import BaseDatabase
@@ -69,13 +70,23 @@ class AstrBotCoreLifecycle:
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
await self.db.initialize()
await do_migration_v4(self.db, {})
try:
await do_migration_v4(self.db, {}, self.astrbot_config)
except Exception as e:
logger.error(f"迁移到 v4.0.0 新版本数据格式失败: {e}")
# 初始化事件队列
self.event_queue = Queue()
# 初始化人格管理器
self.persona_mgr = PersonaManager(self.db, self.astrbot_config)
await self.persona_mgr.initialize()
# 初始化供应商管理器
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
self.provider_manager = ProviderManager(
self.astrbot_config, self.db, self.persona_mgr
)
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
@@ -94,6 +105,8 @@ class AstrBotCoreLifecycle:
self.provider_manager,
self.platform_manager,
self.conversation_manager,
self.platform_message_history_manager,
self.persona_mgr,
)
# 初始化插件管理器
@@ -110,6 +123,7 @@ class AstrBotCoreLifecycle:
PipelineContext(self.astrbot_config, self.plugin_manager)
)
await self.pipeline_scheduler.initialize()
self.star_context.pipeline_ctx = self.pipeline_scheduler.ctx
# 初始化更新器
self.astrbot_updator = AstrBotUpdator()
@@ -232,6 +246,9 @@ class AstrBotCoreLifecycle:
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
tasks.append(
asyncio.create_task(platform_inst.run(), name=f"{platform_inst.meta().id}({platform_inst.meta().name})")
asyncio.create_task(
platform_inst.run(),
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
)
)
return tasks
+17
View File
@@ -205,6 +205,7 @@ class BaseDatabase(abc.ABC):
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
tools: list[str] = None,
) -> Persona:
"""Insert a new persona record."""
...
@@ -219,6 +220,22 @@ class BaseDatabase(abc.ABC):
"""Get all personas for a specific bot."""
...
@abc.abstractmethod
async def update_persona(
self,
persona_id: str,
system_prompt: str = None,
begin_dialogs: list[str] = None,
tools: list[str] = None,
) -> Persona | None:
"""Update a persona's system prompt or begin dialogs."""
...
@abc.abstractmethod
async def delete_persona(self, persona_id: str) -> None:
"""Delete a persona by its ID."""
...
@abc.abstractmethod
async def insert_preference_or_update(self, key: str, value: str) -> Preference:
"""Insert a new preference record."""
+8 -1
View File
@@ -1,11 +1,13 @@
import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.db import BaseDatabase
from astrbot.core.config import AstrBotConfig
from astrbot.api import logger
from .migra_3_to_4 import (
migration_conversation_table,
migration_platform_table,
migration_webchat_data,
migration_persona_data,
)
@@ -24,7 +26,9 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
async def do_migration_v4(
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
astrbot_config: AstrBotConfig,
):
"""
执行数据库迁移
@@ -45,6 +49,9 @@ async def do_migration_v4(
# 执行 WebChat 数据迁移
await migration_webchat_data(db_helper, platform_id_map)
# 执行人格数据迁移
await migration_persona_data(db_helper, astrbot_config)
# 标记迁移完成
await db_helper.insert_preference_or_update("migration_done_v4", "true")
+68 -22
View File
@@ -4,12 +4,10 @@ from .. import BaseDatabase
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
from astrbot.core.config.default import DB_PATH
from astrbot.api import logger
from astrbot.core.config import AstrBotConfig
from astrbot.core.platform.astr_message_event import MessageSesion
from sqlalchemy.ext.asyncio import AsyncSession
from astrbot.core.db.po import (
ConversationV2,
PlatformMessageHistory,
)
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
from sqlalchemy import text
"""
@@ -50,20 +48,21 @@ async def migration_conversation_table(
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
for conversation in conversations:
for idx, conversation in enumerate(conversations):
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
progress = int((idx + 1) / total_cnt * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
logger.warning(
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" not in conv.user_id:
logger.warning(
f"跳过 user_id 为 {conv.user_id} 的会话,它可能是 WebChat 的消息历史记录。"
)
continue
session = MessageSesion.from_str(session_str=conv.user_id)
platform_id = get_platform_id(
@@ -81,13 +80,12 @@ async def migration_conversation_table(
updated_at=datetime.datetime.fromtimestamp(conv.updated_at),
)
dbsession.add(conv_v2)
if conv_v2:
logger.info(f"迁移旧会话 {conv.cid} 到新表成功。")
except Exception as e:
logger.error(
f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}",
exc_info=True,
)
logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。")
async def migration_platform_table(
@@ -107,7 +105,7 @@ async def migration_platform_table(
platform_stats_v3 = stats.platform
if not platform_stats_v3:
logger.warning("没有找到旧平台数据,跳过迁移。")
logger.info("没有找到旧平台数据,跳过迁移。")
return
first_time_stamp = platform_stats_v3[0].timestamp
@@ -120,7 +118,13 @@ async def migration_platform_table(
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
for bucket_end in range(start_time, end_time, 3600):
total_buckets = (end_time - start_time) // 3600
for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
if bucket_idx % 500 == 0:
progress = int((bucket_idx + 1) / total_buckets * 100)
logger.info(
f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})"
)
cnt = 0
while (
idx < len(platform_stats_v3)
@@ -136,9 +140,6 @@ async def migration_platform_table(
platform_type = get_platform_type(
platform_id_map, platform_stats_v3[idx].name
)
logger.info(
f"迁移平台统计数据: {platform_id}, {platform_type}, 时间戳: {bucket_end}, 计数: {cnt}"
)
try:
await dbsession.execute(
text("""
@@ -161,6 +162,7 @@ async def migration_platform_table(
f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}",
exc_info=True,
)
logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。")
async def migration_webchat_data(
@@ -178,20 +180,21 @@ async def migration_webchat_data(
async with db_helper.get_db() as dbsession:
dbsession: AsyncSession
async with dbsession.begin():
for conversation in conversations:
for idx, conversation in enumerate(conversations):
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
progress = int((idx + 1) / total_cnt * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
try:
conv = db_helper_v3.get_conversation_by_user_id(
user_id=conversation.get("user_id", "unknown"),
cid=conversation.get("cid", "unknown"),
)
if not conv:
logger.warning(
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
)
if ":" in conv.user_id:
logger.warning(
f"跳过 user_id 为 {conv.user_id} 的会话,它不是 WebChat 的消息历史记录。"
)
continue
platform_id = "webchat"
history = json.loads(conv.history) if conv.history else []
@@ -206,9 +209,52 @@ async def migration_webchat_data(
)
dbsession.add(new_history)
logger.info(f"迁移旧 WebChat 会话 {conv.cid} 到新表成功。")
except Exception:
logger.error(
f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败",
exc_info=True,
)
logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。")
async def migration_persona_data(
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
):
"""
迁移 Persona 数据到新的表中。
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
"""
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
total_personas = len(v3_persona_config)
logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...")
for idx, persona in enumerate(v3_persona_config):
if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0:
progress = int((idx + 1) / total_personas * 100)
if progress % 10 == 0:
logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})")
try:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
mood_prompt = ""
user_turn = True
for mood_dialog in mood_imitation_dialogs:
if user_turn:
mood_prompt += f"A: {mood_dialog}\n"
else:
mood_prompt += f"B: {mood_dialog}\n"
user_turn = not user_turn
system_prompt = persona.get("prompt", "")
if mood_prompt:
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
persona_new = await db_helper.insert_persona(
persona_id=persona["name"],
system_prompt=system_prompt,
begin_dialogs=begin_dialogs,
)
logger.info(
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
+33 -6
View File
@@ -9,7 +9,7 @@ from sqlmodel import (
UniqueConstraint,
Field,
)
from typing import Optional
from typing import Optional, TypedDict
class PlatformStat(SQLModel, table=True):
@@ -39,12 +39,14 @@ class PlatformStat(SQLModel, table=True):
class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations"
inner_conversation_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
inner_conversation_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
)
conversation_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4())
default_factory=lambda: str(uuid.uuid4()),
)
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False)
@@ -78,6 +80,8 @@ class Persona(SQLModel, table=True):
system_prompt: str = Field(sa_type=Text, nullable=False)
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
"""a list of strings, each representing a dialog to start with"""
tools: Optional[list] = Field(default=None, sa_type=JSON)
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
@@ -119,7 +123,9 @@ class PlatformMessageHistory(SQLModel, table=True):
platform_id: str = Field(nullable=False)
user_id: str = Field(nullable=False) # An id of group, user in platform
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
sender_name: Optional[str] = Field(default=None) # Name of the sender in the platform
sender_name: Optional[str] = Field(
default=None
) # Name of the sender in the platform
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
@@ -136,12 +142,14 @@ class Attachment(SQLModel, table=True):
__tablename__ = "attachments"
inner_attachment_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
inner_attachment_id: int = Field(
primary_key=True, sa_column_kwargs={"autoincrement": True}
)
attachment_id: str = Field(
max_length=36,
nullable=False,
unique=True,
default_factory=lambda: str(uuid.uuid4())
default_factory=lambda: str(uuid.uuid4()),
)
path: str = Field(nullable=False) # Path to the file on disk
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
@@ -182,6 +190,25 @@ class Conversation:
updated_at: int = 0
class Personality(TypedDict):
"""LLM 人格类。
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
"""
prompt: str = ""
name: str = ""
begin_dialogs: list[str] = []
mood_imitation_dialogs: list[str] = []
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
tools: list[str] | None = None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache
_begin_dialogs_processed: list[dict] = []
_mood_imitation_dialogs_processed: str = ""
# ====
# Deprecated, and will be removed in future versions.
# ====
+37 -5
View File
@@ -19,6 +19,7 @@ from sqlalchemy import select, update, delete, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import func
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
class SQLiteDatabase(BaseDatabase):
def __init__(self, db_path: str) -> None:
@@ -99,9 +100,7 @@ class SQLiteDatabase(BaseDatabase):
# Conversation Management
# ====
async def get_conversations(
self, user_id=None, platform_id=None
):
async def get_conversations(self, user_id=None, platform_id=None):
async with self.get_db() as session:
session: AsyncSession
query = select(ConversationV2)
@@ -224,7 +223,7 @@ class SQLiteDatabase(BaseDatabase):
return
query = query.values(**values)
await session.execute(query)
return await self.get_conversation_by_id(cid)
return await self.get_conversation_by_id(cid)
async def delete_conversation(self, cid):
async with self.get_db() as session:
@@ -312,7 +311,9 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalar_one_or_none()
async def insert_persona(self, persona_id, system_prompt, begin_dialogs=None):
async def insert_persona(
self, persona_id, system_prompt, begin_dialogs=None, tools=None
):
"""Insert a new persona record."""
async with self.get_db() as session:
session: AsyncSession
@@ -321,6 +322,7 @@ class SQLiteDatabase(BaseDatabase):
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs or [],
tools=tools,
)
session.add(new_persona)
return new_persona
@@ -341,6 +343,36 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalars().all()
async def update_persona(
self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN
):
"""Update a persona's system prompt or begin dialogs."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = update(Persona).where(Persona.persona_id == persona_id)
values = {}
if system_prompt is not None:
values["system_prompt"] = system_prompt
if begin_dialogs is not None:
values["begin_dialogs"] = begin_dialogs
if tools is not NOT_GIVEN:
values["tools"] = tools
if not values:
return
query = query.values(**values)
await session.execute(query)
return await self.get_persona_by_id(persona_id)
async def delete_persona(self, persona_id):
"""Delete a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
await session.execute(
delete(Persona).where(Persona.persona_id == persona_id)
)
async def insert_preference_or_update(self, key, value):
"""Insert a new preference record or update if it exists."""
async with self.get_db() as session:
+162
View File
@@ -0,0 +1,162 @@
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Persona, Personality
from astrbot.core.config import AstrBotConfig
from astrbot import logger
class PersonaManager:
def __init__(self, db_helper: BaseDatabase, astrbot_config: AstrBotConfig):
self.db = db_helper
self.config = astrbot_config
_ps: dict = astrbot_config["provider_settings"]
self.default_persona: str = _ps.get("default_personality", "default")
self.personas: list[Persona] = []
self.selected_default_persona: Persona | None = None
self.personas_v3: list[Personality] = []
self.selected_default_persona_v3: Personality | None = None
self.persona_v3_config: list[dict] = []
async def initialize(self):
self.personas = await self.get_all_personas()
self.get_v3_persona_data()
logger.info(f"已加载 {len(self.personas)} 个人格。")
async def get_persona(self, persona_id: str):
"""获取指定 persona 的信息"""
persona = await self.db.get_persona_by_id(persona_id)
if not persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
return persona
async def delete_persona(self, persona_id: str):
"""删除指定 persona"""
if not await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} does not exist.")
await self.db.delete_persona(persona_id)
self.personas = [p for p in self.personas if p.persona_id != persona_id]
self.get_v3_persona_data()
async def update_persona(
self,
persona_id: str,
system_prompt: str = None,
begin_dialogs: list[str] = None,
tools: list[str] = None,
):
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
existing_persona = await self.db.get_persona_by_id(persona_id)
if not existing_persona:
raise ValueError(f"Persona with ID {persona_id} does not exist.")
persona = await self.db.update_persona(
persona_id, system_prompt, begin_dialogs, tools=tools
)
if persona:
for i, p in enumerate(self.personas):
if p.persona_id == persona_id:
self.personas[i] = persona
break
self.get_v3_persona_data()
return persona
async def get_all_personas(self) -> list[Persona]:
"""获取所有 personas"""
return await self.db.get_personas()
async def create_persona(
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
tools: list[str] = None,
) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} already exists.")
new_persona = await self.db.insert_persona(
persona_id, system_prompt, begin_dialogs, tools=tools
)
self.personas.append(new_persona)
self.get_v3_persona_data()
return new_persona
def get_v3_persona_data(
self,
) -> tuple[list[dict], list[Personality], Personality]:
"""获取 AstrBot <4.0.0 版本的 persona 数据。
Returns:
- list[dict]: 包含 persona 配置的字典列表。
- list[Personality]: 包含 Personality 对象的列表。
- Personality: 默认选择的 Personality 对象。
"""
v3_persona_config = [
{
"prompt": persona.system_prompt,
"name": persona.persona_id,
"begin_dialogs": persona.begin_dialogs or [],
"mood_imitation_dialogs": [], # deprecated
"tools": persona.tools,
}
for persona in self.personas
]
personas_v3: list[Personality] = []
selected_default_persona: Personality | None = None
for persona_cfg in v3_persona_config:
begin_dialogs = persona_cfg.get("begin_dialogs", [])
bd_processed = []
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
)
begin_dialogs = []
user_turn = True
for dialog in begin_dialogs:
bd_processed.append(
{
"role": "user" if user_turn else "assistant",
"content": dialog,
"_no_save": None, # 不持久化到 db
}
)
user_turn = not user_turn
try:
persona = Personality(
**persona_cfg,
_begin_dialogs_processed=bd_processed,
_mood_imitation_dialogs_processed="", # deprecated
)
if persona["name"] == self.default_persona:
selected_default_persona = persona
personas_v3.append(persona)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
if not selected_default_persona and len(personas_v3) > 0:
# 默认选择第一个
selected_default_persona = personas_v3[0]
if not selected_default_persona:
selected_default_persona = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
tools=None,
_begin_dialogs_processed=[],
)
personas_v3.append(selected_default_persona)
self.personas_v3 = personas_v3
self.selected_default_persona_v3 = selected_default_persona
self.persona_v3_config = v3_persona_config
self.selected_default_persona = Persona(
persona_id=selected_default_persona["name"],
system_prompt=selected_default_persona["prompt"],
begin_dialogs=selected_default_persona["begin_dialogs"],
tools=selected_default_persona["tools"] or None,
)
return v3_persona_config, personas_v3, selected_default_persona
+8 -2
View File
@@ -55,7 +55,7 @@ class PipelineContext:
handler: T.Awaitable,
*args,
**kwargs,
) -> T.AsyncGenerator[None, None]:
) -> T.AsyncGenerator[T.Any, None]:
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
@@ -77,10 +77,16 @@ class PipelineContext:
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as _:
if self.plugin_manager:
context = self.plugin_manager.context
else:
raise ValueError(
"Cannot call handler without a valid context or plugin manager."
)
# 向下兼容
trace_ = traceback.format_exc()
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
ready_to_call = handler(event, context, *args, **kwargs)
if inspect.isasyncgen(ready_to_call):
_has_yielded = False
@@ -21,6 +21,7 @@ from mcp.types import (
EmbeddedResource,
TextResourceContents,
BlobResourceContents,
CallToolResult,
)
from astrbot.core.star.star_handler import EventType
from astrbot import logger
@@ -193,50 +194,25 @@ class ToolLoopAgent(BaseAgentRunner):
if not req.func_tool:
return
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
)
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
res = await client.session.call_tool(func_tool.name, func_tool_args)
if not res:
continue
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
executor = func_tool.execute(
event=self.event,
pipeline_context=self.pipeline_ctx,
**func_tool_args,
)
async for resp in executor:
if isinstance(resp, CallToolResult):
res = resp
if isinstance(res.content[0], TextContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
content=res.content[0].text,
)
)
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
@@ -247,41 +223,54 @@ class ToolLoopAgent(BaseAgentRunner):
yield MessageChain(type="tool_direct_result").base64_image(
res.content[0].data
)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
yield MessageChain().message("返回的数据类型不受支持。")
else:
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
# 尝试调用工具函数
wrapper = self.pipeline_ctx.call_handler(
self.event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None:
# Tool 返回结果
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
yield MessageChain().message(resp)
else:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self._transition_state(AgentState.DONE)
if res := self.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
elif isinstance(res.content[0], EmbeddedResource):
resource = res.content[0].resource
if isinstance(resource, TextResourceContents):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resource.text,
)
)
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
and resource.mimeType.startswith("image/")
):
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回了图片(已直接发送给用户)",
)
)
yield MessageChain(
type="tool_direct_result"
).base64_image(res.content[0].data)
else:
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="返回的数据类型不受支持",
)
)
yield MessageChain().message("返回的数据类型不受支持。")
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
self._transition_state(AgentState.DONE)
if res := self.event.get_result():
if res.chain:
yield MessageChain(
chain=res.chain, type="tool_direct_result"
)
else:
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
)
self.event.clear_result()
except Exception as e:
@@ -100,7 +100,8 @@ class LLMRequestSubStage(Stage):
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
@@ -274,7 +275,6 @@ class LLMRequestSubStage(Stage):
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
async def _handle_webchat(
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
):
@@ -307,19 +307,10 @@ class LLMRequestSubStage(Stage):
if not title or "<None>" in title:
return
await self.conv_manager.update_conversation_title(
event.unified_msg_origin, title=title
unified_msg_origin=event.unified_msg_origin,
title=title,
conversation_id=req.conversation.cid,
)
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
# webchat adapter 中,session_id 的格式是 f"webchat!{username}!{cid}"
# TODO: 优化 WebChat 适配器的对话管理
if event.session_id:
username, cid = event.session_id.split("!")[1:3]
db_helper = self.ctx.plugin_manager.context._db
db_helper.update_conversation_title(
user_id=username,
cid=cid,
title=title,
)
async def _save_to_history(
self,
+2 -2
View File
@@ -5,7 +5,7 @@ from astrbot.core.utils.io import download_image_by_url
from astrbot import logger
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from .func_tool_manager import FunctionToolManager, ToolSet
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
@@ -97,7 +97,7 @@ class ProviderRequest:
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
func_tool: FuncCall | None = None
func_tool: FunctionToolManager | ToolSet | None = None
"""可用的函数工具"""
contexts: list[dict] = field(default_factory=list)
"""上下文。格式与 openai 的上下文格式一致:
+354 -231
View File
@@ -1,16 +1,17 @@
from __future__ import annotations
import json
import textwrap
import os
import asyncio
import logging
from datetime import timedelta
from deprecated import deprecated
from typing import Dict, List, Awaitable, Literal, Any
from typing import Dict, List, Awaitable, Literal, Any, AsyncGenerator
from dataclasses import dataclass
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
from astrbot.core import sp
from astrbot.core.utils.log_pipe import LogPipe
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
@@ -28,6 +29,13 @@ except (ModuleNotFoundError, ImportError):
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
)
from typing_extensions import TYPE_CHECKING
if TYPE_CHECKING:
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.pipeline.context import PipelineContext
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
SUPPORTED_TYPES = [
@@ -39,6 +47,302 @@ SUPPORTED_TYPES = [
] # json schema 支持的数据类型
@dataclass
class FunctionTool:
"""A class representing a function tool that can be used in function calling."""
name: str
parameters: Dict
description: str
handler: Awaitable = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装导致 handler __module__ functools
"""
active: bool = True
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
async def execute(
self,
event: AstrMessageEvent = None,
pipeline_context: "PipelineContext" = None,
**tool_args,
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]:
"""执行函数调用。
Args:
event (AstrMessageEvent): 事件对象, origin local 时必须提供
pipeline_context (PipelineContext): 流水线调度器上下文, origin local 时必须提供
**kwargs: 函数调用的参数
Returns:
AsyncGenerator[None | mcp.types.CallToolResult, None]
"""
if self.origin == "local":
if not event:
raise ValueError("Event must be provided for local function tools.")
wrapper = pipeline_context.call_handler(
event=event,
handler=self.handler,
**tool_args,
)
async for resp in wrapper:
if resp is not None:
text_content = mcp.types.TextContent(
type="text",
text=str(resp),
)
yield mcp.types.CallToolResult(content=[text_content])
else:
# NOTE: Tool 在这里直接请求发送消息给用户
# TODO: 是否需要判断 event.get_result() 是否为空?
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
yield None
elif self.origin == "mcp":
if not self.mcp_client:
raise ValueError("MCP client is not available for MCP function tools.")
res = await self.mcp_client.session.call_tool(
name=self.name,
arguments=tool_args,
)
if not res:
return
yield res
else:
raise Exception(f"Unknown function origin: {self.origin}")
def __dict__(self) -> dict[str, Any]:
"""将 FunctionTool 转换为字典格式"""
return {
"name": self.name,
"parameters": self.parameters,
"description": self.description,
"active": self.active,
"origin": self.origin,
"mcp_server_name": self.mcp_server_name,
}
# alias for FunctionTool
FuncTool = FunctionTool
class ToolSet:
"""A set of function tools that can be used in function calling.
This class provides methods to add, remove, and retrieve tools, as well as
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
def __init__(self, tools: List[FunctionTool] = None):
self.tools: List[FunctionTool] = tools or []
def empty(self) -> bool:
"""Check if the tool set is empty."""
return len(self.tools) == 0
def add_tool(self, tool: FunctionTool):
"""Add a tool to the set."""
# 检查是否已存在同名工具
for i, existing_tool in enumerate(self.tools):
if existing_tool.name == tool.name:
self.tools[i] = tool
return
self.tools.append(tool)
def remove_tool(self, name: str):
"""Remove a tool by its name."""
self.tools = [tool for tool in self.tools if tool.name != name]
def get_tool(self, name: str) -> Optional[FunctionTool]:
"""Get a tool by its name."""
for tool in self.tools:
if tool.name == name:
return tool
return None
@deprecated(reason="Use add_tool() instead", version="4.0.0")
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
"properties": {},
}
for param in func_args:
params["properties"][param["name"]] = {
"type": param["type"],
"description": param["description"],
}
_func = FunctionTool(
name=name,
parameters=params,
description=desc,
handler=handler,
)
self.add_tool(_func)
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
def remove_func(self, name: str):
"""Remove a function tool by its name."""
self.remove_tool(name)
@deprecated(reason="Use get_tool() instead", version="4.0.0")
def get_func(self, name: str) -> List[FunctionTool]:
"""Get all function tools."""
return self.get_tool(name)
def openai_schema(self, omit_empty_parameters: bool = False) -> List[Dict]:
"""Convert tools to OpenAI API function calling schema format."""
result = []
for tool in self.tools:
func_def = {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
},
}
if tool.parameters.get("properties") or not omit_empty_parameters:
func_def["function"]["parameters"] = tool.parameters
result.append(func_def)
return result
def anthropic_schema(self) -> List[Dict]:
"""Convert tools to Anthropic API format."""
result = []
for tool in self.tools:
tool_def = {
"name": tool.name,
"description": tool.description,
"input_schema": {
"type": "object",
"properties": tool.parameters.get("properties", {}),
"required": tool.parameters.get("required", []),
},
}
result.append(tool_def)
return result
def google_schema(self) -> Dict:
"""Convert tools to Google GenAI API format."""
def convert_schema(schema: dict) -> dict:
"""Convert schema to Gemini API format."""
supported_types = {
"string",
"number",
"integer",
"boolean",
"array",
"object",
"null",
}
supported_formats = {
"string": {"enum", "date-time"},
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
result = {}
if "type" in schema and schema["type"] in supported_types:
result["type"] = schema["type"]
if "format" in schema and schema["format"] in supported_formats.get(
result["type"], set()
):
result["format"] = schema["format"]
else:
result["type"] = "null"
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
if "properties" in schema:
properties = {}
for key, value in schema["properties"].items():
prop_value = convert_schema(value)
if "default" in prop_value:
del prop_value["default"]
properties[key] = prop_value
if properties:
result["properties"] = properties
if "items" in schema:
result["items"] = convert_schema(schema["items"])
return result
tools = [
{
"name": tool.name,
"description": tool.description,
"parameters": convert_schema(tool.parameters),
}
for tool in self.tools
]
declarations = {}
if tools:
declarations["function_declarations"] = tools
return declarations
@deprecated(reason="Use openai_schema() instead", version="4.0.0")
def get_func_desc_openai_style(self, omit_empty_parameters: bool = False):
return self.openai_schema(omit_empty_parameters)
@deprecated(reason="Use anthropic_schema() instead", version="4.0.0")
def get_func_desc_anthropic_style(self):
return self.anthropic_schema()
@deprecated(reason="Use google_schema() instead", version="4.0.0")
def get_func_desc_google_genai_style(self):
return self.google_schema()
def names(self) -> List[str]:
"""获取所有工具的名称列表"""
return [tool.name for tool in self.tools]
def __len__(self):
return len(self.tools)
def __bool__(self):
return len(self.tools) > 0
def __iter__(self):
return iter(self.tools)
def _prepare_config(config: dict) -> dict:
"""准备配置,处理嵌套格式"""
if "mcpServers" in config and config["mcpServers"]:
@@ -105,55 +409,6 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
return False, f"{e!s}"
@dataclass
class FuncTool:
"""
用于描述一个函数调用工具
"""
name: str
parameters: Dict
description: str
handler: Awaitable = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装导致 handler __module__ functools
"""
active: bool = True
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
async def execute(self, **args) -> Any:
"""执行函数调用"""
if self.origin == "local":
if not self.handler:
raise Exception(f"Local function {self.name} has no handler")
return await self.handler(**args)
elif self.origin == "mcp":
if not self.mcp_client or not self.mcp_client.session:
raise Exception(f"MCP client for {self.name} is not available")
# 使用name属性而不是额外的mcp_tool_name
actual_tool_name = (
self.name.split(":")[-1] if ":" in self.name else self.name
)
return await self.mcp_client.session.call_tool(actual_tool_name, args)
else:
raise Exception(f"Unknown function origin: {self.origin}")
class MCPClient:
def __init__(self):
# Initialize session and client objects
@@ -276,10 +531,9 @@ class MCPClient:
self.running_event.set() # Set the running event to indicate cleanup is done
class FuncCall:
class FunctionToolManager:
def __init__(self) -> None:
self.func_list: List[FuncTool] = []
"""内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_client_event: Dict[str, asyncio.Event] = {}
@@ -331,11 +585,15 @@ class FuncCall:
self.func_list.pop(i)
break
def get_func(self, name) -> FuncTool:
def get_func(self, name) -> FuncTool | None:
for f in self.func_list:
if f.name == name:
return f
return None
def get_full_tool_set(self) -> ToolSet:
"""获取完整工具集"""
tool_set = ToolSet(self.func_list.copy())
return tool_set
async def init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
@@ -556,203 +814,68 @@ class FuncCall:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
"""
_l = []
# 处理所有工具(包括本地和MCP工具)
for f in self.func_list:
if not f.active:
continue
func_ = {
"type": "function",
"function": {
"name": f.name,
# "parameters": f.parameters,
"description": f.description,
},
}
func_["function"]["parameters"] = f.parameters
if not f.parameters.get("properties") and omit_empty_parameter_field:
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段
del func_["function"]["parameters"]
_l.append(func_)
return _l
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.openai_schema(omit_empty_parameters=omit_empty_parameter_field)
def get_func_desc_anthropic_style(self) -> list:
"""
获得 Anthropic API 风格的**已经激活**的工具描述
"""
tools = []
for f in self.func_list:
if not f.active:
continue
# Convert internal format to Anthropic style
tool = {
"name": f.name,
"description": f.description,
"input_schema": {
"type": "object",
"properties": f.parameters.get("properties", {}),
# Keep the required field from the original parameters if it exists
"required": f.parameters.get("required", []),
},
}
tools.append(tool)
return tools
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.anthropic_schema()
def get_func_desc_google_genai_style(self) -> dict:
"""
获得 Google GenAI API 风格的**已经激活**的工具描述
"""
tools = [f for f in self.func_list if f.active]
toolset = ToolSet(tools)
return toolset.google_schema()
# Gemini API 支持的数据类型和格式
supported_types = {
"string",
"number",
"integer",
"boolean",
"array",
"object",
"null",
}
supported_formats = {
"string": {"enum", "date-time"},
"integer": {"int32", "int64"},
"number": {"float", "double"},
}
def deactivate_llm_tool(self, name: str) -> bool:
"""停用一个已经注册的函数调用工具。
def convert_schema(schema: dict) -> dict:
"""转换 schema 为 Gemini API 格式"""
Returns:
如果没找到会返回 False"""
func_tool = self.get_func(name)
if func_tool is not None:
func_tool.active = False
# 如果 schema 包含 anyOf,则只返回 anyOf 字段
if "anyOf" in schema:
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
result = {}
return True
return False
if "type" in schema and schema["type"] in supported_types:
result["type"] = schema["type"]
if "format" in schema and schema["format"] in supported_formats.get(
result["type"], set()
):
result["format"] = schema["format"]
else:
# 暂时指定默认为null
result["type"] = "null"
# 因为不想解决循环引用,所以这里直接传入 star_map 先了...
def activate_llm_tool(self, name: str, star_map: dict) -> bool:
func_tool = self.get_func(name)
if func_tool is not None:
if func_tool.handler_module_path in star_map:
if not star_map[func_tool.handler_module_path].activated:
raise ValueError(
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
)
support_fields = {
"title",
"description",
"enum",
"minimum",
"maximum",
"maxItems",
"minItems",
"nullable",
"required",
}
result.update({k: schema[k] for k in support_fields if k in schema})
func_tool.active = True
if "properties" in schema:
properties = {}
for key, value in schema["properties"].items():
prop_value = convert_schema(value)
if "default" in prop_value:
del prop_value["default"]
properties[key] = prop_value
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
if properties: # 只在有非空属性时添加
result["properties"] = properties
if "items" in schema:
result["items"] = convert_schema(schema["items"])
return result
tools = [
{
"name": f.name,
"description": f.description,
**({"parameters": convert_schema(f.parameters)}),
}
for f in self.func_list
if f.active
]
declarations = {}
if tools:
declarations["function_declarations"] = tools
return declarations
async def func_call(self, question: str, session_id: str, provider) -> tuple:
_l = []
for f in self.func_list:
if not f.active:
continue
_l.append(
{
"name": f.name,
"parameters": f.parameters,
"description": f.description,
}
)
func_definition = json.dumps(_l, ensure_ascii=False)
prompt = textwrap.dedent(f"""
ROLE:
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用
TOOLS:
可用的函数列表:
{func_definition}
LIMIT:
1. 你返回的内容应当能够被 Python json 模块解析的 Json 格式字符串
2. 你的 Json 返回的格式如下`[{{"name": "<func_name>", "args": <arg_dict>}}, ...]`参数根据上面提供的函数列表中的参数来填写
3. 允许必要时返回多个函数调用但需保证这些函数调用的顺序正确
4. 如果用户的提问中不需要用到给定的函数请直接返回 `{{"res": False}}`
EXAMPLE:
1. `用户提问`请问一下天气怎么样 `函数调用`[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
用户的提问是{question}
""")
_c = 0
while _c < 3:
try:
res = await provider.text_chat(prompt, session_id)
if res.find("```") != -1:
res = res[res.find("```json") + 7 : res.rfind("```")]
res = json.loads(res)
break
except Exception as e:
_c += 1
if _c == 3:
raise e
if "The message you submitted was too long" in str(e):
raise e
if "res" in res and not res["res"]:
return "", False
tool_call_result = []
for tool in res:
# 说明有函数调用
func_name = tool["name"]
args = tool["args"]
# 调用函数
func_tool = self.get_func(func_name)
if not func_tool:
raise Exception(f"Request function {func_name} not found.")
ret = await func_tool.execute(**args)
if ret:
tool_call_result.append(str(ret))
return tool_call_result, True
return True
return False
def __str__(self):
return str(self.func_list)
def __repr__(self):
return str(self.func_list)
FuncCall = FunctionToolManager
+27 -70
View File
@@ -7,85 +7,27 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db import BaseDatabase
from .entities import ProviderType
from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
from .register import llm_tools, provider_cls_map
from ..persona_mgr import PersonaManager
class ProviderManager:
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
def __init__(
self,
config: AstrBotConfig,
db_helper: BaseDatabase,
persona_mgr: PersonaManager,
):
self.persona_mgr = persona_mgr
self.astrbot_config = config
self.providers_config: List = config["provider"]
self.provider_settings: dict = config["provider_settings"]
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
self.persona_configs: list = config.get("persona", [])
self.astrbot_config = config
# 人格情景管理
# 目前没有拆成独立的模块
self.default_persona_name = self.provider_settings.get(
"default_personality", "default"
)
self.personas: List[Personality] = []
self.selected_default_persona = None
for persona in self.persona_configs:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
bd_processed = []
mid_processed = ""
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(
f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。"
)
begin_dialogs = []
user_turn = True
for dialog in begin_dialogs:
bd_processed.append(
{
"role": "user" if user_turn else "assistant",
"content": dialog,
"_no_save": None, # 不持久化到 db
}
)
user_turn = not user_turn
if mood_imitation_dialogs:
if len(mood_imitation_dialogs) % 2 != 0:
logger.error(
f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。"
)
mood_imitation_dialogs = []
user_turn = True
for dialog in mood_imitation_dialogs:
role = "A" if user_turn else "B"
mid_processed += f"{role}: {dialog}\n"
if not user_turn:
mid_processed += "\n"
user_turn = not user_turn
try:
persona = Personality(
**persona,
_begin_dialogs_processed=bd_processed,
_mood_imitation_dialogs_processed=mid_processed,
)
if persona["name"] == self.default_persona_name:
self.selected_default_persona = persona
self.personas.append(persona)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
if not self.selected_default_persona and len(self.personas) > 0:
# 默认选择第一个
self.selected_default_persona = self.personas[0]
if not self.selected_default_persona:
self.selected_default_persona = Personality(
prompt="You are a helpful and friendly assistant.",
name="default",
_begin_dialogs_processed=[],
_mood_imitation_dialogs_processed="",
)
self.personas.append(self.selected_default_persona)
# 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
self.default_persona_name = persona_mgr.default_persona
self.provider_insts: List[Provider] = []
"""加载的 Provider 的实例"""
@@ -113,6 +55,21 @@ class ProviderManager:
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
@property
def persona_configs(self) -> list:
"""动态获取最新的 persona 配置"""
return self.persona_mgr.persona_v3_config
@property
def personas(self) -> list:
"""动态获取最新的 personas 列表"""
return self.persona_mgr.personas_v3
@property
def selected_default_persona(self):
"""动态获取最新的默认选中 persona"""
return self.persona_mgr.selected_default_persona_v3
async def set_provider(
self, provider_id: str, provider_type: ProviderType, umo: str = None
):
+5 -15
View File
@@ -1,23 +1,13 @@
import abc
from typing import List
from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import AsyncGenerator
from astrbot.core.provider.func_tool_manager import FunctionToolManager, ToolSet
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType
from astrbot.core.provider.register import provider_cls_map
from astrbot.core.db.po import Personality
from dataclasses import dataclass
class Personality(TypedDict):
prompt: str = ""
name: str = ""
begin_dialogs: List[str] = []
mood_imitation_dialogs: List[str] = []
# cache
_begin_dialogs_processed: List[dict] = []
_mood_imitation_dialogs_processed: str = ""
@dataclass
class ProviderMeta:
id: str
@@ -90,7 +80,7 @@ class Provider(AbstractProvider):
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
func_tool: FunctionToolManager | ToolSet = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
@@ -119,7 +109,7 @@ class Provider(AbstractProvider):
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
func_tool: FunctionToolManager | ToolSet = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
+23 -31
View File
@@ -11,12 +11,14 @@ from astrbot.core.provider.provider import (
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.func_tool_manager import FunctionToolManager
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform import Platform
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
from astrbot.core.persona_mgr import PersonaManager
from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
@@ -29,6 +31,11 @@ from astrbot.core.star.filter.platform_adapter_type import (
)
from deprecated import deprecated
from typing_extensions import TYPE_CHECKING
if TYPE_CHECKING:
from astrbot.core.pipeline.context import PipelineContext
class Context:
"""
@@ -50,6 +57,8 @@ class Context:
registered_web_apis: list = []
pipeline_ctx: "PipelineContext" = None
# back compatibility
_register_tasks: List[Awaitable] = []
_star_manager = None
@@ -62,6 +71,8 @@ class Context:
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
conversation_manager: ConversationManager = None,
message_history_manager: PlatformMessageHistoryManager = None,
persona_manager: PersonaManager = None,
):
self._event_queue = event_queue
self._config = config
@@ -69,6 +80,8 @@ class Context:
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.conversation_manager = conversation_manager
self.message_history_manager = message_history_manager
self.persona_manager = persona_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
"""根据插件名获取插件的 Metadata"""
@@ -80,7 +93,7 @@ class Context:
"""获取当前载入的所有插件 Metadata 的列表"""
return star_registry
def get_llm_tool_manager(self) -> FuncCall:
def get_llm_tool_manager(self) -> FunctionToolManager:
"""获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools"""
return self.provider_manager.llm_tools
@@ -90,40 +103,14 @@ class Context:
Returns:
如果没找到会返回 False
"""
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
if func_tool.handler_module_path in star_map:
if not star_map[func_tool.handler_module_path].activated:
raise ValueError(
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
)
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
return self.provider_manager.llm_tools.activate_llm_tool(name, star_map)
def deactivate_llm_tool(self, name: str) -> bool:
"""停用一个已经注册的函数调用工具。
Returns:
如果没找到会返回 False"""
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
def register_provider(self, provider: Provider):
"""
@@ -208,7 +195,9 @@ class Context:
return self._event_queue
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform | None:
def get_platform(
self, platform_type: Union[PlatformAdapterType, str]
) -> Platform | None:
"""
获取指定类型的平台适配器
@@ -268,6 +257,9 @@ class Context:
return True
return False
def get_pipeline_context(self) -> "PipelineContext":
return self.pipeline_ctx
"""
以下的方法已经不推荐使用请从 AstrBot 文档查看更好的注册方式
"""
+3 -2
View File
@@ -6,11 +6,11 @@ from .stat import StatRoute
from .log import LogRoute
from .static_file import StaticFileRoute
from .chat import ChatRoute
from .tools import ToolsRoute # 导入新的ToolsRoute
from .tools import ToolsRoute
from .conversation import ConversationRoute
from .file import FileRoute
from .session_management import SessionManagementRoute
from .persona import PersonaRoute
__all__ = [
"AuthRoute",
@@ -25,4 +25,5 @@ __all__ = [
"ConversationRoute",
"FileRoute",
"SessionManagementRoute",
"PersonaRoute",
]
-7
View File
@@ -169,7 +169,6 @@ class ConfigRoute(Route):
"/config/provider/new": ("POST", self.post_new_provider),
"/config/provider/update": ("POST", self.post_update_provider),
"/config/provider/delete": ("POST", self.post_delete_provider),
"/config/llmtools": ("GET", self.get_llm_tools),
"/config/provider/check_one": ("GET", self.check_one_provider_status),
"/config/provider/list": ("GET", self.get_provider_config_list),
"/config/provider/model_list": ("GET", self.get_provider_model_list),
@@ -509,12 +508,6 @@ class ConfigRoute(Route):
return Response().error(str(e)).__dict__
return Response().ok(None, "删除成功,已经实时生效~").__dict__
async def get_llm_tools(self):
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
tools = tool_mgr.get_func_desc_openai_style()
return Response().ok(tools).__dict__
async def _get_astrbot_config(self):
config = self.config
+199
View File
@@ -0,0 +1,199 @@
import traceback
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
from astrbot.core.db import BaseDatabase
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class PersonaRoute(Route):
def __init__(
self,
context: RouteContext,
db_helper: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.routes = {
"/persona/list": ("GET", self.list_personas),
"/persona/detail": ("POST", self.get_persona_detail),
"/persona/create": ("POST", self.create_persona),
"/persona/update": ("POST", self.update_persona),
"/persona/delete": ("POST", self.delete_persona),
}
self.db_helper = db_helper
self.persona_mgr = core_lifecycle.persona_mgr
self.register_routes()
async def list_personas(self):
"""获取所有人格列表"""
try:
personas = await self.persona_mgr.get_all_personas()
return (
Response()
.ok(
[
{
"persona_id": persona.persona_id,
"system_prompt": persona.system_prompt,
"begin_dialogs": persona.begin_dialogs or [],
"tools": persona.tools,
"created_at": persona.created_at.isoformat()
if persona.created_at
else None,
"updated_at": persona.updated_at.isoformat()
if persona.updated_at
else None,
}
for persona in personas
]
)
.__dict__
)
except Exception as e:
logger.error(f"获取人格列表失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"获取人格列表失败: {str(e)}").__dict__
async def get_persona_detail(self):
"""获取指定人格的详细信息"""
try:
data = await request.get_json()
persona_id = data.get("persona_id")
if not persona_id:
return Response().error("缺少必要参数: persona_id").__dict__
persona = await self.persona_mgr.get_persona(persona_id)
if not persona:
return Response().error("人格不存在").__dict__
return (
Response()
.ok(
{
"persona_id": persona.persona_id,
"system_prompt": persona.system_prompt,
"begin_dialogs": persona.begin_dialogs or [],
"tools": persona.tools,
"created_at": persona.created_at.isoformat()
if persona.created_at
else None,
"updated_at": persona.updated_at.isoformat()
if persona.updated_at
else None,
}
)
.__dict__
)
except Exception as e:
logger.error(f"获取人格详情失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"获取人格详情失败: {str(e)}").__dict__
async def create_persona(self):
"""创建新人格"""
try:
data = await request.get_json()
persona_id = data.get("persona_id", "").strip()
system_prompt = data.get("system_prompt", "").strip()
begin_dialogs = data.get("begin_dialogs", [])
tools = data.get("tools")
if not persona_id:
return Response().error("人格ID不能为空").__dict__
if not system_prompt:
return Response().error("系统提示词不能为空").__dict__
# 验证 begin_dialogs 格式
if begin_dialogs and len(begin_dialogs) % 2 != 0:
return (
Response()
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
.__dict__
)
persona = await self.persona_mgr.create_persona(
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs if begin_dialogs else None,
tools=tools if tools else None,
)
return (
Response()
.ok(
{
"message": "人格创建成功",
"persona": {
"persona_id": persona.persona_id,
"system_prompt": persona.system_prompt,
"begin_dialogs": persona.begin_dialogs or [],
"tools": persona.tools or [],
"created_at": persona.created_at.isoformat()
if persona.created_at
else None,
"updated_at": persona.updated_at.isoformat()
if persona.updated_at
else None,
},
}
)
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"创建人格失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"创建人格失败: {str(e)}").__dict__
async def update_persona(self):
"""更新人格信息"""
try:
data = await request.get_json()
persona_id = data.get("persona_id")
system_prompt = data.get("system_prompt")
begin_dialogs = data.get("begin_dialogs")
tools = data.get("tools")
if not persona_id:
return Response().error("缺少必要参数: persona_id").__dict__
# 验证 begin_dialogs 格式
if begin_dialogs is not None and len(begin_dialogs) % 2 != 0:
return (
Response()
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
.__dict__
)
await self.persona_mgr.update_persona(
persona_id=persona_id,
system_prompt=system_prompt,
begin_dialogs=begin_dialogs,
tools=tools,
)
return Response().ok({"message": "人格更新成功"}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"更新人格失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"更新人格失败: {str(e)}").__dict__
async def delete_persona(self):
"""删除人格"""
try:
data = await request.get_json()
persona_id = data.get("persona_id")
if not persona_id:
return Response().error("缺少必要参数: persona_id").__dict__
await self.persona_mgr.delete_persona(persona_id)
return Response().ok({"message": "人格删除成功"}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"删除人格失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"删除人格失败: {str(e)}").__dict__
+40
View File
@@ -8,6 +8,7 @@ from quart import request
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.star import star_map
from .route import Response, Route, RouteContext
@@ -27,6 +28,8 @@ class ToolsRoute(Route):
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
"/tools/mcp/market": ("GET", self.get_mcp_markets),
"/tools/mcp/test": ("POST", self.test_mcp_connection),
"/tools/list": ("GET", self.get_tool_list),
"/tools/toggle-tool": ("POST", self.toggle_tool),
}
self.register_routes()
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
@@ -336,3 +339,40 @@ class ToolsRoute(Route):
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__
async def get_tool_list(self):
"""获取所有注册的工具列表"""
try:
tools = self.tool_mgr.func_list
tools_dict = [tool.__dict__() for tool in tools]
return Response().ok(data=tools_dict).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"获取工具列表失败: {str(e)}").__dict__
async def toggle_tool(self):
"""启用或停用指定的工具"""
try:
data = await request.json
tool_name = data.get("name")
action = data.get("activate") # True or False
if not tool_name or action is None:
return Response().error("缺少必要参数: name 或 action").__dict__
if action:
try:
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
except ValueError as e:
return Response().error(f"启用工具失败: {str(e)}").__dict__
else:
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
if ok:
return Response().ok(None, "操作成功。").__dict__
else:
return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"操作工具失败: {str(e)}").__dict__
+3
View File
@@ -60,6 +60,9 @@ class AstrBotDashboard:
self.session_management_route = SessionManagementRoute(
self.context, db, core_lifecycle
)
self.persona_route = PersonaRoute(
self.context, db, core_lifecycle
)
self.app.add_url_rule(
"/api/plug/<path:subpath>",
+1
View File
@@ -52,6 +52,7 @@ export class I18nLoader {
{ name: 'features/alkaid/index', path: 'features/alkaid/index.json' },
{ name: 'features/alkaid/knowledge-base', path: 'features/alkaid/knowledge-base.json' },
{ name: 'features/alkaid/memory', path: 'features/alkaid/memory.json' },
{ name: 'features/persona', path: 'features/persona.json' },
// 消息模块
{ name: 'messages/errors', path: 'messages/errors.json' },
@@ -2,6 +2,7 @@
"dashboard": "Dashboard",
"platforms": "Platforms",
"providers": "Providers",
"persona": "Persona",
"toolUse": "MCP Tools",
"config": "Config",
"extension": "Extensions",
@@ -0,0 +1,67 @@
{
"page": {
"description": "Manage and configure chat bot personality settings"
},
"buttons": {
"create": "Create Persona",
"createFirst": "Create First Persona",
"edit": "Edit",
"delete": "Delete",
"cancel": "Cancel",
"save": "Save",
"addDialogPair": "Add Dialog Pair"
},
"labels": {
"presetDialogs": "Preset Dialogs ({count} pairs)",
"createdAt": "Created At",
"updatedAt": "Updated At"
},
"form": {
"personaId": "Persona ID",
"systemPrompt": "System Prompt",
"presetDialogs": "Preset Dialogs",
"presetDialogsHelp": "Add some preset dialogs to help the bot better understand the role settings. The number of dialogs must be even (users and assistants take turns).",
"userMessage": "User Message",
"assistantMessage": "Assistant Message",
"tools": "Tool Selection",
"toolsHelp": "Select available tools for this persona. Tools allow the bot to perform specific functions such as searching, calculating, getting information, etc.",
"toolsSelection": "Tool Selection Actions",
"selectAllTools": "Select All Tools",
"clearAllTools": "Clear Selection",
"allSelected": "All Selected",
"mcpServersQuickSelect": "MCP Servers Quick Select",
"searchTools": "Search Tools",
"selectedTools": "Selected Tools",
"noToolsAvailable": "No tools available",
"noToolsFound": "No matching tools found",
"loadingTools": "Loading tools...",
"allToolsAvailable": "Use all available tools",
"noToolsSelected": "No tools selected"
},
"dialog": {
"create": {
"title": "Create New Persona"
},
"edit": {
"title": "Edit Persona"
}
},
"empty": {
"title": "No Persona Configured",
"description": "Create your first persona to start using personalized chatbots"
},
"validation": {
"required": "This field is required",
"minLength": "Minimum {min} characters required",
"alphanumeric": "Only letters, numbers, underscores and hyphens are allowed",
"dialogRequired": "{type} cannot be empty"
},
"messages": {
"loadError": "Failed to load persona list",
"saveSuccess": "Saved successfully",
"saveError": "Save failed",
"deleteConfirm": "Are you sure you want to delete persona \"{id}\"? This action cannot be undone.",
"deleteSuccess": "Deleted successfully",
"deleteError": "Delete failed"
}
}
@@ -107,6 +107,9 @@
"failed": "Import configuration failed: {error}"
},
"configParseError": "Configuration parse error: {error}",
"noAvailableConfig": "No available configuration"
"noAvailableConfig": "No available configuration",
"toggleToolSuccess": "Tool status toggled successfully!",
"toggleToolError": "Failed to toggle tool status: {error}",
"testError": "Test connection failed: {error}"
}
}
@@ -2,6 +2,7 @@
"dashboard": "统计",
"platforms": "消息平台",
"providers": "服务提供商",
"persona": "人格管理",
"toolUse": "MCP",
"config": "配置文件",
"extension": "插件管理",
@@ -0,0 +1,67 @@
{
"page": {
"description": "管理和配置聊天机器人的人格角色设定"
},
"buttons": {
"create": "创建人格",
"createFirst": "创建第一个人格",
"edit": "编辑",
"delete": "删除",
"cancel": "取消",
"save": "保存",
"addDialogPair": "添加对话对"
},
"labels": {
"presetDialogs": "预设对话 ({count} 对)",
"createdAt": "创建时间",
"updatedAt": "更新时间"
},
"form": {
"personaId": "人格 ID",
"systemPrompt": "系统提示词",
"presetDialogs": "预设对话",
"presetDialogsHelp": "添加一些预设的对话来帮助机器人更好地理解角色设定。对话数量必须为偶数(用户和助手轮流对话)。",
"userMessage": "用户消息",
"assistantMessage": "助手消息",
"tools": "工具选择",
"toolsHelp": "为这个人格选择可用的工具。工具可以让机器人执行特定的功能,如搜索、计算、获取信息等。",
"toolsSelection": "工具选择操作",
"selectAllTools": "选择所有工具",
"clearAllTools": "清空选择",
"allSelected": "全选",
"mcpServersQuickSelect": "MCP 服务器快速选择",
"searchTools": "搜索工具",
"selectedTools": "已选择的工具",
"noToolsAvailable": "暂无可用工具",
"noToolsFound": "未找到匹配的工具",
"loadingTools": "正在加载工具...",
"allToolsAvailable": "使用所有可用工具",
"noToolsSelected": "未选择任何工具"
},
"dialog": {
"create": {
"title": "创建新人格"
},
"edit": {
"title": "编辑人格"
}
},
"empty": {
"title": "暂无人格配置",
"description": "创建您的第一个人格来开始使用个性化的聊天机器人"
},
"validation": {
"required": "此字段为必填项",
"minLength": "最少需要 {min} 个字符",
"alphanumeric": "只能包含字母、数字、下划线和连字符",
"dialogRequired": "{type}不能为空"
},
"messages": {
"loadError": "加载人格列表失败",
"saveSuccess": "保存成功",
"saveError": "保存失败",
"deleteConfirm": "确定要删除人格 \"{id}\" 吗?此操作不可撤销。",
"deleteSuccess": "删除成功",
"deleteError": "删除失败"
}
}
@@ -107,6 +107,9 @@
"failed": "导入配置失败: {error}"
},
"configParseError": "配置解析错误: {error}",
"noAvailableConfig": "无可用配置"
"noAvailableConfig": "无可用配置",
"toggleToolSuccess": "工具状态切换成功!",
"toggleToolError": "工具状态切换失败: {error}",
"testError": "测试连接失败: {error}"
}
}
}
+6 -2
View File
@@ -25,6 +25,7 @@ import zhCNDashboard from './locales/zh-CN/features/dashboard.json';
import zhCNAlkaidIndex from './locales/zh-CN/features/alkaid/index.json';
import zhCNAlkaidKnowledgeBase from './locales/zh-CN/features/alkaid/knowledge-base.json';
import zhCNAlkaidMemory from './locales/zh-CN/features/alkaid/memory.json';
import zhCNPersona from './locales/zh-CN/features/persona.json';
import zhCNErrors from './locales/zh-CN/messages/errors.json';
import zhCNSuccess from './locales/zh-CN/messages/success.json';
@@ -54,6 +55,7 @@ import enUSDashboard from './locales/en-US/features/dashboard.json';
import enUSAlkaidIndex from './locales/en-US/features/alkaid/index.json';
import enUSAlkaidKnowledgeBase from './locales/en-US/features/alkaid/knowledge-base.json';
import enUSAlkaidMemory from './locales/en-US/features/alkaid/memory.json';
import enUSPersona from './locales/en-US/features/persona.json';
import enUSErrors from './locales/en-US/messages/errors.json';
import enUSSuccess from './locales/en-US/messages/success.json';
@@ -88,7 +90,8 @@ export const translations = {
index: zhCNAlkaidIndex,
'knowledge-base': zhCNAlkaidKnowledgeBase,
memory: zhCNAlkaidMemory
}
},
persona: zhCNPersona
},
messages: {
errors: zhCNErrors,
@@ -123,7 +126,8 @@ export const translations = {
index: enUSAlkaidIndex,
'knowledge-base': enUSAlkaidKnowledgeBase,
memory: enUSAlkaidMemory
}
},
persona: enUSPersona
},
messages: {
errors: enUSErrors,
@@ -63,9 +63,13 @@ const sidebarItem: menu[] = [
icon: 'mdi-account-group',
to: '/session-management'
},
{
title: 'core.navigation.persona',
icon: 'mdi-heart',
to: '/persona'
},
{
title: 'core.navigation.console',
icon: 'mdi-console',
to: '/console'
},
+5
View File
@@ -56,6 +56,11 @@ const MainRoutes = {
path: '/session-management',
component: () => import('@/views/SessionManagementPage.vue')
},
{
name: 'Persona',
path: '/persona',
component: () => import('@/views/PersonaPage.vue')
},
{
name: 'Console',
path: '/console',
+808
View File
@@ -0,0 +1,808 @@
<template>
<div class="persona-page">
<v-container fluid class="pa-0">
<!-- 页面标题 -->
<v-row class="d-flex justify-space-between align-center px-4 py-3 pb-8">
<div>
<h1 class="text-h1 font-weight-bold mb-2">
<v-icon color="black" class="me-2">mdi-heart</v-icon>{{ t('core.navigation.persona') }}
</h1>
<p class="text-subtitle-1 text-medium-emphasis mb-4">
{{ tm('page.description') }}
</p>
</div>
<div>
<v-btn color="primary" variant="tonal" prepend-icon="mdi-plus" @click="openCreateDialog"
rounded="xl" size="x-large">
{{ tm('buttons.create') }}
</v-btn>
</div>
</v-row>
<!-- 人格卡片网格 -->
<v-row>
<v-col v-for="persona in personas" :key="persona.persona_id" cols="12" md="6" lg="4" xl="3">
<v-card class="persona-card" elevation="2" rounded="lg" @click="viewPersona(persona)">
<v-card-title class="d-flex justify-space-between align-center">
<div class="text-truncate ml-2">
{{ persona.persona_id }}
</div>
<v-menu offset-y>
<template v-slot:activator="{ props }">
<v-btn icon="mdi-dots-vertical" variant="text" size="small" v-bind="props"
@click.stop />
</template>
<v-list density="compact">
<v-list-item @click="editPersona(persona)">
<v-list-item-title>
<v-icon class="mr-2" size="small">mdi-pencil</v-icon>
{{ tm('buttons.edit') }}
</v-list-item-title>
</v-list-item>
<v-list-item @click="deletePersona(persona)" class="text-error">
<v-list-item-title>
<v-icon class="mr-2" size="small">mdi-delete</v-icon>
{{ tm('buttons.delete') }}
</v-list-item-title>
</v-list-item>
</v-list>
</v-menu>
</v-card-title>
<v-card-text>
<div class="system-prompt-preview">
{{ truncateText(persona.system_prompt, 100) }}
</div>
<div class="mt-3" v-if="persona.begin_dialogs && persona.begin_dialogs.length > 0">
<v-chip size="small" color="secondary" variant="tonal" prepend-icon="mdi-chat">
{{ tm('labels.presetDialogs', { count: persona.begin_dialogs.length / 2 }) }}
</v-chip>
</div>
<div class="mt-3 text-caption text-medium-emphasis">
{{ tm('labels.createdAt') }}: {{ formatDate(persona.created_at) }}
</div>
</v-card-text>
</v-card>
</v-col>
<!-- 空状态 -->
<v-col v-if="personas.length === 0 && !loading" cols="12">
<v-card class="text-center pa-8" elevation="0">
<v-icon size="64" color="grey-lighten-1" class="mb-4">mdi-account-group</v-icon>
<h3 class="text-h5 mb-2">{{ tm('empty.title') }}</h3>
<p class="text-body-1 text-medium-emphasis mb-4">{{ tm('empty.description') }}</p>
<v-btn color="primary" variant="flat" prepend-icon="mdi-plus" @click="openCreateDialog">
{{ tm('buttons.createFirst') }}
</v-btn>
</v-card>
</v-col>
</v-row>
<!-- 加载状态 -->
<v-row v-if="loading">
<v-col v-for="n in 6" :key="n" cols="12" md="6" lg="4" xl="3">
<v-skeleton-loader type="card" rounded="lg"></v-skeleton-loader>
</v-col>
</v-row>
</v-container>
<!-- 创建/编辑人格对话框 -->
<v-dialog v-model="showPersonaDialog" max-width="800px" persistent>
<v-card>
<v-card-title class="text-h5">
{{ editingPersona ? tm('dialog.edit.title') : tm('dialog.create.title') }}
</v-card-title>
<v-card-text>
<v-form ref="personaForm" v-model="formValid">
<v-text-field v-model="personaForm.persona_id" :label="tm('form.personaId')"
:rules="personaIdRules" :disabled="editingPersona" variant="outlined" density="comfortable"
class="mb-4" />
<v-textarea v-model="personaForm.system_prompt" :label="tm('form.systemPrompt')"
:rules="systemPromptRules" variant="outlined" rows="6" class="mb-4" />
<v-expansion-panels v-model="expandedPanels" multiple>
<!-- 工具选择面板 -->
<v-expansion-panel value="tools">
<v-expansion-panel-title>
<v-icon class="mr-2">mdi-tools</v-icon>
{{ tm('form.tools') }}
<v-chip v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
size="small" color="primary" variant="tonal" class="ml-2">
{{ personaForm.tools.length }}
</v-chip>
</v-expansion-panel-title>
<v-expansion-panel-text>
<div class="mb-3">
<p class="text-body-2 text-medium-emphasis">
{{ tm('form.toolsHelp') }}
</p>
</div>
<v-radio-group class="mt-2" v-model="toolSelectValue" hide-details="true">
<v-radio label="默认使用全部函数工具" value="0"></v-radio>
<v-radio label="选择指定函数工具" value="1">
</v-radio>
</v-radio-group>
<div v-if="toolSelectValue === '1'" class="mt-3 ml-8">
<!-- 工具搜索 -->
<v-text-field v-model="toolSearch" :label="tm('form.searchTools')"
prepend-inner-icon="mdi-magnify" variant="outlined" density="compact"
hide-details clearable class="mb-3" />
<!-- MCP 服务器 -->
<div v-if="mcpServers.length > 0" class="mb-4">
<h4 class="text-subtitle-2 mb-2">{{ tm('form.mcpServersQuickSelect') }}</h4>
<div class="d-flex flex-wrap ga-2">
<v-chip v-for="server in mcpServers" :key="server.name"
:color="isServerSelected(server) ? 'primary' : 'default'"
:variant="isServerSelected(server) ? 'flat' : 'outlined'"
size="small" clickable @click="toggleMcpServer(server)"
:disabled="!server.tools || server.tools.length === 0">
<v-icon start size="small">mdi-server</v-icon>
{{ server.name }}
<v-chip-text v-if="server.tools" class="ml-1">
({{ server.tools.length }})
</v-chip-text>
</v-chip>
</div>
</div>
<!-- 工具选择列表 -->
<div v-if="filteredTools.length > 0" class="tools-selection">
<v-virtual-scroll :items="filteredTools" height="300" item-height="48">
<template v-slot:default="{ item }">
<v-list-item :key="item.name" density="comfortable"
@click="toggleTool(item.name)">
<template v-slot:prepend>
<v-checkbox-btn :model-value="isToolSelected(item.name)"
@click.stop="toggleTool(item.name)" />
</template>
<v-list-item-title>
{{ item.name }}
<v-chip v-if="item.mcp_server_name" size="x-small"
color="secondary" variant="tonal" class="ml-2">
{{ item.mcp_server_name }}
</v-chip>
</v-list-item-title>
<v-list-item-subtitle v-if="item.description">
{{ truncateText(item.description, 100) }}
</v-list-item-subtitle>
</v-list-item>
</template>
</v-virtual-scroll>
</div>
<div v-else-if="!loadingTools && availableTools.length === 0"
class="text-center pa-4">
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-tools</v-icon>
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsAvailable')
}}
</p>
</div>
<div v-else-if="!loadingTools && filteredTools.length === 0"
class="text-center pa-4">
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-magnify</v-icon>
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsFound') }}
</p>
</div>
<!-- 加载状态 -->
<div v-if="loadingTools" class="text-center pa-4">
<v-progress-circular indeterminate color="primary" />
<p class="text-body-2 text-medium-emphasis mt-2">{{ tm('form.loadingTools')
}}
</p>
</div>
<!-- 已选择的工具 -->
<div class="mt-4">
<h4 class="text-subtitle-2 mb-2">
{{ tm('form.selectedTools') }}
<span v-if="personaForm.tools === null" class="text-success">
({{ tm('form.allSelected') }})
</span>
<span v-else-if="Array.isArray(personaForm.tools)">
({{ personaForm.tools.length }})
</span>
</h4>
<div v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
class="d-flex flex-wrap ga-1" style="max-height: 100px; overflow-y: auto;">
<v-chip v-for="toolName in personaForm.tools" :key="toolName"
size="small" color="primary" variant="tonal" closable
@click:close="removeTool(toolName)">
{{ toolName }}
</v-chip>
</div>
<div v-else class="text-body-2 text-medium-emphasis">
{{ tm('form.noToolsSelected') }}
</div>
</div>
</div>
</v-expansion-panel-text>
</v-expansion-panel>
<!-- 预设对话面板 -->
<v-expansion-panel value="dialogs">
<v-expansion-panel-title>
<v-icon class="mr-2">mdi-chat</v-icon>
{{ tm('form.presetDialogs') }}
<v-chip v-if="personaForm.begin_dialogs.length > 0" size="small" color="primary"
variant="tonal" class="ml-2">
{{ personaForm.begin_dialogs.length / 2 }}
</v-chip>
</v-expansion-panel-title>
<v-expansion-panel-text>
<div class="mb-3">
<p class="text-body-2 text-medium-emphasis">
{{ tm('form.presetDialogsHelp') }}
</p>
</div>
<div v-for="(dialog, index) in personaForm.begin_dialogs" :key="index" class="mb-3">
<v-textarea v-model="personaForm.begin_dialogs[index]"
:label="index % 2 === 0 ? tm('form.userMessage') : tm('form.assistantMessage')"
:rules="getDialogRules(index)" variant="outlined" rows="2"
density="comfortable">
<template v-slot:append>
<v-btn icon="mdi-delete" variant="text" size="small" color="error"
@click="removeDialog(index)" />
</template>
</v-textarea>
</div>
<v-btn variant="outlined" prepend-icon="mdi-plus" @click="addDialogPair" block>
{{ tm('buttons.addDialogPair') }}
</v-btn>
</v-expansion-panel-text>
</v-expansion-panel>
</v-expansion-panels>
</v-form>
</v-card-text>
<v-card-actions>
<v-spacer />
<v-btn color="grey" variant="text" @click="closePersonaDialog">
{{ tm('buttons.cancel') }}
</v-btn>
<v-btn color="primary" variant="flat" @click="savePersona" :loading="saving" :disabled="!formValid">
{{ tm('buttons.save') }}
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
<!-- 查看人格详情对话框 -->
<v-dialog v-model="showViewDialog" max-width="700px">
<v-card v-if="viewingPersona">
<v-card-title class="d-flex justify-space-between align-center">
<span class="text-h5">{{ viewingPersona.persona_id }}</span>
<v-btn icon="mdi-close" variant="text" @click="showViewDialog = false" />
</v-card-title>
<v-card-text>
<div class="mb-4">
<h4 class="text-h6 mb-2">{{ tm('form.systemPrompt') }}</h4>
<div class="system-prompt-content">
{{ viewingPersona.system_prompt }}
</div>
</div>
<div v-if="viewingPersona.begin_dialogs && viewingPersona.begin_dialogs.length > 0" class="mb-4">
<h4 class="text-h6 mb-2">{{ tm('form.presetDialogs') }}</h4>
<div v-for="(dialog, index) in viewingPersona.begin_dialogs" :key="index" class="mb-2">
<v-chip :color="index % 2 === 0 ? 'primary' : 'secondary'" variant="tonal" size="small"
class="mb-1">
{{ index % 2 === 0 ? tm('form.userMessage') : tm('form.assistantMessage') }}
</v-chip>
<div class="dialog-content ml-2">
{{ dialog }}
</div>
</div>
</div>
<div class="mb-4">
<h4 class="text-h6 mb-2">{{ tm('form.tools') }}</h4>
<div v-if="viewingPersona.tools === null" class="text-body-2 text-medium-emphasis">
<v-chip size="small" color="success" variant="tonal" prepend-icon="mdi-check-all">
{{ tm('form.allToolsAvailable') }}
</v-chip>
</div>
<div v-else-if="viewingPersona.tools && viewingPersona.tools.length > 0"
class="d-flex flex-wrap ga-1">
<v-chip v-for="toolName in viewingPersona.tools" :key="toolName" size="small"
color="primary" variant="tonal">
{{ toolName }}
</v-chip>
</div>
<div v-else class="text-body-2 text-medium-emphasis">
{{ tm('form.noToolsSelected') }}
</div>
</div>
<div class="text-caption text-medium-emphasis">
<div>{{ tm('labels.createdAt') }}: {{ formatDate(viewingPersona.created_at) }}</div>
<div v-if="viewingPersona.updated_at">{{ tm('labels.updatedAt') }}: {{
formatDate(viewingPersona.updated_at) }}</div>
</div>
</v-card-text>
</v-card>
</v-dialog>
<!-- 消息提示 -->
<v-snackbar :timeout="3000" elevation="24" :color="messageType" v-model="showMessage" location="top">
{{ message }}
</v-snackbar>
</div>
</template>
<script>
import axios from 'axios';
import { useI18n, useModuleI18n } from '@/i18n/composables';
export default {
name: 'PersonaPage',
setup() {
const { t } = useI18n();
const { tm } = useModuleI18n('features/persona');
return { t, tm };
},
data() {
return {
toolSelectValue: '0', //
personas: [],
loading: false,
saving: false,
showPersonaDialog: false,
showViewDialog: false,
editingPersona: null,
viewingPersona: null,
expandedPanels: [],
formValid: false,
personaForm: {
persona_id: '',
system_prompt: '',
begin_dialogs: [],
tools: []
},
showMessage: false,
message: '',
messageType: 'success',
personaIdRules: [
v => !!v || this.tm('validation.required'),
v => (v && v.length >= 2) || this.tm('validation.minLength', { min: 2 }),
v => /^[a-zA-Z0-9_-]+$/.test(v) || this.tm('validation.alphanumeric')
],
systemPromptRules: [
v => !!v || this.tm('validation.required'),
v => (v && v.length >= 10) || this.tm('validation.minLength', { min: 10 })
],
mcpServers: [],
availableTools: [],
loadingTools: false,
toolSearch: ''
}
},
computed: {
filteredTools() {
if (!this.toolSearch) {
return this.availableTools;
}
const search = this.toolSearch.toLowerCase();
return this.availableTools.filter(tool =>
tool.name.toLowerCase().includes(search) ||
(tool.description && tool.description.toLowerCase().includes(search)) ||
(tool.mcp_server_name && tool.mcp_server_name.toLowerCase().includes(search))
);
}
},
watch: {
toolSearch() {
//
},
toolSelectValue(newValue) {
if (newValue === '0') {
//
this.personaForm.tools = null;
} else if (newValue === '1') {
// null
if (this.personaForm.tools === null) {
this.personaForm.tools = [];
}
}
}
},
mounted() {
this.loadPersonas();
this.loadMcpServers();
this.loadTools();
},
methods: {
async loadPersonas() {
this.loading = true;
try {
const response = await axios.get('/api/persona/list');
if (response.data.status === 'ok') {
this.personas = response.data.data;
} else {
this.showError(response.data.message || this.tm('messages.loadError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.loadError'));
}
this.loading = false;
},
async loadMcpServers() {
try {
const response = await axios.get('/api/tools/mcp/servers');
if (response.data.status === 'ok') {
this.mcpServers = response.data.data;
} else {
this.showError(response.data.message || this.tm('messages.loadError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.loadError'));
}
},
async loadTools() {
this.loadingTools = true;
try {
const response = await axios.get('/api/tools/list');
if (response.data.status === 'ok') {
this.availableTools = response.data.data;
} else {
this.showError(response.data.message || this.tm('messages.loadError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.loadError'));
}
this.loadingTools = false;
},
openCreateDialog() {
this.editingPersona = null;
this.personaForm = {
persona_id: '',
system_prompt: '',
begin_dialogs: [],
tools: []
};
this.toolSelectValue = '1'; //
this.expandedPanels = [];
this.showPersonaDialog = true;
},
editPersona(persona) {
this.editingPersona = persona;
this.personaForm = {
persona_id: persona.persona_id,
system_prompt: persona.system_prompt,
begin_dialogs: [...(persona.begin_dialogs || [])],
tools: persona.tools === null ? null : [...(persona.tools || [])]
};
// tools toolSelectValue
this.toolSelectValue = persona.tools === null ? '0' : '1';
this.expandedPanels = [];
this.showPersonaDialog = true;
},
viewPersona(persona) {
this.viewingPersona = persona;
this.showViewDialog = true;
},
closePersonaDialog() {
this.showPersonaDialog = false;
this.editingPersona = null;
this.personaForm = {
persona_id: '',
system_prompt: '',
begin_dialogs: [],
tools: []
};
this.toolSelectValue = '1'; //
},
async savePersona() {
if (!this.formValid) return;
//
if (this.personaForm.begin_dialogs.length > 0) {
for (let i = 0; i < this.personaForm.begin_dialogs.length; i++) {
if (!this.personaForm.begin_dialogs[i] || this.personaForm.begin_dialogs[i].trim() === '') {
const dialogType = i % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
this.showError(this.tm('validation.dialogRequired', { type: dialogType }));
return;
}
}
}
this.saving = true;
try {
const url = this.editingPersona ? '/api/persona/update' : '/api/persona/create';
const response = await axios.post(url, this.personaForm);
if (response.data.status === 'ok') {
this.showSuccess(response.data.message || this.tm('messages.saveSuccess'));
this.closePersonaDialog();
await this.loadPersonas();
} else {
this.showError(response.data.message || this.tm('messages.saveError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.saveError'));
}
this.saving = false;
},
async deletePersona(persona) {
if (!confirm(this.tm('messages.deleteConfirm', { id: persona.persona_id }))) {
return;
}
try {
const response = await axios.post('/api/persona/delete', {
persona_id: persona.persona_id
});
if (response.data.status === 'ok') {
this.showSuccess(response.data.message || this.tm('messages.deleteSuccess'));
await this.loadPersonas();
} else {
this.showError(response.data.message || this.tm('messages.deleteError'));
}
} catch (error) {
this.showError(error.response?.data?.message || this.tm('messages.deleteError'));
}
},
addDialogPair() {
this.personaForm.begin_dialogs.push('', '');
//
if (!this.expandedPanels.includes('dialogs')) {
this.expandedPanels.push('dialogs');
}
},
removeDialog(index) {
//
if (index % 2 === 0 && index + 1 < this.personaForm.begin_dialogs.length) {
this.personaForm.begin_dialogs.splice(index, 2);
}
//
else if (index % 2 === 1 && index - 1 >= 0) {
this.personaForm.begin_dialogs.splice(index - 1, 2);
}
},
toggleMcpServer(server) {
if (!server.tools || server.tools.length === 0) return;
//
if (this.personaForm.tools === null) {
//
this.personaForm.tools = this.availableTools.map(tool => tool.name)
.filter(toolName => !server.tools.includes(toolName));
this.toolSelectValue = '1'; //
return;
}
// tools
if (!Array.isArray(this.personaForm.tools)) {
this.personaForm.tools = [];
this.toolSelectValue = '1';
}
//
const serverTools = server.tools;
const allSelected = serverTools.every(toolName => this.personaForm.tools.includes(toolName));
if (allSelected) {
//
this.personaForm.tools = this.personaForm.tools.filter(
toolName => !serverTools.includes(toolName)
);
} else {
//
serverTools.forEach(toolName => {
if (!this.personaForm.tools.includes(toolName)) {
this.personaForm.tools.push(toolName);
}
});
}
},
toggleTool(toolName) {
//
if (this.personaForm.tools === null) {
//
//
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
this.toolSelectValue = '1'; //
} else if (Array.isArray(this.personaForm.tools)) {
const index = this.personaForm.tools.indexOf(toolName);
if (index !== -1) {
//
this.personaForm.tools.splice(index, 1);
} else {
//
this.personaForm.tools.push(toolName);
}
} else {
// toolsnull
this.personaForm.tools = [toolName];
this.toolSelectValue = '1';
}
},
toggleAllTools() {
//
if (this.isAllToolsSelected()) {
this.personaForm.tools = [];
} else {
// null
this.personaForm.tools = null;
}
},
clearAllTools() {
//
this.personaForm.tools = [];
},
isAllToolsSelected() {
// toolsnull
return this.personaForm.tools === null;
},
isNoToolsSelected() {
//
return Array.isArray(this.personaForm.tools) && this.personaForm.tools.length === 0;
},
removeTool(toolName) {
//
if (this.personaForm.tools === null) {
//
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
this.toolSelectValue = '1'; //
} else if (Array.isArray(this.personaForm.tools)) {
const index = this.personaForm.tools.indexOf(toolName);
if (index !== -1) {
this.personaForm.tools.splice(index, 1);
}
}
},
truncateText(text, maxLength) {
if (!text) return '';
return text.length > maxLength ? text.substring(0, maxLength) + '...' : text;
},
formatDate(dateString) {
if (!dateString) return '';
return new Date(dateString).toLocaleString();
},
showSuccess(message) {
this.message = message;
this.messageType = 'success';
this.showMessage = true;
},
showError(message) {
this.message = message;
this.messageType = 'error';
this.showMessage = true;
},
getDialogRules(index) {
const dialogType = index % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
return [
v => !!v || this.tm('validation.dialogRequired', { type: dialogType }),
v => (v && v.trim().length > 0) || this.tm('validation.dialogRequired', { type: dialogType })
];
},
isToolSelected(toolName) {
//
if (this.personaForm.tools === null) {
return true;
}
return Array.isArray(this.personaForm.tools) && this.personaForm.tools.includes(toolName);
},
isServerSelected(server) {
if (!server.tools || server.tools.length === 0) return false;
//
if (this.personaForm.tools === null) {
return true;
}
//
return Array.isArray(this.personaForm.tools) &&
server.tools.every(toolName => this.personaForm.tools.includes(toolName));
}
}
}
</script>
<style scoped>
.persona-page {
padding: 20px;
padding-top: 8px;
}
.persona-card {
transition: all 0.3s ease;
height: 100%;
cursor: pointer;
}
.persona-card:hover {
box-shadow: 0 8px 25px 0 rgba(0, 0, 0, 0.15);
}
.system-prompt-preview {
font-size: 14px;
line-height: 1.4;
color: rgba(var(--v-theme-on-surface), 0.7);
overflow: hidden;
display: -webkit-box;
-webkit-line-clamp: 3;
line-clamp: 3;
-webkit-box-orient: vertical;
}
.system-prompt-content {
background-color: rgba(var(--v-theme-surface-variant), 0.3);
padding: 12px;
border-radius: 8px;
font-family: 'Roboto Mono', monospace;
font-size: 14px;
line-height: 1.5;
white-space: pre-wrap;
word-break: break-word;
}
.dialog-content {
background-color: rgba(var(--v-theme-surface-variant), 0.3);
padding: 8px 12px;
border-radius: 8px;
font-size: 14px;
line-height: 1.4;
margin-bottom: 8px;
white-space: pre-wrap;
word-break: break-word;
}
.tools-selection {
max-height: 300px;
overflow-y: auto;
}
.v-virtual-scroll {
padding-bottom: 16px;
}
</style>
+45 -11
View File
@@ -405,24 +405,36 @@
<v-text-field v-model="toolSearch" prepend-inner-icon="mdi-magnify" :label="tm('functionTools.search')"
variant="outlined" density="compact" class="mb-4" hide-details clearable></v-text-field>
<small>复选框代表该工具是否被启用</small>
<v-expansion-panels v-model="openedPanel" multiple style="max-height: 500px; overflow-y: auto;">
<v-expansion-panel v-for="(tool, index) in filteredTools" :key="index" :value="index"
class="mb-2 tool-panel" rounded="lg">
<v-expansion-panel-title>
<v-row no-gutters align="center">
<v-col cols="1">
<v-checkbox
v-model="tool.active"
color="primary"
hide-details
density="compact"
@click.stop
@change="toggleToolStatus(tool)"
></v-checkbox>
</v-col>
<v-col cols="3">
<div class="d-flex align-center">
<v-icon color="primary" class="me-2" size="small">
{{ tool.function.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
{{ tool.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
</v-icon>
<span class="text-body-1 text-high-emphasis font-weight-medium text-truncate"
:title="tool.function.name">
{{ formatToolName(tool.function.name) }}
:title="tool.name">
{{ formatToolName(tool.name) }}
</span>
</div>
</v-col>
<v-col cols="9" class="text-grey">
{{ tool.function.description }}
<v-col cols="8" class="text-grey">
{{ tool.description }}
</v-col>
</v-row>
</v-expansion-panel-title>
@@ -434,9 +446,9 @@
<v-icon color="primary" size="small" class="me-1">mdi-information</v-icon>
{{ tm('functionTools.description') }}
</p>
<p class="text-body-2 ml-6 mb-4">{{ tool.function.description }}</p>
<p class="text-body-2 ml-6 mb-4">{{ tool.description }}</p>
<template v-if="tool.function.parameters && tool.function.parameters.properties">
<template v-if="tool.parameters && tool.parameters.properties">
<p class="text-body-1 font-weight-medium mb-3">
<v-icon color="primary" size="small" class="me-1">mdi-code-json</v-icon>
{{ tm('functionTools.parameters') }}
@@ -451,7 +463,7 @@
</tr>
</thead>
<tbody>
<tr v-for="(param, paramName) in tool.function.parameters.properties" :key="paramName">
<tr v-for="(param, paramName) in tool.parameters.properties" :key="paramName">
<td class="font-weight-medium">{{ paramName }}</td>
<td>
<v-chip size="x-small" color="primary" text class="text-caption">
@@ -562,8 +574,8 @@ export default {
const searchTerm = this.toolSearch.toLowerCase();
return this.tools.filter(tool =>
tool.function.name.toLowerCase().includes(searchTerm) ||
tool.function.description.toLowerCase().includes(searchTerm)
tool.name.toLowerCase().includes(searchTerm) ||
tool.description.toLowerCase().includes(searchTerm)
);
},
@@ -658,7 +670,7 @@ export default {
},
getTools() {
axios.get('/api/config/llmtools')
axios.get('/api/tools/list')
.then(response => {
this.tools = response.data.data || [];
})
@@ -976,6 +988,28 @@ export default {
} catch (e) {
this.showError(this.tm('messages.importError.failed', { error: e.message }));
}
},
//
async toggleToolStatus(tool) {
try {
const response = await axios.post('/api/tools/toggle-tool', {
name: tool.name,
activate: tool.active
});
if (response.data.status === 'ok') {
this.showSuccess(response.data.message || this.tm('messages.toggleToolSuccess'));
} else {
//
tool.active = !tool.active;
this.showError(response.data.message || this.tm('messages.toggleToolError'));
}
} catch (error) {
//
tool.active = !tool.active;
this.showError(this.tm('messages.toggleToolError', { error: error.response?.data?.message || error.message }));
}
}
}
}
+21 -41
View File
@@ -26,6 +26,7 @@ from .long_term_memory import LongTermMemory
from astrbot.core import logger
from astrbot.api.message_components import Plain, Image, Reply
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.provider.func_tool_manager import ToolSet
from typing import Union
from enum import Enum
@@ -128,7 +129,6 @@ class Main(star.Star):
/reset: 重置 LLM 会话
/history: 当前对话的对话记录
/persona: 人格情景(op)
/tool ls: 函数工具
/key: API Key(op)
/websearch: 网页搜索
{notice}"""
@@ -157,50 +157,22 @@ class Main(star.Star):
@tool.command("ls")
async def tool_ls(self, event: AstrMessageEvent):
"""查看函数工具列表"""
tm = self.context.get_llm_tool_manager()
msg = "函数工具:\n"
for tool in tm.func_list:
active = " (启用)" if tool.active else "(停用)"
msg += f"- {tool.name}: {tool.description} {active}\n"
msg += "\n使用 /tool on/off <工具名> 激活或者停用函数工具。/tool off_all 停用所有函数工具。"
event.set_result(MessageEventResult().message(msg).use_t2i(False))
event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
@tool.command("on")
async def tool_on(self, event: AstrMessageEvent, tool_name: str):
"""启用一个函数工具"""
if self.context.activate_llm_tool(tool_name):
event.set_result(
MessageEventResult().message(f"激活工具 {tool_name} 成功。")
)
else:
event.set_result(
MessageEventResult().message(
f"激活工具 {tool_name} 失败,未找到此工具。"
)
)
event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
@tool.command("off")
async def tool_off(self, event: AstrMessageEvent, tool_name: str):
"""停用一个函数工具"""
if self.context.deactivate_llm_tool(tool_name):
event.set_result(
MessageEventResult().message(f"停用工具 {tool_name} 成功。")
)
else:
event.set_result(
MessageEventResult().message(
f"停用工具 {tool_name} 失败,未找到此工具。"
)
)
event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
@tool.command("off_all")
async def tool_all_off(self, event: AstrMessageEvent):
"""停用所有函数工具"""
tm = self.context.get_llm_tool_manager()
for tool in tm.func_list:
self.context.deactivate_llm_tool(tool.name)
event.set_result(MessageEventResult().message("停用所有工具成功。"))
event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
@filter.command_group("plugin")
def plugin(self):
@@ -1264,26 +1236,34 @@ UID: {user_id} 此 ID 可用于设置管理员。
if req.conversation:
persona_id = req.conversation.persona_id
if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
persona_id = self.context.provider_manager.selected_default_persona[
persona_id = self.context.persona_manager.selected_default_persona_v3[
"name"
]
persona = next(
builtins.filter(
lambda persona: persona["name"] == persona_id,
self.context.provider_manager.personas,
self.context.persona_manager.personas_v3,
),
None,
)
if persona:
if prompt := persona["prompt"]:
req.system_prompt += prompt
if mood_dialogs := persona["_mood_imitation_dialogs_processed"]:
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
req.system_prompt += mood_dialogs
if (
begin_dialogs := persona["_begin_dialogs_processed"]
) and not req.contexts:
if begin_dialogs := persona["_begin_dialogs_processed"]:
req.contexts[:0] = begin_dialogs
# tools select
tmgr = self.context.get_llm_tool_manager()
if (persona and persona.get("tools") is None) or not persona:
# select all
toolset = tmgr.get_full_tool_set()
else:
toolset = ToolSet()
for tool_name in persona["tools"]:
tool = tmgr.get_func(tool_name)
if tool:
toolset.add_tool(tool)
req.func_tool = toolset
logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}")
if quote:
sender_info = ""