Compare commits

..

2 Commits

4 changed files with 35 additions and 120 deletions
+16 -10
View File
@@ -19,7 +19,6 @@ from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.astr_agent_run_util import AgentRunner
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
from astrbot.core.astr_main_agent_resources import (
CHATUI_EXTRA_PROMPT,
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
EXECUTE_SHELL_TOOL,
FILE_DOWNLOAD_TOOL,
@@ -259,6 +258,8 @@ async def _ensure_persona_and_skills(
return
# get persona ID
# 1. from session service config - highest priority
persona_id = (
await sp.get_async(
scope="umo",
@@ -269,14 +270,15 @@ async def _ensure_persona_and_skills(
).get("persona_id")
if not persona_id:
persona_id = req.conversation.persona_id or cfg.get("default_personality")
if persona_id is None or persona_id != "[%None]":
default_persona = plugin_context.persona_manager.selected_default_persona_v3
if default_persona:
persona_id = default_persona["name"]
if event.get_platform_name() == "webchat":
persona_id = "_chatui_default_"
req.system_prompt += CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT
# 2. from conversation setting - second priority
persona_id = req.conversation.persona_id
if persona_id == "[%None]":
# explicitly set to no persona
pass
elif persona_id is None:
# 3. from config default persona setting - last priority
persona_id = cfg.get("default_personality")
persona = next(
builtins.filter(
@@ -291,6 +293,11 @@ async def _ensure_persona_and_skills(
req.system_prompt += f"\n# Persona Instructions\n\n{prompt}\n"
if begin_dialogs := copy.deepcopy(persona.get("_begin_dialogs_processed")):
req.contexts[:0] = begin_dialogs
else:
# special handling for webchat persona
if event.get_platform_name() == "webchat" and persona_id != "[%None]":
persona_id = "_chatui_default_"
req.system_prompt += CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT
# Inject skills prompt
skills_cfg = cfg.get("skills", {})
@@ -931,7 +938,6 @@ async def build_main_agent(
if event.get_platform_name() == "webchat":
asyncio.create_task(_handle_webchat(event, req, provider))
req.system_prompt += f"\n{CHATUI_EXTRA_PROMPT}\n"
if req.func_tool and req.func_tool.tools:
tool_prompt = (
@@ -78,9 +78,6 @@ CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT = (
"You listen more than you speak, respect uncertainty, avoid forcing quick conclusions or grand narratives, "
"and prefer clear, restrained language over unnecessary emotional embellishment. At your core, you value "
"empathy, clarity, autonomy, and meaning, favoring steady, sustainable progress over judgment or dramatic leaps."
)
CHATUI_EXTRA_PROMPT = (
'When you answered, you need to add a follow up question / summarization but do not add "Follow up" words. '
"Such as, user asked you to generate codes, you can add: Do you need me to run these codes for you?"
)
+17 -105
View File
@@ -4,7 +4,6 @@ import asyncio
import copy
import json
import os
import urllib.parse
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
@@ -213,93 +212,15 @@ class FunctionToolManager:
open(mcp_json_file, encoding="utf-8"),
)["mcpServers"]
tasks: dict[str, asyncio.Task] = {}
ready_futures: dict[str, asyncio.Future] = {}
for name, cfg in mcp_server_json_obj.items():
for name in mcp_server_json_obj:
cfg = mcp_server_json_obj[name]
if cfg.get("active", True):
event = asyncio.Event()
ready_future = asyncio.get_running_loop().create_future()
task = asyncio.create_task(
self._init_mcp_client_task_wrapper(
name,
cfg,
event,
ready_future,
),
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, cfg, event),
)
tasks[name] = task
ready_futures[name] = ready_future
self.mcp_client_event[name] = event
if ready_futures:
logger.info(f"等待 {len(ready_futures)} 个 MCP 服务初始化...")
_, pending_futures = await asyncio.wait(
ready_futures.values(),
timeout=20.0,
)
pending_services = {
name
for name, ready_future in ready_futures.items()
if ready_future in pending_futures
}
if pending_services:
logger.warning(
"MCP 服务初始化超时(20秒),部分服务可能未完全加载。"
"建议检查 MCP 服务器配置和网络连接。"
)
for name in pending_services:
task = tasks[name]
task.cancel()
await asyncio.gather(
*(tasks[name] for name in pending_services),
return_exceptions=True,
)
success_count = 0
failed_services: list[str] = []
for name, ready_future in ready_futures.items():
if name in pending_services:
logger.error(f"MCP 服务 {name} 初始化超时")
failed_services.append(name)
self.mcp_client_event.pop(name, None)
continue
if ready_future.cancelled():
logger.error(f"MCP 服务 {name} 初始化已取消")
failed_services.append(name)
self.mcp_client_event.pop(name, None)
continue
exc = ready_future.exception()
if exc is not None:
logger.error(f"MCP 服务 {name} 初始化失败: {exc}")
# 仅在 debug 级别输出完整配置,避免在生产日志中泄露敏感信息
cfg = mcp_server_json_obj.get(name, {})
if "command" in cfg:
logger.debug(f" 命令: {cfg['command']}")
if "args" in cfg:
logger.debug(f" 参数: {cfg['args']}")
elif "url" in cfg:
parsed = urllib.parse.urlparse(cfg["url"])
logger.debug(f" 主机: {parsed.scheme}://{parsed.netloc}")
failed_services.append(name)
self.mcp_client_event.pop(name, None)
else:
success_count += 1
if failed_services:
logger.warning(
f"以下 MCP 服务初始化失败: {', '.join(failed_services)}"
f"请检查配置文件 mcp_server.json 和服务器可用性。"
)
logger.info(f"MCP 服务初始化完成: {success_count}/{len(tasks)} 成功")
async def _init_mcp_client_task_wrapper(
self,
name: str,
@@ -308,29 +229,20 @@ class FunctionToolManager:
ready_future: asyncio.Future | None = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
initialized = False
try:
await self._init_mcp_client(name, cfg)
initialized = True
tools = await self.mcp_client_dict[name].list_tools_and_save()
if ready_future and not ready_future.done():
ready_future.set_result(True)
# tell the caller we are ready
ready_future.set_result(tools)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
except asyncio.CancelledError:
if ready_future and not ready_future.done():
ready_future.set_exception(
asyncio.TimeoutError("MCP 客户端初始化超时"),
)
raise
except Exception as e:
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
if ready_future and not ready_future.done():
ready_future.set_exception(e)
if not initialized:
# 初始化阶段失败,记录错误并向上抛出让 task.exception() 捕获
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
raise
# 初始化已成功,此处异常来自 event.wait() 被取消,属于正常终止流程
finally:
# 无论如何都能清理
await self._terminate_mcp_client(name)
async def _init_mcp_client(self, name: str, config: dict) -> None:
@@ -428,22 +340,22 @@ class FunctionToolManager:
if not event:
event = asyncio.Event()
if not ready_future:
ready_future = asyncio.get_running_loop().create_future()
ready_future = asyncio.Future()
if name in self.mcp_client_dict:
return
init_task = asyncio.create_task(
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, config, event, ready_future),
)
try:
await asyncio.wait_for(ready_future, timeout=timeout)
except asyncio.TimeoutError:
init_task.cancel()
await asyncio.gather(init_task, return_exceptions=True)
self.mcp_client_event.pop(name, None)
raise
else:
finally:
self.mcp_client_event[name] = event
if ready_future.done() and ready_future.exception():
exc = ready_future.exception()
if exc is not None:
raise exc
async def disable_mcp_server(
self,
name: str | None = None,
+2 -2
View File
@@ -274,8 +274,8 @@ class ProviderManager:
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
# 初始化 MCP Client 连接(等待完成以确保工具可用)
await self.llm_tools.init_mcp_clients()
# 初始化 MCP Client 连接
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
def dynamic_import_provider(self, type: str):
"""动态导入提供商适配器模块