Improve: 引入全新的人格管理模式以及重构函数工具管理器 (#2305)
* feat: add persona management * refactor: 重构函数工具管理器,引入 ToolSet,并让 Persona 支持绑定 Tools * feat: 更新 Persona 工具选择逻辑,支持全选和指定工具的切换 * feat: 更新 BaseDatabase 中的 persona 方法返回类型,支持返回 None
This commit is contained in:
@@ -115,8 +115,7 @@ DEFAULT_CONFIG = {
|
||||
"log_level": "INFO",
|
||||
"pip_install_arg": "",
|
||||
"pypi_index_url": "https://mirrors.aliyun.com/pypi/simple/",
|
||||
"knowledge_db": {},
|
||||
"persona": [],
|
||||
"persona": [], # deprecated
|
||||
"timezone": "",
|
||||
"callback_api_base": "",
|
||||
}
|
||||
@@ -1701,43 +1700,6 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"persona": {
|
||||
"description": "人格情景设置",
|
||||
"type": "list",
|
||||
"config_template": {
|
||||
"新人格情景": {
|
||||
"name": "",
|
||||
"prompt": "",
|
||||
"begin_dialogs": [],
|
||||
"mood_imitation_dialogs": [],
|
||||
}
|
||||
},
|
||||
"tmpl_display_title": "name",
|
||||
"items": {
|
||||
"name": {
|
||||
"description": "人格名称",
|
||||
"type": "string",
|
||||
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
|
||||
},
|
||||
"prompt": {
|
||||
"description": "设定(系统提示词)",
|
||||
"type": "text",
|
||||
"hint": "填写人格的身份背景、性格特征、兴趣爱好、个人经历、口头禅等。",
|
||||
},
|
||||
"begin_dialogs": {
|
||||
"description": "预设对话",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||
},
|
||||
"mood_imitation_dialogs": {
|
||||
"description": "对话风格模仿",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一致。对话需要成对(用户和助手),输入完一个角色的内容之后按【回车】。需要偶数个对话",
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"description": "语音转文本(STT)",
|
||||
"type": "object",
|
||||
|
||||
@@ -22,6 +22,7 @@ from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
|
||||
from astrbot.core.star import PluginManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core import LogBroker
|
||||
from astrbot.core.db import BaseDatabase
|
||||
@@ -69,13 +70,23 @@ class AstrBotCoreLifecycle:
|
||||
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
|
||||
|
||||
await self.db.initialize()
|
||||
await do_migration_v4(self.db, {})
|
||||
|
||||
try:
|
||||
await do_migration_v4(self.db, {}, self.astrbot_config)
|
||||
except Exception as e:
|
||||
logger.error(f"迁移到 v4.0.0 新版本数据格式失败: {e}")
|
||||
|
||||
# 初始化事件队列
|
||||
self.event_queue = Queue()
|
||||
|
||||
# 初始化人格管理器
|
||||
self.persona_mgr = PersonaManager(self.db, self.astrbot_config)
|
||||
await self.persona_mgr.initialize()
|
||||
|
||||
# 初始化供应商管理器
|
||||
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
||||
self.provider_manager = ProviderManager(
|
||||
self.astrbot_config, self.db, self.persona_mgr
|
||||
)
|
||||
|
||||
# 初始化平台管理器
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
@@ -94,6 +105,8 @@ class AstrBotCoreLifecycle:
|
||||
self.provider_manager,
|
||||
self.platform_manager,
|
||||
self.conversation_manager,
|
||||
self.platform_message_history_manager,
|
||||
self.persona_mgr,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
@@ -110,6 +123,7 @@ class AstrBotCoreLifecycle:
|
||||
PipelineContext(self.astrbot_config, self.plugin_manager)
|
||||
)
|
||||
await self.pipeline_scheduler.initialize()
|
||||
self.star_context.pipeline_ctx = self.pipeline_scheduler.ctx
|
||||
|
||||
# 初始化更新器
|
||||
self.astrbot_updator = AstrBotUpdator()
|
||||
@@ -232,6 +246,9 @@ class AstrBotCoreLifecycle:
|
||||
platform_insts = self.platform_manager.get_insts()
|
||||
for platform_inst in platform_insts:
|
||||
tasks.append(
|
||||
asyncio.create_task(platform_inst.run(), name=f"{platform_inst.meta().id}({platform_inst.meta().name})")
|
||||
asyncio.create_task(
|
||||
platform_inst.run(),
|
||||
name=f"{platform_inst.meta().id}({platform_inst.meta().name})",
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
@@ -205,6 +205,7 @@ class BaseDatabase(abc.ABC):
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
) -> Persona:
|
||||
"""Insert a new persona record."""
|
||||
...
|
||||
@@ -219,6 +220,22 @@ class BaseDatabase(abc.ABC):
|
||||
"""Get all personas for a specific bot."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str = None,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
) -> Persona | None:
|
||||
"""Update a persona's system prompt or begin dialogs."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_persona(self, persona_id: str) -> None:
|
||||
"""Delete a persona by its ID."""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
async def insert_preference_or_update(self, key: str, value: str) -> Preference:
|
||||
"""Insert a new preference record."""
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.api import logger
|
||||
from .migra_3_to_4 import (
|
||||
migration_conversation_table,
|
||||
migration_platform_table,
|
||||
migration_webchat_data,
|
||||
migration_persona_data,
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +26,9 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:
|
||||
|
||||
|
||||
async def do_migration_v4(
|
||||
db_helper: BaseDatabase, platform_id_map: dict[str, dict[str, str]]
|
||||
db_helper: BaseDatabase,
|
||||
platform_id_map: dict[str, dict[str, str]],
|
||||
astrbot_config: AstrBotConfig,
|
||||
):
|
||||
"""
|
||||
执行数据库迁移
|
||||
@@ -45,6 +49,9 @@ async def do_migration_v4(
|
||||
# 执行 WebChat 数据迁移
|
||||
await migration_webchat_data(db_helper, platform_id_map)
|
||||
|
||||
# 执行人格数据迁移
|
||||
await migration_persona_data(db_helper, astrbot_config)
|
||||
|
||||
# 标记迁移完成
|
||||
await db_helper.insert_preference_or_update("migration_done_v4", "true")
|
||||
|
||||
|
||||
@@ -4,12 +4,10 @@ from .. import BaseDatabase
|
||||
from .sqlite_v3 import SQLiteDatabase as SQLiteV3DatabaseV3
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from astrbot.core.db.po import (
|
||||
ConversationV2,
|
||||
PlatformMessageHistory,
|
||||
)
|
||||
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory
|
||||
from sqlalchemy import text
|
||||
|
||||
"""
|
||||
@@ -50,20 +48,21 @@ async def migration_conversation_table(
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
for conversation in conversations:
|
||||
for idx, conversation in enumerate(conversations):
|
||||
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
|
||||
progress = int((idx + 1) / total_cnt * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
|
||||
try:
|
||||
conv = db_helper_v3.get_conversation_by_user_id(
|
||||
user_id=conversation.get("user_id", "unknown"),
|
||||
cid=conversation.get("cid", "unknown"),
|
||||
)
|
||||
if not conv:
|
||||
logger.warning(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
||||
)
|
||||
if ":" not in conv.user_id:
|
||||
logger.warning(
|
||||
f"跳过 user_id 为 {conv.user_id} 的会话,它可能是 WebChat 的消息历史记录。"
|
||||
)
|
||||
continue
|
||||
session = MessageSesion.from_str(session_str=conv.user_id)
|
||||
platform_id = get_platform_id(
|
||||
@@ -81,13 +80,12 @@ async def migration_conversation_table(
|
||||
updated_at=datetime.datetime.fromtimestamp(conv.updated_at),
|
||||
)
|
||||
dbsession.add(conv_v2)
|
||||
if conv_v2:
|
||||
logger.info(f"迁移旧会话 {conv.cid} 到新表成功。")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"迁移旧会话 {conversation.get('cid', 'unknown')} 失败: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(f"成功迁移 {total_cnt} 条旧的会话数据到新表。")
|
||||
|
||||
|
||||
async def migration_platform_table(
|
||||
@@ -107,7 +105,7 @@ async def migration_platform_table(
|
||||
platform_stats_v3 = stats.platform
|
||||
|
||||
if not platform_stats_v3:
|
||||
logger.warning("没有找到旧平台数据,跳过迁移。")
|
||||
logger.info("没有找到旧平台数据,跳过迁移。")
|
||||
return
|
||||
|
||||
first_time_stamp = platform_stats_v3[0].timestamp
|
||||
@@ -120,7 +118,13 @@ async def migration_platform_table(
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
for bucket_end in range(start_time, end_time, 3600):
|
||||
total_buckets = (end_time - start_time) // 3600
|
||||
for bucket_idx, bucket_end in enumerate(range(start_time, end_time, 3600)):
|
||||
if bucket_idx % 500 == 0:
|
||||
progress = int((bucket_idx + 1) / total_buckets * 100)
|
||||
logger.info(
|
||||
f"进度: {progress}% ({bucket_idx + 1}/{total_buckets})"
|
||||
)
|
||||
cnt = 0
|
||||
while (
|
||||
idx < len(platform_stats_v3)
|
||||
@@ -136,9 +140,6 @@ async def migration_platform_table(
|
||||
platform_type = get_platform_type(
|
||||
platform_id_map, platform_stats_v3[idx].name
|
||||
)
|
||||
logger.info(
|
||||
f"迁移平台统计数据: {platform_id}, {platform_type}, 时间戳: {bucket_end}, 计数: {cnt}"
|
||||
)
|
||||
try:
|
||||
await dbsession.execute(
|
||||
text("""
|
||||
@@ -161,6 +162,7 @@ async def migration_platform_table(
|
||||
f"迁移平台统计数据失败: {platform_id}, {platform_type}, 时间戳: {bucket_end}",
|
||||
exc_info=True,
|
||||
)
|
||||
logger.info(f"成功迁移 {len(platform_stats_v3)} 条旧的平台数据到新表。")
|
||||
|
||||
|
||||
async def migration_webchat_data(
|
||||
@@ -178,20 +180,21 @@ async def migration_webchat_data(
|
||||
async with db_helper.get_db() as dbsession:
|
||||
dbsession: AsyncSession
|
||||
async with dbsession.begin():
|
||||
for conversation in conversations:
|
||||
for idx, conversation in enumerate(conversations):
|
||||
if total_cnt > 0 and (idx + 1) % max(1, total_cnt // 10) == 0:
|
||||
progress = int((idx + 1) / total_cnt * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_cnt})")
|
||||
try:
|
||||
conv = db_helper_v3.get_conversation_by_user_id(
|
||||
user_id=conversation.get("user_id", "unknown"),
|
||||
cid=conversation.get("cid", "unknown"),
|
||||
)
|
||||
if not conv:
|
||||
logger.warning(
|
||||
logger.info(
|
||||
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。"
|
||||
)
|
||||
if ":" in conv.user_id:
|
||||
logger.warning(
|
||||
f"跳过 user_id 为 {conv.user_id} 的会话,它不是 WebChat 的消息历史记录。"
|
||||
)
|
||||
continue
|
||||
platform_id = "webchat"
|
||||
history = json.loads(conv.history) if conv.history else []
|
||||
@@ -206,9 +209,52 @@ async def migration_webchat_data(
|
||||
)
|
||||
dbsession.add(new_history)
|
||||
|
||||
logger.info(f"迁移旧 WebChat 会话 {conv.cid} 到新表成功。")
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"迁移旧 WebChat 会话 {conversation.get('cid', 'unknown')} 失败",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
logger.info(f"成功迁移 {total_cnt} 条旧的 WebChat 会话数据到新表。")
|
||||
|
||||
|
||||
async def migration_persona_data(
|
||||
db_helper: BaseDatabase, astrbot_config: AstrBotConfig
|
||||
):
|
||||
"""
|
||||
迁移 Persona 数据到新的表中。
|
||||
旧的 Persona 数据存储在 preference 中,新的 Persona 数据存储在 persona 表中。
|
||||
"""
|
||||
v3_persona_config: list[dict] = astrbot_config.get("persona", [])
|
||||
total_personas = len(v3_persona_config)
|
||||
logger.info(f"迁移 {total_personas} 个 Persona 配置到新表中...")
|
||||
|
||||
for idx, persona in enumerate(v3_persona_config):
|
||||
if total_personas > 0 and (idx + 1) % max(1, total_personas // 10) == 0:
|
||||
progress = int((idx + 1) / total_personas * 100)
|
||||
if progress % 10 == 0:
|
||||
logger.info(f"进度: {progress}% ({idx + 1}/{total_personas})")
|
||||
try:
|
||||
begin_dialogs = persona.get("begin_dialogs", [])
|
||||
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
||||
mood_prompt = ""
|
||||
user_turn = True
|
||||
for mood_dialog in mood_imitation_dialogs:
|
||||
if user_turn:
|
||||
mood_prompt += f"A: {mood_dialog}\n"
|
||||
else:
|
||||
mood_prompt += f"B: {mood_dialog}\n"
|
||||
user_turn = not user_turn
|
||||
system_prompt = persona.get("prompt", "")
|
||||
if mood_prompt:
|
||||
system_prompt += f"Here are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n {mood_prompt}"
|
||||
persona_new = await db_helper.insert_persona(
|
||||
persona_id=persona["name"],
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs,
|
||||
)
|
||||
logger.info(
|
||||
f"迁移 Persona {persona['name']}({persona_new.system_prompt[:30]}...) 到新表成功。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
+33
-6
@@ -9,7 +9,7 @@ from sqlmodel import (
|
||||
UniqueConstraint,
|
||||
Field,
|
||||
)
|
||||
from typing import Optional
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
|
||||
class PlatformStat(SQLModel, table=True):
|
||||
@@ -39,12 +39,14 @@ class PlatformStat(SQLModel, table=True):
|
||||
class ConversationV2(SQLModel, table=True):
|
||||
__tablename__ = "conversations"
|
||||
|
||||
inner_conversation_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
inner_conversation_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
conversation_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4())
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False)
|
||||
@@ -78,6 +80,8 @@ class Persona(SQLModel, table=True):
|
||||
system_prompt: str = Field(sa_type=Text, nullable=False)
|
||||
begin_dialogs: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
"""a list of strings, each representing a dialog to start with"""
|
||||
tools: Optional[list] = Field(default=None, sa_type=JSON)
|
||||
"""None means use ALL tools for default, empty list means no tools, otherwise a list of tool names."""
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc),
|
||||
@@ -119,7 +123,9 @@ class PlatformMessageHistory(SQLModel, table=True):
|
||||
platform_id: str = Field(nullable=False)
|
||||
user_id: str = Field(nullable=False) # An id of group, user in platform
|
||||
sender_id: Optional[str] = Field(default=None) # ID of the sender in the platform
|
||||
sender_name: Optional[str] = Field(default=None) # Name of the sender in the platform
|
||||
sender_name: Optional[str] = Field(
|
||||
default=None
|
||||
) # Name of the sender in the platform
|
||||
content: dict = Field(sa_type=JSON, nullable=False) # a message chain list
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
updated_at: datetime = Field(
|
||||
@@ -136,12 +142,14 @@ class Attachment(SQLModel, table=True):
|
||||
|
||||
__tablename__ = "attachments"
|
||||
|
||||
inner_attachment_id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
|
||||
inner_attachment_id: int = Field(
|
||||
primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||
)
|
||||
attachment_id: str = Field(
|
||||
max_length=36,
|
||||
nullable=False,
|
||||
unique=True,
|
||||
default_factory=lambda: str(uuid.uuid4())
|
||||
default_factory=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
path: str = Field(nullable=False) # Path to the file on disk
|
||||
type: str = Field(nullable=False) # Type of the file (e.g., 'image', 'file')
|
||||
@@ -182,6 +190,25 @@ class Conversation:
|
||||
updated_at: int = 0
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
"""LLM 人格类。
|
||||
|
||||
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
|
||||
"""
|
||||
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: list[str] = []
|
||||
mood_imitation_dialogs: list[str] = []
|
||||
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
|
||||
tools: list[str] | None = None
|
||||
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: list[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
|
||||
|
||||
# ====
|
||||
# Deprecated, and will be removed in future versions.
|
||||
# ====
|
||||
|
||||
@@ -19,6 +19,7 @@ from sqlalchemy import select, update, delete, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||
|
||||
class SQLiteDatabase(BaseDatabase):
|
||||
def __init__(self, db_path: str) -> None:
|
||||
@@ -99,9 +100,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
# Conversation Management
|
||||
# ====
|
||||
|
||||
async def get_conversations(
|
||||
self, user_id=None, platform_id=None
|
||||
):
|
||||
async def get_conversations(self, user_id=None, platform_id=None):
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
query = select(ConversationV2)
|
||||
@@ -224,7 +223,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
return
|
||||
query = query.values(**values)
|
||||
await session.execute(query)
|
||||
return await self.get_conversation_by_id(cid)
|
||||
return await self.get_conversation_by_id(cid)
|
||||
|
||||
async def delete_conversation(self, cid):
|
||||
async with self.get_db() as session:
|
||||
@@ -312,7 +311,9 @@ class SQLiteDatabase(BaseDatabase):
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def insert_persona(self, persona_id, system_prompt, begin_dialogs=None):
|
||||
async def insert_persona(
|
||||
self, persona_id, system_prompt, begin_dialogs=None, tools=None
|
||||
):
|
||||
"""Insert a new persona record."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
@@ -321,6 +322,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
persona_id=persona_id,
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs or [],
|
||||
tools=tools,
|
||||
)
|
||||
session.add(new_persona)
|
||||
return new_persona
|
||||
@@ -341,6 +343,36 @@ class SQLiteDatabase(BaseDatabase):
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_persona(
|
||||
self, persona_id, system_prompt=None, begin_dialogs=None, tools=NOT_GIVEN
|
||||
):
|
||||
"""Update a persona's system prompt or begin dialogs."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
query = update(Persona).where(Persona.persona_id == persona_id)
|
||||
values = {}
|
||||
if system_prompt is not None:
|
||||
values["system_prompt"] = system_prompt
|
||||
if begin_dialogs is not None:
|
||||
values["begin_dialogs"] = begin_dialogs
|
||||
if tools is not NOT_GIVEN:
|
||||
values["tools"] = tools
|
||||
if not values:
|
||||
return
|
||||
query = query.values(**values)
|
||||
await session.execute(query)
|
||||
return await self.get_persona_by_id(persona_id)
|
||||
|
||||
async def delete_persona(self, persona_id):
|
||||
"""Delete a persona by its ID."""
|
||||
async with self.get_db() as session:
|
||||
session: AsyncSession
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(Persona).where(Persona.persona_id == persona_id)
|
||||
)
|
||||
|
||||
async def insert_preference_or_update(self, key, value):
|
||||
"""Insert a new preference record or update if it exists."""
|
||||
async with self.get_db() as session:
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.db.po import Persona, Personality
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
class PersonaManager:
|
||||
def __init__(self, db_helper: BaseDatabase, astrbot_config: AstrBotConfig):
|
||||
self.db = db_helper
|
||||
self.config = astrbot_config
|
||||
_ps: dict = astrbot_config["provider_settings"]
|
||||
self.default_persona: str = _ps.get("default_personality", "default")
|
||||
self.personas: list[Persona] = []
|
||||
self.selected_default_persona: Persona | None = None
|
||||
|
||||
self.personas_v3: list[Personality] = []
|
||||
self.selected_default_persona_v3: Personality | None = None
|
||||
self.persona_v3_config: list[dict] = []
|
||||
|
||||
async def initialize(self):
|
||||
self.personas = await self.get_all_personas()
|
||||
self.get_v3_persona_data()
|
||||
logger.info(f"已加载 {len(self.personas)} 个人格。")
|
||||
|
||||
async def get_persona(self, persona_id: str):
|
||||
"""获取指定 persona 的信息"""
|
||||
persona = await self.db.get_persona_by_id(persona_id)
|
||||
if not persona:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
return persona
|
||||
|
||||
async def delete_persona(self, persona_id: str):
|
||||
"""删除指定 persona"""
|
||||
if not await self.db.get_persona_by_id(persona_id):
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
await self.db.delete_persona(persona_id)
|
||||
self.personas = [p for p in self.personas if p.persona_id != persona_id]
|
||||
self.get_v3_persona_data()
|
||||
|
||||
async def update_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str = None,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
):
|
||||
"""更新指定 persona 的信息。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
existing_persona = await self.db.get_persona_by_id(persona_id)
|
||||
if not existing_persona:
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist.")
|
||||
persona = await self.db.update_persona(
|
||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
||||
)
|
||||
if persona:
|
||||
for i, p in enumerate(self.personas):
|
||||
if p.persona_id == persona_id:
|
||||
self.personas[i] = persona
|
||||
break
|
||||
self.get_v3_persona_data()
|
||||
return persona
|
||||
|
||||
async def get_all_personas(self) -> list[Persona]:
|
||||
"""获取所有 personas"""
|
||||
return await self.db.get_personas()
|
||||
|
||||
async def create_persona(
|
||||
self,
|
||||
persona_id: str,
|
||||
system_prompt: str,
|
||||
begin_dialogs: list[str] = None,
|
||||
tools: list[str] = None,
|
||||
) -> Persona:
|
||||
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
|
||||
if await self.db.get_persona_by_id(persona_id):
|
||||
raise ValueError(f"Persona with ID {persona_id} already exists.")
|
||||
new_persona = await self.db.insert_persona(
|
||||
persona_id, system_prompt, begin_dialogs, tools=tools
|
||||
)
|
||||
self.personas.append(new_persona)
|
||||
self.get_v3_persona_data()
|
||||
return new_persona
|
||||
|
||||
def get_v3_persona_data(
|
||||
self,
|
||||
) -> tuple[list[dict], list[Personality], Personality]:
|
||||
"""获取 AstrBot <4.0.0 版本的 persona 数据。
|
||||
|
||||
Returns:
|
||||
- list[dict]: 包含 persona 配置的字典列表。
|
||||
- list[Personality]: 包含 Personality 对象的列表。
|
||||
- Personality: 默认选择的 Personality 对象。
|
||||
"""
|
||||
v3_persona_config = [
|
||||
{
|
||||
"prompt": persona.system_prompt,
|
||||
"name": persona.persona_id,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"mood_imitation_dialogs": [], # deprecated
|
||||
"tools": persona.tools,
|
||||
}
|
||||
for persona in self.personas
|
||||
]
|
||||
|
||||
personas_v3: list[Personality] = []
|
||||
selected_default_persona: Personality | None = None
|
||||
|
||||
for persona_cfg in v3_persona_config:
|
||||
begin_dialogs = persona_cfg.get("begin_dialogs", [])
|
||||
bd_processed = []
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(
|
||||
f"{persona_cfg['name']} 人格情景预设对话格式不对,条数应该为偶数。"
|
||||
)
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append(
|
||||
{
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
}
|
||||
)
|
||||
user_turn = not user_turn
|
||||
|
||||
try:
|
||||
persona = Personality(
|
||||
**persona_cfg,
|
||||
_begin_dialogs_processed=bd_processed,
|
||||
_mood_imitation_dialogs_processed="", # deprecated
|
||||
)
|
||||
if persona["name"] == self.default_persona:
|
||||
selected_default_persona = persona
|
||||
personas_v3.append(persona)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
if not selected_default_persona and len(personas_v3) > 0:
|
||||
# 默认选择第一个
|
||||
selected_default_persona = personas_v3[0]
|
||||
|
||||
if not selected_default_persona:
|
||||
selected_default_persona = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
tools=None,
|
||||
_begin_dialogs_processed=[],
|
||||
)
|
||||
personas_v3.append(selected_default_persona)
|
||||
|
||||
self.personas_v3 = personas_v3
|
||||
self.selected_default_persona_v3 = selected_default_persona
|
||||
self.persona_v3_config = v3_persona_config
|
||||
self.selected_default_persona = Persona(
|
||||
persona_id=selected_default_persona["name"],
|
||||
system_prompt=selected_default_persona["prompt"],
|
||||
begin_dialogs=selected_default_persona["begin_dialogs"],
|
||||
tools=selected_default_persona["tools"] or None,
|
||||
)
|
||||
|
||||
return v3_persona_config, personas_v3, selected_default_persona
|
||||
@@ -55,7 +55,7 @@ class PipelineContext:
|
||||
handler: T.Awaitable,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> T.AsyncGenerator[None, None]:
|
||||
) -> T.AsyncGenerator[T.Any, None]:
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
@@ -77,10 +77,16 @@ class PipelineContext:
|
||||
try:
|
||||
ready_to_call = handler(event, *args, **kwargs)
|
||||
except TypeError as _:
|
||||
if self.plugin_manager:
|
||||
context = self.plugin_manager.context
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot call handler without a valid context or plugin manager."
|
||||
)
|
||||
# 向下兼容
|
||||
trace_ = traceback.format_exc()
|
||||
# 以前的 handler 会额外传入一个参数, 但是 context 对象实际上在插件实例中有一份
|
||||
ready_to_call = handler(event, self.plugin_manager.context, *args, **kwargs)
|
||||
ready_to_call = handler(event, context, *args, **kwargs)
|
||||
|
||||
if inspect.isasyncgen(ready_to_call):
|
||||
_has_yielded = False
|
||||
|
||||
@@ -21,6 +21,7 @@ from mcp.types import (
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
CallToolResult,
|
||||
)
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot import logger
|
||||
@@ -193,50 +194,25 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
if not req.func_tool:
|
||||
return
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
if func_tool.origin == "mcp":
|
||||
logger.info(
|
||||
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||
)
|
||||
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||
if not res:
|
||||
continue
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain(type="tool_direct_result").base64_image(
|
||||
res.content[0].data
|
||||
)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
executor = func_tool.execute(
|
||||
event=self.event,
|
||||
pipeline_context=self.pipeline_ctx,
|
||||
**func_tool_args,
|
||||
)
|
||||
async for resp in executor:
|
||||
if isinstance(resp, CallToolResult):
|
||||
res = resp
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
yield MessageChain().message(res.content[0].text)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
@@ -247,41 +223,54 @@ class ToolLoopAgent(BaseAgentRunner):
|
||||
yield MessageChain(type="tool_direct_result").base64_image(
|
||||
res.content[0].data
|
||||
)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
else:
|
||||
logger.info(f"使用工具:{func_tool_name},参数:{func_tool_args}")
|
||||
# 尝试调用工具函数
|
||||
wrapper = self.pipeline_ctx.call_handler(
|
||||
self.event, func_tool.handler, **func_tool_args
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None:
|
||||
# Tool 返回结果
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resp,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resp)
|
||||
else:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
self._transition_state(AgentState.DONE)
|
||||
if res := self.event.get_result():
|
||||
if res.chain:
|
||||
yield MessageChain(
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
yield MessageChain().message(resource.text)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
yield MessageChain(
|
||||
type="tool_direct_result"
|
||||
).base64_image(res.content[0].data)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
yield MessageChain().message("返回的数据类型不受支持。")
|
||||
elif resp is None:
|
||||
# Tool 直接请求发送消息给用户
|
||||
# 这里我们将直接结束 Agent Loop。
|
||||
self._transition_state(AgentState.DONE)
|
||||
if res := self.event.get_result():
|
||||
if res.chain:
|
||||
yield MessageChain(
|
||||
chain=res.chain, type="tool_direct_result"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||
)
|
||||
|
||||
self.event.clear_result()
|
||||
except Exception as e:
|
||||
|
||||
@@ -100,7 +100,8 @@ class LLMRequestSubStage(Stage):
|
||||
if not event.message_str.startswith(self.provider_wake_prefix):
|
||||
return
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
|
||||
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
|
||||
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_path = await comp.convert_to_file_path()
|
||||
@@ -274,7 +275,6 @@ class LLMRequestSubStage(Stage):
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(self._handle_webchat(event, req, provider))
|
||||
|
||||
|
||||
async def _handle_webchat(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, prov: Provider
|
||||
):
|
||||
@@ -307,19 +307,10 @@ class LLMRequestSubStage(Stage):
|
||||
if not title or "<None>" in title:
|
||||
return
|
||||
await self.conv_manager.update_conversation_title(
|
||||
event.unified_msg_origin, title=title
|
||||
unified_msg_origin=event.unified_msg_origin,
|
||||
title=title,
|
||||
conversation_id=req.conversation.cid,
|
||||
)
|
||||
# 由于 WebChat 平台特殊性,其有两个对话,因此我们要更新两个对话的标题
|
||||
# webchat adapter 中,session_id 的格式是 f"webchat!{username}!{cid}"
|
||||
# TODO: 优化 WebChat 适配器的对话管理
|
||||
if event.session_id:
|
||||
username, cid = event.session_id.split("!")[1:3]
|
||||
db_helper = self.ctx.plugin_manager.context._db
|
||||
db_helper.update_conversation_title(
|
||||
user_id=username,
|
||||
cid=cid,
|
||||
title=title,
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self,
|
||||
|
||||
@@ -5,7 +5,7 @@ from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot import logger
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Type
|
||||
from .func_tool_manager import FuncCall
|
||||
from .func_tool_manager import FunctionToolManager, ToolSet
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
@@ -97,7 +97,7 @@ class ProviderRequest:
|
||||
"""会话 ID"""
|
||||
image_urls: list[str] = field(default_factory=list)
|
||||
"""图片 URL 列表"""
|
||||
func_tool: FuncCall | None = None
|
||||
func_tool: FunctionToolManager | ToolSet | None = None
|
||||
"""可用的函数工具"""
|
||||
contexts: list[dict] = field(default_factory=list)
|
||||
"""上下文。格式与 openai 的上下文格式一致:
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from __future__ import annotations
|
||||
import json
|
||||
import textwrap
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from deprecated import deprecated
|
||||
|
||||
from typing import Dict, List, Awaitable, Literal, Any
|
||||
from typing import Dict, List, Awaitable, Literal, Any, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
@@ -28,6 +29,13 @@ except (ModuleNotFoundError, ImportError):
|
||||
"警告: 缺少依赖库 'mcp' 或者 mcp 库版本过低,无法使用 Streamable HTTP 连接方式。"
|
||||
)
|
||||
|
||||
from typing_extensions import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.pipeline.context import PipelineContext
|
||||
|
||||
|
||||
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
|
||||
|
||||
SUPPORTED_TYPES = [
|
||||
@@ -39,6 +47,302 @@ SUPPORTED_TYPES = [
|
||||
] # json schema 支持的数据类型
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTool:
|
||||
"""A class representing a function tool that can be used in function calling."""
|
||||
|
||||
name: str
|
||||
parameters: Dict
|
||||
description: str
|
||||
handler: Awaitable = None
|
||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||
handler_module_path: str = None
|
||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||
|
||||
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||
"""
|
||||
active: bool = True
|
||||
"""是否激活"""
|
||||
|
||||
origin: Literal["local", "mcp"] = "local"
|
||||
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
||||
|
||||
# MCP 相关字段
|
||||
mcp_server_name: str = None
|
||||
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
||||
mcp_client: MCPClient = None
|
||||
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
event: AstrMessageEvent = None,
|
||||
pipeline_context: "PipelineContext" = None,
|
||||
**tool_args,
|
||||
) -> AsyncGenerator[Any | mcp.types.CallToolResult, None]:
|
||||
"""执行函数调用。
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。
|
||||
pipeline_context (PipelineContext): 流水线调度器上下文, 当 origin 为 local 时必须提供。
|
||||
**kwargs: 函数调用的参数。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None | mcp.types.CallToolResult, None]
|
||||
"""
|
||||
if self.origin == "local":
|
||||
if not event:
|
||||
raise ValueError("Event must be provided for local function tools.")
|
||||
wrapper = pipeline_context.call_handler(
|
||||
event=event,
|
||||
handler=self.handler,
|
||||
**tool_args,
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None:
|
||||
text_content = mcp.types.TextContent(
|
||||
type="text",
|
||||
text=str(resp),
|
||||
)
|
||||
yield mcp.types.CallToolResult(content=[text_content])
|
||||
else:
|
||||
# NOTE: Tool 在这里直接请求发送消息给用户
|
||||
# TODO: 是否需要判断 event.get_result() 是否为空?
|
||||
# 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容"
|
||||
yield None
|
||||
|
||||
elif self.origin == "mcp":
|
||||
if not self.mcp_client:
|
||||
raise ValueError("MCP client is not available for MCP function tools.")
|
||||
res = await self.mcp_client.session.call_tool(
|
||||
name=self.name,
|
||||
arguments=tool_args,
|
||||
)
|
||||
if not res:
|
||||
return
|
||||
yield res
|
||||
|
||||
else:
|
||||
raise Exception(f"Unknown function origin: {self.origin}")
|
||||
|
||||
def __dict__(self) -> dict[str, Any]:
|
||||
"""将 FunctionTool 转换为字典格式"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"parameters": self.parameters,
|
||||
"description": self.description,
|
||||
"active": self.active,
|
||||
"origin": self.origin,
|
||||
"mcp_server_name": self.mcp_server_name,
|
||||
}
|
||||
|
||||
|
||||
# alias for FunctionTool
|
||||
FuncTool = FunctionTool
|
||||
|
||||
|
||||
class ToolSet:
|
||||
"""A set of function tools that can be used in function calling.
|
||||
|
||||
This class provides methods to add, remove, and retrieve tools, as well as
|
||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
||||
|
||||
def __init__(self, tools: List[FunctionTool] = None):
|
||||
self.tools: List[FunctionTool] = tools or []
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the tool set is empty."""
|
||||
return len(self.tools) == 0
|
||||
|
||||
def add_tool(self, tool: FunctionTool):
|
||||
"""Add a tool to the set."""
|
||||
# 检查是否已存在同名工具
|
||||
for i, existing_tool in enumerate(self.tools):
|
||||
if existing_tool.name == tool.name:
|
||||
self.tools[i] = tool
|
||||
return
|
||||
self.tools.append(tool)
|
||||
|
||||
def remove_tool(self, name: str):
|
||||
"""Remove a tool by its name."""
|
||||
self.tools = [tool for tool in self.tools if tool.name != name]
|
||||
|
||||
def get_tool(self, name: str) -> Optional[FunctionTool]:
|
||||
"""Get a tool by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.name == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
||||
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
|
||||
"""Add a function tool to the set."""
|
||||
params = {
|
||||
"type": "object", # hard-coded here
|
||||
"properties": {},
|
||||
}
|
||||
for param in func_args:
|
||||
params["properties"][param["name"]] = {
|
||||
"type": param["type"],
|
||||
"description": param["description"],
|
||||
}
|
||||
_func = FunctionTool(
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
handler=handler,
|
||||
)
|
||||
self.add_tool(_func)
|
||||
|
||||
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
|
||||
def remove_func(self, name: str):
|
||||
"""Remove a function tool by its name."""
|
||||
self.remove_tool(name)
|
||||
|
||||
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
||||
def get_func(self, name: str) -> List[FunctionTool]:
|
||||
"""Get all function tools."""
|
||||
return self.get_tool(name)
|
||||
|
||||
def openai_schema(self, omit_empty_parameters: bool = False) -> List[Dict]:
|
||||
"""Convert tools to OpenAI API function calling schema format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
func_def = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
},
|
||||
}
|
||||
|
||||
if tool.parameters.get("properties") or not omit_empty_parameters:
|
||||
func_def["function"]["parameters"] = tool.parameters
|
||||
|
||||
result.append(func_def)
|
||||
return result
|
||||
|
||||
def anthropic_schema(self) -> List[Dict]:
|
||||
"""Convert tools to Anthropic API format."""
|
||||
result = []
|
||||
for tool in self.tools:
|
||||
tool_def = {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": tool.parameters.get("properties", {}),
|
||||
"required": tool.parameters.get("required", []),
|
||||
},
|
||||
}
|
||||
result.append(tool_def)
|
||||
return result
|
||||
|
||||
def google_schema(self) -> Dict:
|
||||
"""Convert tools to Google GenAI API format."""
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
"""Convert schema to Gemini API format."""
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
|
||||
result = {}
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"], set()
|
||||
):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
result["type"] = "null"
|
||||
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
|
||||
if "properties" in schema:
|
||||
properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
prop_value = convert_schema(value)
|
||||
if "default" in prop_value:
|
||||
del prop_value["default"]
|
||||
properties[key] = prop_value
|
||||
|
||||
if properties:
|
||||
result["properties"] = properties
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert_schema(schema["items"])
|
||||
|
||||
return result
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": convert_schema(tool.parameters),
|
||||
}
|
||||
for tool in self.tools
|
||||
]
|
||||
|
||||
declarations = {}
|
||||
if tools:
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
@deprecated(reason="Use openai_schema() instead", version="4.0.0")
|
||||
def get_func_desc_openai_style(self, omit_empty_parameters: bool = False):
|
||||
return self.openai_schema(omit_empty_parameters)
|
||||
|
||||
@deprecated(reason="Use anthropic_schema() instead", version="4.0.0")
|
||||
def get_func_desc_anthropic_style(self):
|
||||
return self.anthropic_schema()
|
||||
|
||||
@deprecated(reason="Use google_schema() instead", version="4.0.0")
|
||||
def get_func_desc_google_genai_style(self):
|
||||
return self.google_schema()
|
||||
|
||||
def names(self) -> List[str]:
|
||||
"""获取所有工具的名称列表"""
|
||||
return [tool.name for tool in self.tools]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.tools)
|
||||
|
||||
def __bool__(self):
|
||||
return len(self.tools) > 0
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.tools)
|
||||
|
||||
|
||||
def _prepare_config(config: dict) -> dict:
|
||||
"""准备配置,处理嵌套格式"""
|
||||
if "mcpServers" in config and config["mcpServers"]:
|
||||
@@ -105,55 +409,6 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
|
||||
return False, f"{e!s}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuncTool:
|
||||
"""
|
||||
用于描述一个函数调用工具。
|
||||
"""
|
||||
|
||||
name: str
|
||||
parameters: Dict
|
||||
description: str
|
||||
handler: Awaitable = None
|
||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||
handler_module_path: str = None
|
||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||
|
||||
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
|
||||
"""
|
||||
active: bool = True
|
||||
"""是否激活"""
|
||||
|
||||
origin: Literal["local", "mcp"] = "local"
|
||||
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
|
||||
|
||||
# MCP 相关字段
|
||||
mcp_server_name: str = None
|
||||
"""MCP 服务名称,当 origin 为 mcp 时有效"""
|
||||
mcp_client: MCPClient = None
|
||||
"""MCP 客户端,当 origin 为 mcp 时有效"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
|
||||
|
||||
async def execute(self, **args) -> Any:
|
||||
"""执行函数调用"""
|
||||
if self.origin == "local":
|
||||
if not self.handler:
|
||||
raise Exception(f"Local function {self.name} has no handler")
|
||||
return await self.handler(**args)
|
||||
elif self.origin == "mcp":
|
||||
if not self.mcp_client or not self.mcp_client.session:
|
||||
raise Exception(f"MCP client for {self.name} is not available")
|
||||
# 使用name属性而不是额外的mcp_tool_name
|
||||
actual_tool_name = (
|
||||
self.name.split(":")[-1] if ":" in self.name else self.name
|
||||
)
|
||||
return await self.mcp_client.session.call_tool(actual_tool_name, args)
|
||||
else:
|
||||
raise Exception(f"Unknown function origin: {self.origin}")
|
||||
|
||||
|
||||
class MCPClient:
|
||||
def __init__(self):
|
||||
# Initialize session and client objects
|
||||
@@ -276,10 +531,9 @@ class MCPClient:
|
||||
self.running_event.set() # Set the running event to indicate cleanup is done
|
||||
|
||||
|
||||
class FuncCall:
|
||||
class FunctionToolManager:
|
||||
def __init__(self) -> None:
|
||||
self.func_list: List[FuncTool] = []
|
||||
"""内部加载的 func tools"""
|
||||
self.mcp_client_dict: Dict[str, MCPClient] = {}
|
||||
"""MCP 服务列表"""
|
||||
self.mcp_client_event: Dict[str, asyncio.Event] = {}
|
||||
@@ -331,11 +585,15 @@ class FuncCall:
|
||||
self.func_list.pop(i)
|
||||
break
|
||||
|
||||
def get_func(self, name) -> FuncTool:
|
||||
def get_func(self, name) -> FuncTool | None:
|
||||
for f in self.func_list:
|
||||
if f.name == name:
|
||||
return f
|
||||
return None
|
||||
|
||||
def get_full_tool_set(self) -> ToolSet:
|
||||
"""获取完整工具集"""
|
||||
tool_set = ToolSet(self.func_list.copy())
|
||||
return tool_set
|
||||
|
||||
async def init_mcp_clients(self) -> None:
|
||||
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
|
||||
@@ -556,203 +814,68 @@ class FuncCall:
|
||||
"""
|
||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
_l = []
|
||||
# 处理所有工具(包括本地和MCP工具)
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
func_ = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f.name,
|
||||
# "parameters": f.parameters,
|
||||
"description": f.description,
|
||||
},
|
||||
}
|
||||
func_["function"]["parameters"] = f.parameters
|
||||
if not f.parameters.get("properties") and omit_empty_parameter_field:
|
||||
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段
|
||||
del func_["function"]["parameters"]
|
||||
_l.append(func_)
|
||||
return _l
|
||||
tools = [f for f in self.func_list if f.active]
|
||||
toolset = ToolSet(tools)
|
||||
return toolset.openai_schema(omit_empty_parameters=omit_empty_parameter_field)
|
||||
|
||||
def get_func_desc_anthropic_style(self) -> list:
|
||||
"""
|
||||
获得 Anthropic API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
tools = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
|
||||
# Convert internal format to Anthropic style
|
||||
tool = {
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": f.parameters.get("properties", {}),
|
||||
# Keep the required field from the original parameters if it exists
|
||||
"required": f.parameters.get("required", []),
|
||||
},
|
||||
}
|
||||
tools.append(tool)
|
||||
return tools
|
||||
tools = [f for f in self.func_list if f.active]
|
||||
toolset = ToolSet(tools)
|
||||
return toolset.anthropic_schema()
|
||||
|
||||
def get_func_desc_google_genai_style(self) -> dict:
|
||||
"""
|
||||
获得 Google GenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
tools = [f for f in self.func_list if f.active]
|
||||
toolset = ToolSet(tools)
|
||||
return toolset.google_schema()
|
||||
|
||||
# Gemini API 支持的数据类型和格式
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
"""停用一个已经注册的函数调用工具。
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
"""转换 schema 为 Gemini API 格式"""
|
||||
Returns:
|
||||
如果没找到,会返回 False"""
|
||||
func_tool = self.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
|
||||
# 如果 schema 包含 anyOf,则只返回 anyOf 字段
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
result = {}
|
||||
return True
|
||||
return False
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"], set()
|
||||
):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
# 暂时指定默认为null
|
||||
result["type"] = "null"
|
||||
# 因为不想解决循环引用,所以这里直接传入 star_map 先了...
|
||||
def activate_llm_tool(self, name: str, star_map: dict) -> bool:
|
||||
func_tool = self.get_func(name)
|
||||
if func_tool is not None:
|
||||
if func_tool.handler_module_path in star_map:
|
||||
if not star_map[func_tool.handler_module_path].activated:
|
||||
raise ValueError(
|
||||
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
|
||||
)
|
||||
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
func_tool.active = True
|
||||
|
||||
if "properties" in schema:
|
||||
properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
prop_value = convert_schema(value)
|
||||
if "default" in prop_value:
|
||||
del prop_value["default"]
|
||||
properties[key] = prop_value
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name in inactivated_llm_tools:
|
||||
inactivated_llm_tools.remove(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
if properties: # 只在有非空属性时添加
|
||||
result["properties"] = properties
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert_schema(schema["items"])
|
||||
|
||||
return result
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
**({"parameters": convert_schema(f.parameters)}),
|
||||
}
|
||||
for f in self.func_list
|
||||
if f.active
|
||||
]
|
||||
|
||||
declarations = {}
|
||||
if tools:
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
async def func_call(self, question: str, session_id: str, provider) -> tuple:
|
||||
_l = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
_l.append(
|
||||
{
|
||||
"name": f.name,
|
||||
"parameters": f.parameters,
|
||||
"description": f.description,
|
||||
}
|
||||
)
|
||||
func_definition = json.dumps(_l, ensure_ascii=False)
|
||||
|
||||
prompt = textwrap.dedent(f"""
|
||||
ROLE:
|
||||
你是一个 Function calling AI Agent, 你的任务是将用户的提问转化为函数调用。
|
||||
|
||||
TOOLS:
|
||||
可用的函数列表:
|
||||
|
||||
{func_definition}
|
||||
|
||||
LIMIT:
|
||||
1. 你返回的内容应当能够被 Python 的 json 模块解析的 Json 格式字符串。
|
||||
2. 你的 Json 返回的格式如下:`[{{"name": "<func_name>", "args": <arg_dict>}}, ...]`。参数根据上面提供的函数列表中的参数来填写。
|
||||
3. 允许必要时返回多个函数调用,但需保证这些函数调用的顺序正确。
|
||||
4. 如果用户的提问中不需要用到给定的函数,请直接返回 `{{"res": False}}`。
|
||||
|
||||
EXAMPLE:
|
||||
1. `用户提问`:请问一下天气怎么样? `函数调用`:[{{"name": "get_weather", "args": {{"city": "北京"}}}}]
|
||||
|
||||
用户的提问是:{question}
|
||||
""")
|
||||
|
||||
_c = 0
|
||||
while _c < 3:
|
||||
try:
|
||||
res = await provider.text_chat(prompt, session_id)
|
||||
if res.find("```") != -1:
|
||||
res = res[res.find("```json") + 7 : res.rfind("```")]
|
||||
res = json.loads(res)
|
||||
break
|
||||
except Exception as e:
|
||||
_c += 1
|
||||
if _c == 3:
|
||||
raise e
|
||||
if "The message you submitted was too long" in str(e):
|
||||
raise e
|
||||
|
||||
if "res" in res and not res["res"]:
|
||||
return "", False
|
||||
|
||||
tool_call_result = []
|
||||
for tool in res:
|
||||
# 说明有函数调用
|
||||
func_name = tool["name"]
|
||||
args = tool["args"]
|
||||
# 调用函数
|
||||
func_tool = self.get_func(func_name)
|
||||
if not func_tool:
|
||||
raise Exception(f"Request function {func_name} not found.")
|
||||
|
||||
ret = await func_tool.execute(**args)
|
||||
if ret:
|
||||
tool_call_result.append(str(ret))
|
||||
return tool_call_result, True
|
||||
return True
|
||||
return False
|
||||
|
||||
def __str__(self):
|
||||
return str(self.func_list)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.func_list)
|
||||
|
||||
|
||||
FuncCall = FunctionToolManager
|
||||
|
||||
@@ -7,85 +7,27 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.db import BaseDatabase
|
||||
|
||||
from .entities import ProviderType
|
||||
from .provider import Personality, Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
|
||||
from .register import llm_tools, provider_cls_map
|
||||
from ..persona_mgr import PersonaManager
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
|
||||
def __init__(
|
||||
self,
|
||||
config: AstrBotConfig,
|
||||
db_helper: BaseDatabase,
|
||||
persona_mgr: PersonaManager,
|
||||
):
|
||||
self.persona_mgr = persona_mgr
|
||||
self.astrbot_config = config
|
||||
self.providers_config: List = config["provider"]
|
||||
self.provider_settings: dict = config["provider_settings"]
|
||||
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
|
||||
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
|
||||
self.persona_configs: list = config.get("persona", [])
|
||||
self.astrbot_config = config
|
||||
|
||||
# 人格情景管理
|
||||
# 目前没有拆成独立的模块
|
||||
self.default_persona_name = self.provider_settings.get(
|
||||
"default_personality", "default"
|
||||
)
|
||||
self.personas: List[Personality] = []
|
||||
self.selected_default_persona = None
|
||||
for persona in self.persona_configs:
|
||||
begin_dialogs = persona.get("begin_dialogs", [])
|
||||
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
|
||||
bd_processed = []
|
||||
mid_processed = ""
|
||||
if begin_dialogs:
|
||||
if len(begin_dialogs) % 2 != 0:
|
||||
logger.error(
|
||||
f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。"
|
||||
)
|
||||
begin_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in begin_dialogs:
|
||||
bd_processed.append(
|
||||
{
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
}
|
||||
)
|
||||
user_turn = not user_turn
|
||||
if mood_imitation_dialogs:
|
||||
if len(mood_imitation_dialogs) % 2 != 0:
|
||||
logger.error(
|
||||
f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。"
|
||||
)
|
||||
mood_imitation_dialogs = []
|
||||
user_turn = True
|
||||
for dialog in mood_imitation_dialogs:
|
||||
role = "A" if user_turn else "B"
|
||||
mid_processed += f"{role}: {dialog}\n"
|
||||
if not user_turn:
|
||||
mid_processed += "\n"
|
||||
user_turn = not user_turn
|
||||
|
||||
try:
|
||||
persona = Personality(
|
||||
**persona,
|
||||
_begin_dialogs_processed=bd_processed,
|
||||
_mood_imitation_dialogs_processed=mid_processed,
|
||||
)
|
||||
if persona["name"] == self.default_persona_name:
|
||||
self.selected_default_persona = persona
|
||||
self.personas.append(persona)
|
||||
except Exception as e:
|
||||
logger.error(f"解析 Persona 配置失败:{e}")
|
||||
|
||||
if not self.selected_default_persona and len(self.personas) > 0:
|
||||
# 默认选择第一个
|
||||
self.selected_default_persona = self.personas[0]
|
||||
|
||||
if not self.selected_default_persona:
|
||||
self.selected_default_persona = Personality(
|
||||
prompt="You are a helpful and friendly assistant.",
|
||||
name="default",
|
||||
_begin_dialogs_processed=[],
|
||||
_mood_imitation_dialogs_processed="",
|
||||
)
|
||||
self.personas.append(self.selected_default_persona)
|
||||
# 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager
|
||||
self.default_persona_name = persona_mgr.default_persona
|
||||
|
||||
self.provider_insts: List[Provider] = []
|
||||
"""加载的 Provider 的实例"""
|
||||
@@ -113,6 +55,21 @@ class ProviderManager:
|
||||
if kdb_cfg and len(kdb_cfg):
|
||||
self.curr_kdb_name = list(kdb_cfg.keys())[0]
|
||||
|
||||
@property
|
||||
def persona_configs(self) -> list:
|
||||
"""动态获取最新的 persona 配置"""
|
||||
return self.persona_mgr.persona_v3_config
|
||||
|
||||
@property
|
||||
def personas(self) -> list:
|
||||
"""动态获取最新的 personas 列表"""
|
||||
return self.persona_mgr.personas_v3
|
||||
|
||||
@property
|
||||
def selected_default_persona(self):
|
||||
"""动态获取最新的默认选中 persona"""
|
||||
return self.persona_mgr.selected_default_persona_v3
|
||||
|
||||
async def set_provider(
|
||||
self, provider_id: str, provider_type: ProviderType, umo: str = None
|
||||
):
|
||||
|
||||
@@ -1,23 +1,13 @@
|
||||
import abc
|
||||
from typing import List
|
||||
from typing import TypedDict, AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager, ToolSet
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType
|
||||
from astrbot.core.provider.register import provider_cls_map
|
||||
from astrbot.core.db.po import Personality
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class Personality(TypedDict):
|
||||
prompt: str = ""
|
||||
name: str = ""
|
||||
begin_dialogs: List[str] = []
|
||||
mood_imitation_dialogs: List[str] = []
|
||||
|
||||
# cache
|
||||
_begin_dialogs_processed: List[dict] = []
|
||||
_mood_imitation_dialogs_processed: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMeta:
|
||||
id: str
|
||||
@@ -90,7 +80,7 @@ class Provider(AbstractProvider):
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
func_tool: FunctionToolManager | ToolSet = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||
@@ -119,7 +109,7 @@ class Provider(AbstractProvider):
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: list[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
func_tool: FunctionToolManager | ToolSet = None,
|
||||
contexts: list = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
|
||||
|
||||
@@ -11,12 +11,14 @@ from astrbot.core.provider.provider import (
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.func_tool_manager import FunctionToolManager
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.platform import Platform
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager
|
||||
from astrbot.core.persona_mgr import PersonaManager
|
||||
from .star import star_registry, StarMetadata, star_map
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
@@ -29,6 +31,11 @@ from astrbot.core.star.filter.platform_adapter_type import (
|
||||
)
|
||||
from deprecated import deprecated
|
||||
|
||||
from typing_extensions import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from astrbot.core.pipeline.context import PipelineContext
|
||||
|
||||
|
||||
class Context:
|
||||
"""
|
||||
@@ -50,6 +57,8 @@ class Context:
|
||||
|
||||
registered_web_apis: list = []
|
||||
|
||||
pipeline_ctx: "PipelineContext" = None
|
||||
|
||||
# back compatibility
|
||||
_register_tasks: List[Awaitable] = []
|
||||
_star_manager = None
|
||||
@@ -62,6 +71,8 @@ class Context:
|
||||
provider_manager: ProviderManager = None,
|
||||
platform_manager: PlatformManager = None,
|
||||
conversation_manager: ConversationManager = None,
|
||||
message_history_manager: PlatformMessageHistoryManager = None,
|
||||
persona_manager: PersonaManager = None,
|
||||
):
|
||||
self._event_queue = event_queue
|
||||
self._config = config
|
||||
@@ -69,6 +80,8 @@ class Context:
|
||||
self.provider_manager = provider_manager
|
||||
self.platform_manager = platform_manager
|
||||
self.conversation_manager = conversation_manager
|
||||
self.message_history_manager = message_history_manager
|
||||
self.persona_manager = persona_manager
|
||||
|
||||
def get_registered_star(self, star_name: str) -> StarMetadata:
|
||||
"""根据插件名获取插件的 Metadata"""
|
||||
@@ -80,7 +93,7 @@ class Context:
|
||||
"""获取当前载入的所有插件 Metadata 的列表"""
|
||||
return star_registry
|
||||
|
||||
def get_llm_tool_manager(self) -> FuncCall:
|
||||
def get_llm_tool_manager(self) -> FunctionToolManager:
|
||||
"""获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools"""
|
||||
return self.provider_manager.llm_tools
|
||||
|
||||
@@ -90,40 +103,14 @@ class Context:
|
||||
Returns:
|
||||
如果没找到,会返回 False
|
||||
"""
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
if func_tool.handler_module_path in star_map:
|
||||
if not star_map[func_tool.handler_module_path].activated:
|
||||
raise ValueError(
|
||||
f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。"
|
||||
)
|
||||
|
||||
func_tool.active = True
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name in inactivated_llm_tools:
|
||||
inactivated_llm_tools.remove(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
return self.provider_manager.llm_tools.activate_llm_tool(name, star_map)
|
||||
|
||||
def deactivate_llm_tool(self, name: str) -> bool:
|
||||
"""停用一个已经注册的函数调用工具。
|
||||
|
||||
Returns:
|
||||
如果没找到,会返回 False"""
|
||||
func_tool = self.provider_manager.llm_tools.get_func(name)
|
||||
if func_tool is not None:
|
||||
func_tool.active = False
|
||||
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
if name not in inactivated_llm_tools:
|
||||
inactivated_llm_tools.append(name)
|
||||
sp.put("inactivated_llm_tools", inactivated_llm_tools)
|
||||
|
||||
return True
|
||||
return False
|
||||
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
"""
|
||||
@@ -208,7 +195,9 @@ class Context:
|
||||
return self._event_queue
|
||||
|
||||
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
|
||||
def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform | None:
|
||||
def get_platform(
|
||||
self, platform_type: Union[PlatformAdapterType, str]
|
||||
) -> Platform | None:
|
||||
"""
|
||||
获取指定类型的平台适配器。
|
||||
|
||||
@@ -268,6 +257,9 @@ class Context:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_pipeline_context(self) -> "PipelineContext":
|
||||
return self.pipeline_ctx
|
||||
|
||||
"""
|
||||
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
|
||||
"""
|
||||
|
||||
@@ -6,11 +6,11 @@ from .stat import StatRoute
|
||||
from .log import LogRoute
|
||||
from .static_file import StaticFileRoute
|
||||
from .chat import ChatRoute
|
||||
from .tools import ToolsRoute # 导入新的ToolsRoute
|
||||
from .tools import ToolsRoute
|
||||
from .conversation import ConversationRoute
|
||||
from .file import FileRoute
|
||||
from .session_management import SessionManagementRoute
|
||||
|
||||
from .persona import PersonaRoute
|
||||
|
||||
__all__ = [
|
||||
"AuthRoute",
|
||||
@@ -25,4 +25,5 @@ __all__ = [
|
||||
"ConversationRoute",
|
||||
"FileRoute",
|
||||
"SessionManagementRoute",
|
||||
"PersonaRoute",
|
||||
]
|
||||
|
||||
@@ -169,7 +169,6 @@ class ConfigRoute(Route):
|
||||
"/config/provider/new": ("POST", self.post_new_provider),
|
||||
"/config/provider/update": ("POST", self.post_update_provider),
|
||||
"/config/provider/delete": ("POST", self.post_delete_provider),
|
||||
"/config/llmtools": ("GET", self.get_llm_tools),
|
||||
"/config/provider/check_one": ("GET", self.check_one_provider_status),
|
||||
"/config/provider/list": ("GET", self.get_provider_config_list),
|
||||
"/config/provider/model_list": ("GET", self.get_provider_model_list),
|
||||
@@ -509,12 +508,6 @@ class ConfigRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
return Response().ok(None, "删除成功,已经实时生效~").__dict__
|
||||
|
||||
async def get_llm_tools(self):
|
||||
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
|
||||
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
tools = tool_mgr.get_func_desc_openai_style()
|
||||
return Response().ok(tools).__dict__
|
||||
|
||||
async def _get_astrbot_config(self):
|
||||
config = self.config
|
||||
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
import traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
from astrbot.core import logger
|
||||
from quart import request
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
|
||||
|
||||
class PersonaRoute(Route):
|
||||
def __init__(
|
||||
self,
|
||||
context: RouteContext,
|
||||
db_helper: BaseDatabase,
|
||||
core_lifecycle: AstrBotCoreLifecycle,
|
||||
) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
"/persona/list": ("GET", self.list_personas),
|
||||
"/persona/detail": ("POST", self.get_persona_detail),
|
||||
"/persona/create": ("POST", self.create_persona),
|
||||
"/persona/update": ("POST", self.update_persona),
|
||||
"/persona/delete": ("POST", self.delete_persona),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.persona_mgr = core_lifecycle.persona_mgr
|
||||
self.register_routes()
|
||||
|
||||
async def list_personas(self):
|
||||
"""获取所有人格列表"""
|
||||
try:
|
||||
personas = await self.persona_mgr.get_all_personas()
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
[
|
||||
{
|
||||
"persona_id": persona.persona_id,
|
||||
"system_prompt": persona.system_prompt,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"tools": persona.tools,
|
||||
"created_at": persona.created_at.isoformat()
|
||||
if persona.created_at
|
||||
else None,
|
||||
"updated_at": persona.updated_at.isoformat()
|
||||
if persona.updated_at
|
||||
else None,
|
||||
}
|
||||
for persona in personas
|
||||
]
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取人格列表失败: {str(e)}\n{traceback.format_exc()}")
|
||||
return Response().error(f"获取人格列表失败: {str(e)}").__dict__
|
||||
|
||||
async def get_persona_detail(self):
|
||||
"""获取指定人格的详细信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id")
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("缺少必要参数: persona_id").__dict__
|
||||
|
||||
persona = await self.persona_mgr.get_persona(persona_id)
|
||||
if not persona:
|
||||
return Response().error("人格不存在").__dict__
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"persona_id": persona.persona_id,
|
||||
"system_prompt": persona.system_prompt,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"tools": persona.tools,
|
||||
"created_at": persona.created_at.isoformat()
|
||||
if persona.created_at
|
||||
else None,
|
||||
"updated_at": persona.updated_at.isoformat()
|
||||
if persona.updated_at
|
||||
else None,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取人格详情失败: {str(e)}\n{traceback.format_exc()}")
|
||||
return Response().error(f"获取人格详情失败: {str(e)}").__dict__
|
||||
|
||||
async def create_persona(self):
|
||||
"""创建新人格"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id", "").strip()
|
||||
system_prompt = data.get("system_prompt", "").strip()
|
||||
begin_dialogs = data.get("begin_dialogs", [])
|
||||
tools = data.get("tools")
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("人格ID不能为空").__dict__
|
||||
|
||||
if not system_prompt:
|
||||
return Response().error("系统提示词不能为空").__dict__
|
||||
|
||||
# 验证 begin_dialogs 格式
|
||||
if begin_dialogs and len(begin_dialogs) % 2 != 0:
|
||||
return (
|
||||
Response()
|
||||
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
persona = await self.persona_mgr.create_persona(
|
||||
persona_id=persona_id,
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs if begin_dialogs else None,
|
||||
tools=tools if tools else None,
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"message": "人格创建成功",
|
||||
"persona": {
|
||||
"persona_id": persona.persona_id,
|
||||
"system_prompt": persona.system_prompt,
|
||||
"begin_dialogs": persona.begin_dialogs or [],
|
||||
"tools": persona.tools or [],
|
||||
"created_at": persona.created_at.isoformat()
|
||||
if persona.created_at
|
||||
else None,
|
||||
"updated_at": persona.updated_at.isoformat()
|
||||
if persona.updated_at
|
||||
else None,
|
||||
},
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"创建人格失败: {str(e)}\n{traceback.format_exc()}")
|
||||
return Response().error(f"创建人格失败: {str(e)}").__dict__
|
||||
|
||||
async def update_persona(self):
|
||||
"""更新人格信息"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id")
|
||||
system_prompt = data.get("system_prompt")
|
||||
begin_dialogs = data.get("begin_dialogs")
|
||||
tools = data.get("tools")
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("缺少必要参数: persona_id").__dict__
|
||||
|
||||
# 验证 begin_dialogs 格式
|
||||
if begin_dialogs is not None and len(begin_dialogs) % 2 != 0:
|
||||
return (
|
||||
Response()
|
||||
.error("预设对话数量必须为偶数(用户和助手轮流对话)")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
await self.persona_mgr.update_persona(
|
||||
persona_id=persona_id,
|
||||
system_prompt=system_prompt,
|
||||
begin_dialogs=begin_dialogs,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
return Response().ok({"message": "人格更新成功"}).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"更新人格失败: {str(e)}\n{traceback.format_exc()}")
|
||||
return Response().error(f"更新人格失败: {str(e)}").__dict__
|
||||
|
||||
async def delete_persona(self):
|
||||
"""删除人格"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
persona_id = data.get("persona_id")
|
||||
|
||||
if not persona_id:
|
||||
return Response().error("缺少必要参数: persona_id").__dict__
|
||||
|
||||
await self.persona_mgr.delete_persona(persona_id)
|
||||
|
||||
return Response().ok({"message": "人格删除成功"}).__dict__
|
||||
except ValueError as e:
|
||||
return Response().error(str(e)).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"删除人格失败: {str(e)}\n{traceback.format_exc()}")
|
||||
return Response().error(f"删除人格失败: {str(e)}").__dict__
|
||||
@@ -8,6 +8,7 @@ from quart import request
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.star import star_map
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -27,6 +28,8 @@ class ToolsRoute(Route):
|
||||
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
|
||||
"/tools/mcp/market": ("GET", self.get_mcp_markets),
|
||||
"/tools/mcp/test": ("POST", self.test_mcp_connection),
|
||||
"/tools/list": ("GET", self.get_tool_list),
|
||||
"/tools/toggle-tool": ("POST", self.toggle_tool),
|
||||
}
|
||||
self.register_routes()
|
||||
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
@@ -336,3 +339,40 @@ class ToolsRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__
|
||||
|
||||
async def get_tool_list(self):
|
||||
"""获取所有注册的工具列表"""
|
||||
try:
|
||||
tools = self.tool_mgr.func_list
|
||||
tools_dict = [tool.__dict__() for tool in tools]
|
||||
return Response().ok(data=tools_dict).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"获取工具列表失败: {str(e)}").__dict__
|
||||
|
||||
async def toggle_tool(self):
|
||||
"""启用或停用指定的工具"""
|
||||
try:
|
||||
data = await request.json
|
||||
tool_name = data.get("name")
|
||||
action = data.get("activate") # True or False
|
||||
|
||||
if not tool_name or action is None:
|
||||
return Response().error("缺少必要参数: name 或 action").__dict__
|
||||
|
||||
if action:
|
||||
try:
|
||||
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
|
||||
except ValueError as e:
|
||||
return Response().error(f"启用工具失败: {str(e)}").__dict__
|
||||
else:
|
||||
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
|
||||
|
||||
if ok:
|
||||
return Response().ok(None, "操作成功。").__dict__
|
||||
else:
|
||||
return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__
|
||||
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"操作工具失败: {str(e)}").__dict__
|
||||
|
||||
@@ -60,6 +60,9 @@ class AstrBotDashboard:
|
||||
self.session_management_route = SessionManagementRoute(
|
||||
self.context, db, core_lifecycle
|
||||
)
|
||||
self.persona_route = PersonaRoute(
|
||||
self.context, db, core_lifecycle
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
"/api/plug/<path:subpath>",
|
||||
|
||||
@@ -52,6 +52,7 @@ export class I18nLoader {
|
||||
{ name: 'features/alkaid/index', path: 'features/alkaid/index.json' },
|
||||
{ name: 'features/alkaid/knowledge-base', path: 'features/alkaid/knowledge-base.json' },
|
||||
{ name: 'features/alkaid/memory', path: 'features/alkaid/memory.json' },
|
||||
{ name: 'features/persona', path: 'features/persona.json' },
|
||||
|
||||
// 消息模块
|
||||
{ name: 'messages/errors', path: 'messages/errors.json' },
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"dashboard": "Dashboard",
|
||||
"platforms": "Platforms",
|
||||
"providers": "Providers",
|
||||
"persona": "Persona",
|
||||
"toolUse": "MCP Tools",
|
||||
"config": "Config",
|
||||
"extension": "Extensions",
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
{
|
||||
"page": {
|
||||
"description": "Manage and configure chat bot personality settings"
|
||||
},
|
||||
"buttons": {
|
||||
"create": "Create Persona",
|
||||
"createFirst": "Create First Persona",
|
||||
"edit": "Edit",
|
||||
"delete": "Delete",
|
||||
"cancel": "Cancel",
|
||||
"save": "Save",
|
||||
"addDialogPair": "Add Dialog Pair"
|
||||
},
|
||||
"labels": {
|
||||
"presetDialogs": "Preset Dialogs ({count} pairs)",
|
||||
"createdAt": "Created At",
|
||||
"updatedAt": "Updated At"
|
||||
},
|
||||
"form": {
|
||||
"personaId": "Persona ID",
|
||||
"systemPrompt": "System Prompt",
|
||||
"presetDialogs": "Preset Dialogs",
|
||||
"presetDialogsHelp": "Add some preset dialogs to help the bot better understand the role settings. The number of dialogs must be even (users and assistants take turns).",
|
||||
"userMessage": "User Message",
|
||||
"assistantMessage": "Assistant Message",
|
||||
"tools": "Tool Selection",
|
||||
"toolsHelp": "Select available tools for this persona. Tools allow the bot to perform specific functions such as searching, calculating, getting information, etc.",
|
||||
"toolsSelection": "Tool Selection Actions",
|
||||
"selectAllTools": "Select All Tools",
|
||||
"clearAllTools": "Clear Selection",
|
||||
"allSelected": "All Selected",
|
||||
"mcpServersQuickSelect": "MCP Servers Quick Select",
|
||||
"searchTools": "Search Tools",
|
||||
"selectedTools": "Selected Tools",
|
||||
"noToolsAvailable": "No tools available",
|
||||
"noToolsFound": "No matching tools found",
|
||||
"loadingTools": "Loading tools...",
|
||||
"allToolsAvailable": "Use all available tools",
|
||||
"noToolsSelected": "No tools selected"
|
||||
},
|
||||
"dialog": {
|
||||
"create": {
|
||||
"title": "Create New Persona"
|
||||
},
|
||||
"edit": {
|
||||
"title": "Edit Persona"
|
||||
}
|
||||
},
|
||||
"empty": {
|
||||
"title": "No Persona Configured",
|
||||
"description": "Create your first persona to start using personalized chatbots"
|
||||
},
|
||||
"validation": {
|
||||
"required": "This field is required",
|
||||
"minLength": "Minimum {min} characters required",
|
||||
"alphanumeric": "Only letters, numbers, underscores and hyphens are allowed",
|
||||
"dialogRequired": "{type} cannot be empty"
|
||||
},
|
||||
"messages": {
|
||||
"loadError": "Failed to load persona list",
|
||||
"saveSuccess": "Saved successfully",
|
||||
"saveError": "Save failed",
|
||||
"deleteConfirm": "Are you sure you want to delete persona \"{id}\"? This action cannot be undone.",
|
||||
"deleteSuccess": "Deleted successfully",
|
||||
"deleteError": "Delete failed"
|
||||
}
|
||||
}
|
||||
@@ -107,6 +107,9 @@
|
||||
"failed": "Import configuration failed: {error}"
|
||||
},
|
||||
"configParseError": "Configuration parse error: {error}",
|
||||
"noAvailableConfig": "No available configuration"
|
||||
"noAvailableConfig": "No available configuration",
|
||||
"toggleToolSuccess": "Tool status toggled successfully!",
|
||||
"toggleToolError": "Failed to toggle tool status: {error}",
|
||||
"testError": "Test connection failed: {error}"
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
"dashboard": "统计",
|
||||
"platforms": "消息平台",
|
||||
"providers": "服务提供商",
|
||||
"persona": "人格管理",
|
||||
"toolUse": "MCP",
|
||||
"config": "配置文件",
|
||||
"extension": "插件管理",
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
{
|
||||
"page": {
|
||||
"description": "管理和配置聊天机器人的人格角色设定"
|
||||
},
|
||||
"buttons": {
|
||||
"create": "创建人格",
|
||||
"createFirst": "创建第一个人格",
|
||||
"edit": "编辑",
|
||||
"delete": "删除",
|
||||
"cancel": "取消",
|
||||
"save": "保存",
|
||||
"addDialogPair": "添加对话对"
|
||||
},
|
||||
"labels": {
|
||||
"presetDialogs": "预设对话 ({count} 对)",
|
||||
"createdAt": "创建时间",
|
||||
"updatedAt": "更新时间"
|
||||
},
|
||||
"form": {
|
||||
"personaId": "人格 ID",
|
||||
"systemPrompt": "系统提示词",
|
||||
"presetDialogs": "预设对话",
|
||||
"presetDialogsHelp": "添加一些预设的对话来帮助机器人更好地理解角色设定。对话数量必须为偶数(用户和助手轮流对话)。",
|
||||
"userMessage": "用户消息",
|
||||
"assistantMessage": "助手消息",
|
||||
"tools": "工具选择",
|
||||
"toolsHelp": "为这个人格选择可用的工具。工具可以让机器人执行特定的功能,如搜索、计算、获取信息等。",
|
||||
"toolsSelection": "工具选择操作",
|
||||
"selectAllTools": "选择所有工具",
|
||||
"clearAllTools": "清空选择",
|
||||
"allSelected": "全选",
|
||||
"mcpServersQuickSelect": "MCP 服务器快速选择",
|
||||
"searchTools": "搜索工具",
|
||||
"selectedTools": "已选择的工具",
|
||||
"noToolsAvailable": "暂无可用工具",
|
||||
"noToolsFound": "未找到匹配的工具",
|
||||
"loadingTools": "正在加载工具...",
|
||||
"allToolsAvailable": "使用所有可用工具",
|
||||
"noToolsSelected": "未选择任何工具"
|
||||
},
|
||||
"dialog": {
|
||||
"create": {
|
||||
"title": "创建新人格"
|
||||
},
|
||||
"edit": {
|
||||
"title": "编辑人格"
|
||||
}
|
||||
},
|
||||
"empty": {
|
||||
"title": "暂无人格配置",
|
||||
"description": "创建您的第一个人格来开始使用个性化的聊天机器人"
|
||||
},
|
||||
"validation": {
|
||||
"required": "此字段为必填项",
|
||||
"minLength": "最少需要 {min} 个字符",
|
||||
"alphanumeric": "只能包含字母、数字、下划线和连字符",
|
||||
"dialogRequired": "{type}不能为空"
|
||||
},
|
||||
"messages": {
|
||||
"loadError": "加载人格列表失败",
|
||||
"saveSuccess": "保存成功",
|
||||
"saveError": "保存失败",
|
||||
"deleteConfirm": "确定要删除人格 \"{id}\" 吗?此操作不可撤销。",
|
||||
"deleteSuccess": "删除成功",
|
||||
"deleteError": "删除失败"
|
||||
}
|
||||
}
|
||||
@@ -107,6 +107,9 @@
|
||||
"failed": "导入配置失败: {error}"
|
||||
},
|
||||
"configParseError": "配置解析错误: {error}",
|
||||
"noAvailableConfig": "无可用配置"
|
||||
"noAvailableConfig": "无可用配置",
|
||||
"toggleToolSuccess": "工具状态切换成功!",
|
||||
"toggleToolError": "工具状态切换失败: {error}",
|
||||
"testError": "测试连接失败: {error}"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,7 @@ import zhCNDashboard from './locales/zh-CN/features/dashboard.json';
|
||||
import zhCNAlkaidIndex from './locales/zh-CN/features/alkaid/index.json';
|
||||
import zhCNAlkaidKnowledgeBase from './locales/zh-CN/features/alkaid/knowledge-base.json';
|
||||
import zhCNAlkaidMemory from './locales/zh-CN/features/alkaid/memory.json';
|
||||
import zhCNPersona from './locales/zh-CN/features/persona.json';
|
||||
|
||||
import zhCNErrors from './locales/zh-CN/messages/errors.json';
|
||||
import zhCNSuccess from './locales/zh-CN/messages/success.json';
|
||||
@@ -54,6 +55,7 @@ import enUSDashboard from './locales/en-US/features/dashboard.json';
|
||||
import enUSAlkaidIndex from './locales/en-US/features/alkaid/index.json';
|
||||
import enUSAlkaidKnowledgeBase from './locales/en-US/features/alkaid/knowledge-base.json';
|
||||
import enUSAlkaidMemory from './locales/en-US/features/alkaid/memory.json';
|
||||
import enUSPersona from './locales/en-US/features/persona.json';
|
||||
|
||||
import enUSErrors from './locales/en-US/messages/errors.json';
|
||||
import enUSSuccess from './locales/en-US/messages/success.json';
|
||||
@@ -88,7 +90,8 @@ export const translations = {
|
||||
index: zhCNAlkaidIndex,
|
||||
'knowledge-base': zhCNAlkaidKnowledgeBase,
|
||||
memory: zhCNAlkaidMemory
|
||||
}
|
||||
},
|
||||
persona: zhCNPersona
|
||||
},
|
||||
messages: {
|
||||
errors: zhCNErrors,
|
||||
@@ -123,7 +126,8 @@ export const translations = {
|
||||
index: enUSAlkaidIndex,
|
||||
'knowledge-base': enUSAlkaidKnowledgeBase,
|
||||
memory: enUSAlkaidMemory
|
||||
}
|
||||
},
|
||||
persona: enUSPersona
|
||||
},
|
||||
messages: {
|
||||
errors: enUSErrors,
|
||||
|
||||
@@ -63,9 +63,13 @@ const sidebarItem: menu[] = [
|
||||
icon: 'mdi-account-group',
|
||||
to: '/session-management'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.persona',
|
||||
icon: 'mdi-heart',
|
||||
to: '/persona'
|
||||
},
|
||||
{
|
||||
title: 'core.navigation.console',
|
||||
|
||||
icon: 'mdi-console',
|
||||
to: '/console'
|
||||
},
|
||||
|
||||
@@ -56,6 +56,11 @@ const MainRoutes = {
|
||||
path: '/session-management',
|
||||
component: () => import('@/views/SessionManagementPage.vue')
|
||||
},
|
||||
{
|
||||
name: 'Persona',
|
||||
path: '/persona',
|
||||
component: () => import('@/views/PersonaPage.vue')
|
||||
},
|
||||
{
|
||||
name: 'Console',
|
||||
path: '/console',
|
||||
|
||||
@@ -0,0 +1,808 @@
|
||||
<template>
|
||||
<div class="persona-page">
|
||||
<v-container fluid class="pa-0">
|
||||
<!-- 页面标题 -->
|
||||
<v-row class="d-flex justify-space-between align-center px-4 py-3 pb-8">
|
||||
<div>
|
||||
<h1 class="text-h1 font-weight-bold mb-2">
|
||||
<v-icon color="black" class="me-2">mdi-heart</v-icon>{{ t('core.navigation.persona') }}
|
||||
</h1>
|
||||
<p class="text-subtitle-1 text-medium-emphasis mb-4">
|
||||
{{ tm('page.description') }}
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<v-btn color="primary" variant="tonal" prepend-icon="mdi-plus" @click="openCreateDialog"
|
||||
rounded="xl" size="x-large">
|
||||
{{ tm('buttons.create') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-row>
|
||||
|
||||
|
||||
<!-- 人格卡片网格 -->
|
||||
<v-row>
|
||||
<v-col v-for="persona in personas" :key="persona.persona_id" cols="12" md="6" lg="4" xl="3">
|
||||
<v-card class="persona-card" elevation="2" rounded="lg" @click="viewPersona(persona)">
|
||||
<v-card-title class="d-flex justify-space-between align-center">
|
||||
<div class="text-truncate ml-2">
|
||||
{{ persona.persona_id }}
|
||||
</div>
|
||||
<v-menu offset-y>
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-btn icon="mdi-dots-vertical" variant="text" size="small" v-bind="props"
|
||||
@click.stop />
|
||||
</template>
|
||||
<v-list density="compact">
|
||||
<v-list-item @click="editPersona(persona)">
|
||||
<v-list-item-title>
|
||||
<v-icon class="mr-2" size="small">mdi-pencil</v-icon>
|
||||
{{ tm('buttons.edit') }}
|
||||
</v-list-item-title>
|
||||
</v-list-item>
|
||||
<v-list-item @click="deletePersona(persona)" class="text-error">
|
||||
<v-list-item-title>
|
||||
<v-icon class="mr-2" size="small">mdi-delete</v-icon>
|
||||
{{ tm('buttons.delete') }}
|
||||
</v-list-item-title>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-menu>
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text>
|
||||
<div class="system-prompt-preview">
|
||||
{{ truncateText(persona.system_prompt, 100) }}
|
||||
</div>
|
||||
|
||||
<div class="mt-3" v-if="persona.begin_dialogs && persona.begin_dialogs.length > 0">
|
||||
<v-chip size="small" color="secondary" variant="tonal" prepend-icon="mdi-chat">
|
||||
{{ tm('labels.presetDialogs', { count: persona.begin_dialogs.length / 2 }) }}
|
||||
</v-chip>
|
||||
</div>
|
||||
|
||||
<div class="mt-3 text-caption text-medium-emphasis">
|
||||
{{ tm('labels.createdAt') }}: {{ formatDate(persona.created_at) }}
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</v-col>
|
||||
|
||||
<!-- 空状态 -->
|
||||
<v-col v-if="personas.length === 0 && !loading" cols="12">
|
||||
<v-card class="text-center pa-8" elevation="0">
|
||||
<v-icon size="64" color="grey-lighten-1" class="mb-4">mdi-account-group</v-icon>
|
||||
<h3 class="text-h5 mb-2">{{ tm('empty.title') }}</h3>
|
||||
<p class="text-body-1 text-medium-emphasis mb-4">{{ tm('empty.description') }}</p>
|
||||
<v-btn color="primary" variant="flat" prepend-icon="mdi-plus" @click="openCreateDialog">
|
||||
{{ tm('buttons.createFirst') }}
|
||||
</v-btn>
|
||||
</v-card>
|
||||
</v-col>
|
||||
</v-row>
|
||||
|
||||
<!-- 加载状态 -->
|
||||
<v-row v-if="loading">
|
||||
<v-col v-for="n in 6" :key="n" cols="12" md="6" lg="4" xl="3">
|
||||
<v-skeleton-loader type="card" rounded="lg"></v-skeleton-loader>
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-container>
|
||||
|
||||
<!-- 创建/编辑人格对话框 -->
|
||||
<v-dialog v-model="showPersonaDialog" max-width="800px" persistent>
|
||||
<v-card>
|
||||
<v-card-title class="text-h5">
|
||||
{{ editingPersona ? tm('dialog.edit.title') : tm('dialog.create.title') }}
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text>
|
||||
<v-form ref="personaForm" v-model="formValid">
|
||||
<v-text-field v-model="personaForm.persona_id" :label="tm('form.personaId')"
|
||||
:rules="personaIdRules" :disabled="editingPersona" variant="outlined" density="comfortable"
|
||||
class="mb-4" />
|
||||
|
||||
<v-textarea v-model="personaForm.system_prompt" :label="tm('form.systemPrompt')"
|
||||
:rules="systemPromptRules" variant="outlined" rows="6" class="mb-4" />
|
||||
|
||||
<v-expansion-panels v-model="expandedPanels" multiple>
|
||||
<!-- 工具选择面板 -->
|
||||
<v-expansion-panel value="tools">
|
||||
<v-expansion-panel-title>
|
||||
<v-icon class="mr-2">mdi-tools</v-icon>
|
||||
{{ tm('form.tools') }}
|
||||
<v-chip v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
|
||||
size="small" color="primary" variant="tonal" class="ml-2">
|
||||
{{ personaForm.tools.length }}
|
||||
</v-chip>
|
||||
</v-expansion-panel-title>
|
||||
|
||||
<v-expansion-panel-text>
|
||||
<div class="mb-3">
|
||||
<p class="text-body-2 text-medium-emphasis">
|
||||
{{ tm('form.toolsHelp') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<v-radio-group class="mt-2" v-model="toolSelectValue" hide-details="true">
|
||||
<v-radio label="默认使用全部函数工具" value="0"></v-radio>
|
||||
<v-radio label="选择指定函数工具" value="1">
|
||||
</v-radio>
|
||||
</v-radio-group>
|
||||
|
||||
<div v-if="toolSelectValue === '1'" class="mt-3 ml-8">
|
||||
|
||||
<!-- 工具搜索 -->
|
||||
<v-text-field v-model="toolSearch" :label="tm('form.searchTools')"
|
||||
prepend-inner-icon="mdi-magnify" variant="outlined" density="compact"
|
||||
hide-details clearable class="mb-3" />
|
||||
|
||||
|
||||
<!-- MCP 服务器 -->
|
||||
<div v-if="mcpServers.length > 0" class="mb-4">
|
||||
<h4 class="text-subtitle-2 mb-2">{{ tm('form.mcpServersQuickSelect') }}</h4>
|
||||
<div class="d-flex flex-wrap ga-2">
|
||||
<v-chip v-for="server in mcpServers" :key="server.name"
|
||||
:color="isServerSelected(server) ? 'primary' : 'default'"
|
||||
:variant="isServerSelected(server) ? 'flat' : 'outlined'"
|
||||
size="small" clickable @click="toggleMcpServer(server)"
|
||||
:disabled="!server.tools || server.tools.length === 0">
|
||||
<v-icon start size="small">mdi-server</v-icon>
|
||||
{{ server.name }}
|
||||
<v-chip-text v-if="server.tools" class="ml-1">
|
||||
({{ server.tools.length }})
|
||||
</v-chip-text>
|
||||
</v-chip>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- 工具选择列表 -->
|
||||
<div v-if="filteredTools.length > 0" class="tools-selection">
|
||||
<v-virtual-scroll :items="filteredTools" height="300" item-height="48">
|
||||
<template v-slot:default="{ item }">
|
||||
<v-list-item :key="item.name" density="comfortable"
|
||||
@click="toggleTool(item.name)">
|
||||
<template v-slot:prepend>
|
||||
<v-checkbox-btn :model-value="isToolSelected(item.name)"
|
||||
@click.stop="toggleTool(item.name)" />
|
||||
</template>
|
||||
|
||||
<v-list-item-title>
|
||||
{{ item.name }}
|
||||
<v-chip v-if="item.mcp_server_name" size="x-small"
|
||||
color="secondary" variant="tonal" class="ml-2">
|
||||
{{ item.mcp_server_name }}
|
||||
</v-chip>
|
||||
</v-list-item-title>
|
||||
|
||||
<v-list-item-subtitle v-if="item.description">
|
||||
{{ truncateText(item.description, 100) }}
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</template>
|
||||
</v-virtual-scroll>
|
||||
</div>
|
||||
|
||||
<div v-else-if="!loadingTools && availableTools.length === 0"
|
||||
class="text-center pa-4">
|
||||
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-tools</v-icon>
|
||||
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsAvailable')
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div v-else-if="!loadingTools && filteredTools.length === 0"
|
||||
class="text-center pa-4">
|
||||
<v-icon size="48" color="grey-lighten-2" class="mb-2">mdi-magnify</v-icon>
|
||||
<p class="text-body-2 text-medium-emphasis">{{ tm('form.noToolsFound') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 加载状态 -->
|
||||
<div v-if="loadingTools" class="text-center pa-4">
|
||||
<v-progress-circular indeterminate color="primary" />
|
||||
<p class="text-body-2 text-medium-emphasis mt-2">{{ tm('form.loadingTools')
|
||||
}}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<!-- 已选择的工具 -->
|
||||
<div class="mt-4">
|
||||
<h4 class="text-subtitle-2 mb-2">
|
||||
{{ tm('form.selectedTools') }}
|
||||
<span v-if="personaForm.tools === null" class="text-success">
|
||||
({{ tm('form.allSelected') }})
|
||||
</span>
|
||||
<span v-else-if="Array.isArray(personaForm.tools)">
|
||||
({{ personaForm.tools.length }})
|
||||
</span>
|
||||
</h4>
|
||||
<div v-if="Array.isArray(personaForm.tools) && personaForm.tools.length > 0"
|
||||
class="d-flex flex-wrap ga-1" style="max-height: 100px; overflow-y: auto;">
|
||||
<v-chip v-for="toolName in personaForm.tools" :key="toolName"
|
||||
size="small" color="primary" variant="tonal" closable
|
||||
@click:close="removeTool(toolName)">
|
||||
{{ toolName }}
|
||||
</v-chip>
|
||||
</div>
|
||||
<div v-else class="text-body-2 text-medium-emphasis">
|
||||
{{ tm('form.noToolsSelected') }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
</v-expansion-panel-text>
|
||||
</v-expansion-panel>
|
||||
|
||||
<!-- 预设对话面板 -->
|
||||
<v-expansion-panel value="dialogs">
|
||||
<v-expansion-panel-title>
|
||||
<v-icon class="mr-2">mdi-chat</v-icon>
|
||||
{{ tm('form.presetDialogs') }}
|
||||
<v-chip v-if="personaForm.begin_dialogs.length > 0" size="small" color="primary"
|
||||
variant="tonal" class="ml-2">
|
||||
{{ personaForm.begin_dialogs.length / 2 }}
|
||||
</v-chip>
|
||||
</v-expansion-panel-title>
|
||||
|
||||
<v-expansion-panel-text>
|
||||
<div class="mb-3">
|
||||
<p class="text-body-2 text-medium-emphasis">
|
||||
{{ tm('form.presetDialogsHelp') }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div v-for="(dialog, index) in personaForm.begin_dialogs" :key="index" class="mb-3">
|
||||
<v-textarea v-model="personaForm.begin_dialogs[index]"
|
||||
:label="index % 2 === 0 ? tm('form.userMessage') : tm('form.assistantMessage')"
|
||||
:rules="getDialogRules(index)" variant="outlined" rows="2"
|
||||
density="comfortable">
|
||||
<template v-slot:append>
|
||||
<v-btn icon="mdi-delete" variant="text" size="small" color="error"
|
||||
@click="removeDialog(index)" />
|
||||
</template>
|
||||
</v-textarea>
|
||||
</div>
|
||||
|
||||
<v-btn variant="outlined" prepend-icon="mdi-plus" @click="addDialogPair" block>
|
||||
{{ tm('buttons.addDialogPair') }}
|
||||
</v-btn>
|
||||
</v-expansion-panel-text>
|
||||
</v-expansion-panel>
|
||||
</v-expansion-panels>
|
||||
</v-form>
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions>
|
||||
<v-spacer />
|
||||
<v-btn color="grey" variant="text" @click="closePersonaDialog">
|
||||
{{ tm('buttons.cancel') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" variant="flat" @click="savePersona" :loading="saving" :disabled="!formValid">
|
||||
{{ tm('buttons.save') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 查看人格详情对话框 -->
|
||||
<v-dialog v-model="showViewDialog" max-width="700px">
|
||||
<v-card v-if="viewingPersona">
|
||||
<v-card-title class="d-flex justify-space-between align-center">
|
||||
<span class="text-h5">{{ viewingPersona.persona_id }}</span>
|
||||
<v-btn icon="mdi-close" variant="text" @click="showViewDialog = false" />
|
||||
</v-card-title>
|
||||
|
||||
<v-card-text>
|
||||
<div class="mb-4">
|
||||
<h4 class="text-h6 mb-2">{{ tm('form.systemPrompt') }}</h4>
|
||||
<div class="system-prompt-content">
|
||||
{{ viewingPersona.system_prompt }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div v-if="viewingPersona.begin_dialogs && viewingPersona.begin_dialogs.length > 0" class="mb-4">
|
||||
<h4 class="text-h6 mb-2">{{ tm('form.presetDialogs') }}</h4>
|
||||
<div v-for="(dialog, index) in viewingPersona.begin_dialogs" :key="index" class="mb-2">
|
||||
<v-chip :color="index % 2 === 0 ? 'primary' : 'secondary'" variant="tonal" size="small"
|
||||
class="mb-1">
|
||||
{{ index % 2 === 0 ? tm('form.userMessage') : tm('form.assistantMessage') }}
|
||||
</v-chip>
|
||||
<div class="dialog-content ml-2">
|
||||
{{ dialog }}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mb-4">
|
||||
<h4 class="text-h6 mb-2">{{ tm('form.tools') }}</h4>
|
||||
<div v-if="viewingPersona.tools === null" class="text-body-2 text-medium-emphasis">
|
||||
<v-chip size="small" color="success" variant="tonal" prepend-icon="mdi-check-all">
|
||||
{{ tm('form.allToolsAvailable') }}
|
||||
</v-chip>
|
||||
</div>
|
||||
<div v-else-if="viewingPersona.tools && viewingPersona.tools.length > 0"
|
||||
class="d-flex flex-wrap ga-1">
|
||||
<v-chip v-for="toolName in viewingPersona.tools" :key="toolName" size="small"
|
||||
color="primary" variant="tonal">
|
||||
{{ toolName }}
|
||||
</v-chip>
|
||||
</div>
|
||||
<div v-else class="text-body-2 text-medium-emphasis">
|
||||
{{ tm('form.noToolsSelected') }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="text-caption text-medium-emphasis">
|
||||
<div>{{ tm('labels.createdAt') }}: {{ formatDate(viewingPersona.created_at) }}</div>
|
||||
<div v-if="viewingPersona.updated_at">{{ tm('labels.updatedAt') }}: {{
|
||||
formatDate(viewingPersona.updated_at) }}</div>
|
||||
</div>
|
||||
</v-card-text>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
|
||||
<!-- 消息提示 -->
|
||||
<v-snackbar :timeout="3000" elevation="24" :color="messageType" v-model="showMessage" location="top">
|
||||
{{ message }}
|
||||
</v-snackbar>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
import axios from 'axios';
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables';
|
||||
|
||||
export default {
|
||||
name: 'PersonaPage',
|
||||
setup() {
|
||||
const { t } = useI18n();
|
||||
const { tm } = useModuleI18n('features/persona');
|
||||
return { t, tm };
|
||||
},
|
||||
data() {
|
||||
return {
|
||||
toolSelectValue: '0', // 默认选择全部工具
|
||||
personas: [],
|
||||
loading: false,
|
||||
saving: false,
|
||||
showPersonaDialog: false,
|
||||
showViewDialog: false,
|
||||
editingPersona: null,
|
||||
viewingPersona: null,
|
||||
expandedPanels: [],
|
||||
formValid: false,
|
||||
personaForm: {
|
||||
persona_id: '',
|
||||
system_prompt: '',
|
||||
begin_dialogs: [],
|
||||
tools: []
|
||||
},
|
||||
showMessage: false,
|
||||
message: '',
|
||||
messageType: 'success',
|
||||
personaIdRules: [
|
||||
v => !!v || this.tm('validation.required'),
|
||||
v => (v && v.length >= 2) || this.tm('validation.minLength', { min: 2 }),
|
||||
v => /^[a-zA-Z0-9_-]+$/.test(v) || this.tm('validation.alphanumeric')
|
||||
],
|
||||
systemPromptRules: [
|
||||
v => !!v || this.tm('validation.required'),
|
||||
v => (v && v.length >= 10) || this.tm('validation.minLength', { min: 10 })
|
||||
],
|
||||
mcpServers: [],
|
||||
availableTools: [],
|
||||
loadingTools: false,
|
||||
toolSearch: ''
|
||||
}
|
||||
},
|
||||
|
||||
computed: {
|
||||
filteredTools() {
|
||||
if (!this.toolSearch) {
|
||||
return this.availableTools;
|
||||
}
|
||||
const search = this.toolSearch.toLowerCase();
|
||||
return this.availableTools.filter(tool =>
|
||||
tool.name.toLowerCase().includes(search) ||
|
||||
(tool.description && tool.description.toLowerCase().includes(search)) ||
|
||||
(tool.mcp_server_name && tool.mcp_server_name.toLowerCase().includes(search))
|
||||
);
|
||||
}
|
||||
},
|
||||
|
||||
watch: {
|
||||
toolSearch() {
|
||||
// 响应式搜索,无需额外处理
|
||||
},
|
||||
|
||||
toolSelectValue(newValue) {
|
||||
if (newValue === '0') {
|
||||
// 选择全部工具
|
||||
this.personaForm.tools = null;
|
||||
} else if (newValue === '1') {
|
||||
// 选择指定工具,如果当前是null,则转换为空数组
|
||||
if (this.personaForm.tools === null) {
|
||||
this.personaForm.tools = [];
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
mounted() {
|
||||
this.loadPersonas();
|
||||
this.loadMcpServers();
|
||||
this.loadTools();
|
||||
},
|
||||
|
||||
methods: {
|
||||
async loadPersonas() {
|
||||
this.loading = true;
|
||||
try {
|
||||
const response = await axios.get('/api/persona/list');
|
||||
if (response.data.status === 'ok') {
|
||||
this.personas = response.data.data;
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.loadError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.loadError'));
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async loadMcpServers() {
|
||||
try {
|
||||
const response = await axios.get('/api/tools/mcp/servers');
|
||||
if (response.data.status === 'ok') {
|
||||
this.mcpServers = response.data.data;
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.loadError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.loadError'));
|
||||
}
|
||||
},
|
||||
|
||||
async loadTools() {
|
||||
this.loadingTools = true;
|
||||
try {
|
||||
const response = await axios.get('/api/tools/list');
|
||||
if (response.data.status === 'ok') {
|
||||
this.availableTools = response.data.data;
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.loadError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.loadError'));
|
||||
}
|
||||
this.loadingTools = false;
|
||||
},
|
||||
|
||||
openCreateDialog() {
|
||||
this.editingPersona = null;
|
||||
this.personaForm = {
|
||||
persona_id: '',
|
||||
system_prompt: '',
|
||||
begin_dialogs: [],
|
||||
tools: []
|
||||
};
|
||||
this.toolSelectValue = '1'; // 默认选择指定工具
|
||||
this.expandedPanels = [];
|
||||
this.showPersonaDialog = true;
|
||||
},
|
||||
|
||||
editPersona(persona) {
|
||||
this.editingPersona = persona;
|
||||
this.personaForm = {
|
||||
persona_id: persona.persona_id,
|
||||
system_prompt: persona.system_prompt,
|
||||
begin_dialogs: [...(persona.begin_dialogs || [])],
|
||||
tools: persona.tools === null ? null : [...(persona.tools || [])]
|
||||
};
|
||||
// 根据 tools 的值设置 toolSelectValue
|
||||
this.toolSelectValue = persona.tools === null ? '0' : '1';
|
||||
this.expandedPanels = [];
|
||||
this.showPersonaDialog = true;
|
||||
},
|
||||
|
||||
viewPersona(persona) {
|
||||
this.viewingPersona = persona;
|
||||
this.showViewDialog = true;
|
||||
},
|
||||
|
||||
closePersonaDialog() {
|
||||
this.showPersonaDialog = false;
|
||||
this.editingPersona = null;
|
||||
this.personaForm = {
|
||||
persona_id: '',
|
||||
system_prompt: '',
|
||||
begin_dialogs: [],
|
||||
tools: []
|
||||
};
|
||||
this.toolSelectValue = '1'; // 重置为默认值
|
||||
},
|
||||
|
||||
async savePersona() {
|
||||
if (!this.formValid) return;
|
||||
|
||||
// 验证预设对话不能为空
|
||||
if (this.personaForm.begin_dialogs.length > 0) {
|
||||
for (let i = 0; i < this.personaForm.begin_dialogs.length; i++) {
|
||||
if (!this.personaForm.begin_dialogs[i] || this.personaForm.begin_dialogs[i].trim() === '') {
|
||||
const dialogType = i % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
|
||||
this.showError(this.tm('validation.dialogRequired', { type: dialogType }));
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
this.saving = true;
|
||||
try {
|
||||
const url = this.editingPersona ? '/api/persona/update' : '/api/persona/create';
|
||||
const response = await axios.post(url, this.personaForm);
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(response.data.message || this.tm('messages.saveSuccess'));
|
||||
this.closePersonaDialog();
|
||||
await this.loadPersonas();
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.saveError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.saveError'));
|
||||
}
|
||||
this.saving = false;
|
||||
},
|
||||
|
||||
async deletePersona(persona) {
|
||||
if (!confirm(this.tm('messages.deleteConfirm', { id: persona.persona_id }))) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await axios.post('/api/persona/delete', {
|
||||
persona_id: persona.persona_id
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(response.data.message || this.tm('messages.deleteSuccess'));
|
||||
await this.loadPersonas();
|
||||
} else {
|
||||
this.showError(response.data.message || this.tm('messages.deleteError'));
|
||||
}
|
||||
} catch (error) {
|
||||
this.showError(error.response?.data?.message || this.tm('messages.deleteError'));
|
||||
}
|
||||
},
|
||||
|
||||
addDialogPair() {
|
||||
this.personaForm.begin_dialogs.push('', '');
|
||||
// 自动展开预设对话面板
|
||||
if (!this.expandedPanels.includes('dialogs')) {
|
||||
this.expandedPanels.push('dialogs');
|
||||
}
|
||||
},
|
||||
|
||||
removeDialog(index) {
|
||||
// 如果是偶数索引(用户消息),删除用户消息和对应的助手消息
|
||||
if (index % 2 === 0 && index + 1 < this.personaForm.begin_dialogs.length) {
|
||||
this.personaForm.begin_dialogs.splice(index, 2);
|
||||
}
|
||||
// 如果是奇数索引(助手消息),删除助手消息和对应的用户消息
|
||||
else if (index % 2 === 1 && index - 1 >= 0) {
|
||||
this.personaForm.begin_dialogs.splice(index - 1, 2);
|
||||
}
|
||||
},
|
||||
|
||||
toggleMcpServer(server) {
|
||||
if (!server.tools || server.tools.length === 0) return;
|
||||
|
||||
// 如果当前是全选状态,需要先转换为具体的工具列表
|
||||
if (this.personaForm.tools === null) {
|
||||
// 从全选状态转换为去除该服务器工具的状态
|
||||
this.personaForm.tools = this.availableTools.map(tool => tool.name)
|
||||
.filter(toolName => !server.tools.includes(toolName));
|
||||
this.toolSelectValue = '1'; // 切换到指定工具模式
|
||||
return;
|
||||
}
|
||||
|
||||
// 确保tools是数组
|
||||
if (!Array.isArray(this.personaForm.tools)) {
|
||||
this.personaForm.tools = [];
|
||||
this.toolSelectValue = '1';
|
||||
}
|
||||
|
||||
// 检查是否所有服务器的工具都已选中
|
||||
const serverTools = server.tools;
|
||||
const allSelected = serverTools.every(toolName => this.personaForm.tools.includes(toolName));
|
||||
|
||||
if (allSelected) {
|
||||
// 移除所有服务器工具
|
||||
this.personaForm.tools = this.personaForm.tools.filter(
|
||||
toolName => !serverTools.includes(toolName)
|
||||
);
|
||||
} else {
|
||||
// 添加所有服务器工具
|
||||
serverTools.forEach(toolName => {
|
||||
if (!this.personaForm.tools.includes(toolName)) {
|
||||
this.personaForm.tools.push(toolName);
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
toggleTool(toolName) {
|
||||
// 如果当前是全选状态,需要先转换为具体的工具列表
|
||||
if (this.personaForm.tools === null) {
|
||||
// 如果是全选状态,点击某个工具表示要取消选择该工具
|
||||
// 所以创建一个包含所有其他工具的数组
|
||||
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
|
||||
this.toolSelectValue = '1'; // 切换到指定工具模式
|
||||
} else if (Array.isArray(this.personaForm.tools)) {
|
||||
const index = this.personaForm.tools.indexOf(toolName);
|
||||
if (index !== -1) {
|
||||
// 如果工具已选择,移除工具
|
||||
this.personaForm.tools.splice(index, 1);
|
||||
} else {
|
||||
// 如果工具未选择,添加工具
|
||||
this.personaForm.tools.push(toolName);
|
||||
}
|
||||
} else {
|
||||
// 如果tools不是数组也不是null,初始化为数组
|
||||
this.personaForm.tools = [toolName];
|
||||
this.toolSelectValue = '1';
|
||||
}
|
||||
},
|
||||
|
||||
toggleAllTools() {
|
||||
// 如果当前是全选状态,则清空选择
|
||||
if (this.isAllToolsSelected()) {
|
||||
this.personaForm.tools = [];
|
||||
} else {
|
||||
// 否则设置为全选(null表示所有工具)
|
||||
this.personaForm.tools = null;
|
||||
}
|
||||
},
|
||||
|
||||
clearAllTools() {
|
||||
// 清空所有工具选择
|
||||
this.personaForm.tools = [];
|
||||
},
|
||||
|
||||
isAllToolsSelected() {
|
||||
// 检查是否为全选状态(tools为null)
|
||||
return this.personaForm.tools === null;
|
||||
},
|
||||
|
||||
isNoToolsSelected() {
|
||||
// 检查是否没有选择任何工具
|
||||
return Array.isArray(this.personaForm.tools) && this.personaForm.tools.length === 0;
|
||||
},
|
||||
|
||||
removeTool(toolName) {
|
||||
// 如果当前是全选状态,需要先转换为具体的工具列表
|
||||
if (this.personaForm.tools === null) {
|
||||
// 创建一个包含所有工具的数组,然后移除指定工具
|
||||
this.personaForm.tools = this.availableTools.map(tool => tool.name).filter(name => name !== toolName);
|
||||
this.toolSelectValue = '1'; // 切换到指定工具模式
|
||||
} else if (Array.isArray(this.personaForm.tools)) {
|
||||
const index = this.personaForm.tools.indexOf(toolName);
|
||||
if (index !== -1) {
|
||||
this.personaForm.tools.splice(index, 1);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
truncateText(text, maxLength) {
|
||||
if (!text) return '';
|
||||
return text.length > maxLength ? text.substring(0, maxLength) + '...' : text;
|
||||
},
|
||||
|
||||
formatDate(dateString) {
|
||||
if (!dateString) return '';
|
||||
return new Date(dateString).toLocaleString();
|
||||
},
|
||||
|
||||
showSuccess(message) {
|
||||
this.message = message;
|
||||
this.messageType = 'success';
|
||||
this.showMessage = true;
|
||||
},
|
||||
|
||||
showError(message) {
|
||||
this.message = message;
|
||||
this.messageType = 'error';
|
||||
this.showMessage = true;
|
||||
},
|
||||
|
||||
getDialogRules(index) {
|
||||
const dialogType = index % 2 === 0 ? this.tm('form.userMessage') : this.tm('form.assistantMessage');
|
||||
return [
|
||||
v => !!v || this.tm('validation.dialogRequired', { type: dialogType }),
|
||||
v => (v && v.trim().length > 0) || this.tm('validation.dialogRequired', { type: dialogType })
|
||||
];
|
||||
},
|
||||
|
||||
isToolSelected(toolName) {
|
||||
// 如果是全选状态,所有工具都被选中
|
||||
if (this.personaForm.tools === null) {
|
||||
return true;
|
||||
}
|
||||
return Array.isArray(this.personaForm.tools) && this.personaForm.tools.includes(toolName);
|
||||
},
|
||||
|
||||
isServerSelected(server) {
|
||||
if (!server.tools || server.tools.length === 0) return false;
|
||||
|
||||
// 如果是全选状态,所有服务器都被选中
|
||||
if (this.personaForm.tools === null) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// 检查服务器的所有工具是否都已选中
|
||||
return Array.isArray(this.personaForm.tools) &&
|
||||
server.tools.every(toolName => this.personaForm.tools.includes(toolName));
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.persona-page {
|
||||
padding: 20px;
|
||||
padding-top: 8px;
|
||||
}
|
||||
|
||||
.persona-card {
|
||||
transition: all 0.3s ease;
|
||||
height: 100%;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.persona-card:hover {
|
||||
box-shadow: 0 8px 25px 0 rgba(0, 0, 0, 0.15);
|
||||
}
|
||||
|
||||
.system-prompt-preview {
|
||||
font-size: 14px;
|
||||
line-height: 1.4;
|
||||
color: rgba(var(--v-theme-on-surface), 0.7);
|
||||
overflow: hidden;
|
||||
display: -webkit-box;
|
||||
-webkit-line-clamp: 3;
|
||||
line-clamp: 3;
|
||||
-webkit-box-orient: vertical;
|
||||
}
|
||||
|
||||
.system-prompt-content {
|
||||
background-color: rgba(var(--v-theme-surface-variant), 0.3);
|
||||
padding: 12px;
|
||||
border-radius: 8px;
|
||||
font-family: 'Roboto Mono', monospace;
|
||||
font-size: 14px;
|
||||
line-height: 1.5;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.dialog-content {
|
||||
background-color: rgba(var(--v-theme-surface-variant), 0.3);
|
||||
padding: 8px 12px;
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
line-height: 1.4;
|
||||
margin-bottom: 8px;
|
||||
white-space: pre-wrap;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.tools-selection {
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.v-virtual-scroll {
|
||||
padding-bottom: 16px;
|
||||
}
|
||||
</style>
|
||||
@@ -405,24 +405,36 @@
|
||||
<v-text-field v-model="toolSearch" prepend-inner-icon="mdi-magnify" :label="tm('functionTools.search')"
|
||||
variant="outlined" density="compact" class="mb-4" hide-details clearable></v-text-field>
|
||||
|
||||
<small>复选框代表该工具是否被启用。</small>
|
||||
|
||||
<v-expansion-panels v-model="openedPanel" multiple style="max-height: 500px; overflow-y: auto;">
|
||||
<v-expansion-panel v-for="(tool, index) in filteredTools" :key="index" :value="index"
|
||||
class="mb-2 tool-panel" rounded="lg">
|
||||
<v-expansion-panel-title>
|
||||
<v-row no-gutters align="center">
|
||||
<v-col cols="1">
|
||||
<v-checkbox
|
||||
v-model="tool.active"
|
||||
color="primary"
|
||||
hide-details
|
||||
density="compact"
|
||||
@click.stop
|
||||
@change="toggleToolStatus(tool)"
|
||||
></v-checkbox>
|
||||
</v-col>
|
||||
<v-col cols="3">
|
||||
<div class="d-flex align-center">
|
||||
<v-icon color="primary" class="me-2" size="small">
|
||||
{{ tool.function.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
|
||||
{{ tool.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }}
|
||||
</v-icon>
|
||||
<span class="text-body-1 text-high-emphasis font-weight-medium text-truncate"
|
||||
:title="tool.function.name">
|
||||
{{ formatToolName(tool.function.name) }}
|
||||
:title="tool.name">
|
||||
{{ formatToolName(tool.name) }}
|
||||
</span>
|
||||
</div>
|
||||
</v-col>
|
||||
<v-col cols="9" class="text-grey">
|
||||
{{ tool.function.description }}
|
||||
<v-col cols="8" class="text-grey">
|
||||
{{ tool.description }}
|
||||
</v-col>
|
||||
</v-row>
|
||||
</v-expansion-panel-title>
|
||||
@@ -434,9 +446,9 @@
|
||||
<v-icon color="primary" size="small" class="me-1">mdi-information</v-icon>
|
||||
{{ tm('functionTools.description') }}
|
||||
</p>
|
||||
<p class="text-body-2 ml-6 mb-4">{{ tool.function.description }}</p>
|
||||
<p class="text-body-2 ml-6 mb-4">{{ tool.description }}</p>
|
||||
|
||||
<template v-if="tool.function.parameters && tool.function.parameters.properties">
|
||||
<template v-if="tool.parameters && tool.parameters.properties">
|
||||
<p class="text-body-1 font-weight-medium mb-3">
|
||||
<v-icon color="primary" size="small" class="me-1">mdi-code-json</v-icon>
|
||||
{{ tm('functionTools.parameters') }}
|
||||
@@ -451,7 +463,7 @@
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="(param, paramName) in tool.function.parameters.properties" :key="paramName">
|
||||
<tr v-for="(param, paramName) in tool.parameters.properties" :key="paramName">
|
||||
<td class="font-weight-medium">{{ paramName }}</td>
|
||||
<td>
|
||||
<v-chip size="x-small" color="primary" text class="text-caption">
|
||||
@@ -562,8 +574,8 @@ export default {
|
||||
|
||||
const searchTerm = this.toolSearch.toLowerCase();
|
||||
return this.tools.filter(tool =>
|
||||
tool.function.name.toLowerCase().includes(searchTerm) ||
|
||||
tool.function.description.toLowerCase().includes(searchTerm)
|
||||
tool.name.toLowerCase().includes(searchTerm) ||
|
||||
tool.description.toLowerCase().includes(searchTerm)
|
||||
);
|
||||
},
|
||||
|
||||
@@ -658,7 +670,7 @@ export default {
|
||||
},
|
||||
|
||||
getTools() {
|
||||
axios.get('/api/config/llmtools')
|
||||
axios.get('/api/tools/list')
|
||||
.then(response => {
|
||||
this.tools = response.data.data || [];
|
||||
})
|
||||
@@ -976,6 +988,28 @@ export default {
|
||||
} catch (e) {
|
||||
this.showError(this.tm('messages.importError.failed', { error: e.message }));
|
||||
}
|
||||
},
|
||||
|
||||
// 切换工具状态
|
||||
async toggleToolStatus(tool) {
|
||||
try {
|
||||
const response = await axios.post('/api/tools/toggle-tool', {
|
||||
name: tool.name,
|
||||
activate: tool.active
|
||||
});
|
||||
|
||||
if (response.data.status === 'ok') {
|
||||
this.showSuccess(response.data.message || this.tm('messages.toggleToolSuccess'));
|
||||
} else {
|
||||
// 如果失败,恢复原状态
|
||||
tool.active = !tool.active;
|
||||
this.showError(response.data.message || this.tm('messages.toggleToolError'));
|
||||
}
|
||||
} catch (error) {
|
||||
// 如果失败,恢复原状态
|
||||
tool.active = !tool.active;
|
||||
this.showError(this.tm('messages.toggleToolError', { error: error.response?.data?.message || error.message }));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
+21
-41
@@ -26,6 +26,7 @@ from .long_term_memory import LongTermMemory
|
||||
from astrbot.core import logger
|
||||
from astrbot.api.message_components import Plain, Image, Reply
|
||||
from astrbot.core.star.session_llm_manager import SessionServiceManager
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from typing import Union
|
||||
from enum import Enum
|
||||
|
||||
@@ -128,7 +129,6 @@ class Main(star.Star):
|
||||
/reset: 重置 LLM 会话
|
||||
/history: 当前对话的对话记录
|
||||
/persona: 人格情景(op)
|
||||
/tool ls: 函数工具
|
||||
/key: API Key(op)
|
||||
/websearch: 网页搜索
|
||||
{notice}"""
|
||||
@@ -157,50 +157,22 @@ class Main(star.Star):
|
||||
@tool.command("ls")
|
||||
async def tool_ls(self, event: AstrMessageEvent):
|
||||
"""查看函数工具列表"""
|
||||
tm = self.context.get_llm_tool_manager()
|
||||
msg = "函数工具:\n"
|
||||
for tool in tm.func_list:
|
||||
active = " (启用)" if tool.active else "(停用)"
|
||||
msg += f"- {tool.name}: {tool.description} {active}\n"
|
||||
|
||||
msg += "\n使用 /tool on/off <工具名> 激活或者停用函数工具。/tool off_all 停用所有函数工具。"
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
|
||||
|
||||
@tool.command("on")
|
||||
async def tool_on(self, event: AstrMessageEvent, tool_name: str):
|
||||
"""启用一个函数工具"""
|
||||
if self.context.activate_llm_tool(tool_name):
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"激活工具 {tool_name} 成功。")
|
||||
)
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"激活工具 {tool_name} 失败,未找到此工具。"
|
||||
)
|
||||
)
|
||||
event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
|
||||
|
||||
@tool.command("off")
|
||||
async def tool_off(self, event: AstrMessageEvent, tool_name: str):
|
||||
"""停用一个函数工具"""
|
||||
if self.context.deactivate_llm_tool(tool_name):
|
||||
event.set_result(
|
||||
MessageEventResult().message(f"停用工具 {tool_name} 成功。")
|
||||
)
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"停用工具 {tool_name} 失败,未找到此工具。"
|
||||
)
|
||||
)
|
||||
event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
|
||||
|
||||
@tool.command("off_all")
|
||||
async def tool_all_off(self, event: AstrMessageEvent):
|
||||
"""停用所有函数工具"""
|
||||
tm = self.context.get_llm_tool_manager()
|
||||
for tool in tm.func_list:
|
||||
self.context.deactivate_llm_tool(tool.name)
|
||||
event.set_result(MessageEventResult().message("停用所有工具成功。"))
|
||||
event.set_result(MessageEventResult().message("tool 指令已经被移除。"))
|
||||
|
||||
@filter.command_group("plugin")
|
||||
def plugin(self):
|
||||
@@ -1264,26 +1236,34 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
if req.conversation:
|
||||
persona_id = req.conversation.persona_id
|
||||
if not persona_id and persona_id != "[%None]": # [%None] 为用户取消人格
|
||||
persona_id = self.context.provider_manager.selected_default_persona[
|
||||
persona_id = self.context.persona_manager.selected_default_persona_v3[
|
||||
"name"
|
||||
]
|
||||
persona = next(
|
||||
builtins.filter(
|
||||
lambda persona: persona["name"] == persona_id,
|
||||
self.context.provider_manager.personas,
|
||||
self.context.persona_manager.personas_v3,
|
||||
),
|
||||
None,
|
||||
)
|
||||
if persona:
|
||||
if prompt := persona["prompt"]:
|
||||
req.system_prompt += prompt
|
||||
if mood_dialogs := persona["_mood_imitation_dialogs_processed"]:
|
||||
req.system_prompt += "\nHere are few shots of dialogs, you need to imitate the tone of 'B' in the following dialogs to respond:\n"
|
||||
req.system_prompt += mood_dialogs
|
||||
if (
|
||||
begin_dialogs := persona["_begin_dialogs_processed"]
|
||||
) and not req.contexts:
|
||||
if begin_dialogs := persona["_begin_dialogs_processed"]:
|
||||
req.contexts[:0] = begin_dialogs
|
||||
# tools select
|
||||
tmgr = self.context.get_llm_tool_manager()
|
||||
if (persona and persona.get("tools") is None) or not persona:
|
||||
# select all
|
||||
toolset = tmgr.get_full_tool_set()
|
||||
else:
|
||||
toolset = ToolSet()
|
||||
for tool_name in persona["tools"]:
|
||||
tool = tmgr.get_func(tool_name)
|
||||
if tool:
|
||||
toolset.add_tool(tool)
|
||||
req.func_tool = toolset
|
||||
logger.debug(f"Tool set for persona {persona_id}: {toolset.names()}")
|
||||
|
||||
if quote:
|
||||
sender_info = ""
|
||||
|
||||
Reference in New Issue
Block a user