diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py
index 50b436043..54ad1e63b 100644
--- a/astrbot/core/pipeline/respond/stage.py
+++ b/astrbot/core/pipeline/respond/stage.py
@@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.utils.path_util import path_Mapping
+from astrbot.core.utils.session_lock import session_lock_manager
@register_stage
@@ -177,25 +178,26 @@ class RespondStage(Stage):
result.chain.remove(comp)
break
- for rcomp in record_comps:
- i = await self._calc_comp_interval(rcomp)
- await asyncio.sleep(i)
- try:
- await event.send(MessageChain([rcomp]))
- except Exception as e:
- logger.error(f"发送消息失败: {e} chain: {result.chain}")
- break
-
- # 分段回复
- for comp in non_record_comps:
- i = await self._calc_comp_interval(comp)
- await asyncio.sleep(i)
- try:
- await event.send(MessageChain([*decorated_comps, comp]))
- decorated_comps = [] # 清空已发送的装饰组件
- except Exception as e:
- logger.error(f"发送消息失败: {e} chain: {result.chain}")
- break
+ # leverage lock to guarentee the order of message sending among different events
+ async with session_lock_manager.acquire_lock(event.unified_msg_origin):
+ for rcomp in record_comps:
+ i = await self._calc_comp_interval(rcomp)
+ await asyncio.sleep(i)
+ try:
+ await event.send(MessageChain([rcomp]))
+ except Exception as e:
+ logger.error(f"发送消息失败: {e} chain: {result.chain}")
+ break
+ # 分段回复
+ for comp in non_record_comps:
+ i = await self._calc_comp_interval(comp)
+ await asyncio.sleep(i)
+ try:
+ await event.send(MessageChain([*decorated_comps, comp]))
+ decorated_comps = [] # 清空已发送的装饰组件
+ except Exception as e:
+ logger.error(f"发送消息失败: {e} chain: {result.chain}")
+ break
else:
for rcomp in record_comps:
try:
diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py
index cea2e4f38..07a0fbd8f 100644
--- a/astrbot/core/provider/func_tool_manager.py
+++ b/astrbot/core/provider/func_tool_manager.py
@@ -39,6 +39,72 @@ SUPPORTED_TYPES = [
] # json schema 支持的数据类型
+def _prepare_config(config: dict) -> dict:
+ """准备配置,处理嵌套格式"""
+ if "mcpServers" in config and config["mcpServers"]:
+ first_key = next(iter(config["mcpServers"]))
+ config = config["mcpServers"][first_key]
+ config.pop("active", None)
+ return config
+
+
+async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
+ """快速测试 MCP 服务器可达性"""
+ import aiohttp
+
+ cfg = _prepare_config(config.copy())
+
+ url = cfg["url"]
+ headers = cfg.get("headers", {})
+ timeout = cfg.get("timeout", 10)
+
+ try:
+ async with aiohttp.ClientSession() as session:
+ if cfg.get("transport") == "streamable_http":
+ test_payload = {
+ "jsonrpc": "2.0",
+ "method": "initialize",
+ "id": 0,
+ "params": {
+ "protocolVersion": "2024-11-05",
+ "capabilities": {},
+ "clientInfo": {"name": "test-client", "version": "1.2.3"},
+ },
+ }
+ async with session.post(
+ url,
+ headers={
+ **headers,
+ "Content-Type": "application/json",
+ "Accept": "application/json, text/event-stream",
+ },
+ json=test_payload,
+ timeout=aiohttp.ClientTimeout(total=timeout),
+ ) as response:
+ if response.status == 200:
+ return True, ""
+ else:
+ return False, f"HTTP {response.status}: {response.reason}"
+ else:
+ async with session.get(
+ url,
+ headers={
+ **headers,
+ "Accept": "application/json, text/event-stream",
+ },
+ timeout=aiohttp.ClientTimeout(total=timeout),
+ ) as response:
+ if response.status == 200:
+ return True, ""
+ else:
+ return False, f"HTTP {response.status}: {response.reason}"
+
+ except asyncio.TimeoutError:
+ return False, f"连接超时: {timeout}秒"
+ except Exception as e:
+ return False, f"{e!s}"
+
+
@dataclass
class FuncTool:
"""
@@ -80,12 +146,10 @@ class FuncTool:
if not self.mcp_client or not self.mcp_client.session:
raise Exception(f"MCP client for {self.name} is not available")
# 使用name属性而不是额外的mcp_tool_name
- if ":" in self.name:
- # 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
- actual_tool_name = self.name.split(":")[-1]
- return await self.mcp_client.session.call_tool(actual_tool_name, args)
- else:
- return await self.mcp_client.session.call_tool(self.name, args)
+ actual_tool_name = (
+ self.name.split(":")[-1] if ":" in self.name else self.name
+ )
+ return await self.mcp_client.session.call_tool(actual_tool_name, args)
else:
raise Exception(f"Unknown function origin: {self.origin}")
@@ -100,6 +164,7 @@ class MCPClient:
self.active: bool = True
self.tools: List[mcp.Tool] = []
self.server_errlogs: List[str] = []
+ self.running_event = asyncio.Event()
async def connect_to_server(self, mcp_server_config: dict, name: str):
"""连接到 MCP 服务器
@@ -112,17 +177,19 @@ class MCPClient:
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
- cfg = mcp_server_config.copy()
- if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
- key_0 = list(cfg["mcpServers"].keys())[0]
- cfg = cfg["mcpServers"][key_0]
- cfg.pop("active", None) # Remove active flag from config
+ cfg = _prepare_config(mcp_server_config.copy())
+
+ def logging_callback(msg: str):
+ # 处理 MCP 服务的错误日志
+ print(f"MCP Server {name} Error: {msg}")
+ self.server_errlogs.append(msg)
if "url" in cfg:
- is_sse = True
- if cfg.get("transport") == "streamable_http":
- is_sse = False
- if is_sse:
+ success, error_msg = await _quick_test_mcp_connection(cfg)
+ if not success:
+ raise Exception(error_msg)
+
+ if cfg.get("transport") != "streamable_http":
# SSE transport method
self._streams_context = sse_client(
url=cfg["url"],
@@ -130,11 +197,18 @@ class MCPClient:
timeout=cfg.get("timeout", 5),
sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5),
)
- streams = await self.exit_stack.enter_async_context(self._streams_context)
+ streams = await self.exit_stack.enter_async_context(
+ self._streams_context
+ )
# Create a new client session
+ read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
- mcp.ClientSession(*streams)
+ mcp.ClientSession(
+ *streams,
+ read_timeout_seconds=read_timeout,
+ logging_callback=logging_callback, # type: ignore
+ )
)
else:
timeout = timedelta(seconds=cfg.get("timeout", 30))
@@ -148,11 +222,19 @@ class MCPClient:
sse_read_timeout=sse_read_timeout,
terminate_on_close=cfg.get("terminate_on_close", True),
)
- read_s, write_s, _ = await self.exit_stack.enter_async_context(self._streams_context)
+ read_s, write_s, _ = await self.exit_stack.enter_async_context(
+ self._streams_context
+ )
# Create a new client session
+ read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20))
self.session = await self.exit_stack.enter_async_context(
- mcp.ClientSession(read_stream=read_s, write_stream=write_s)
+ mcp.ClientSession(
+ read_stream=read_s,
+ write_stream=write_s,
+ read_timeout_seconds=read_timeout,
+ logging_callback=logging_callback, # type: ignore
+ )
)
else:
@@ -172,7 +254,7 @@ class MCPClient:
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
- ),
+ ), # type: ignore
),
)
@@ -180,19 +262,18 @@ class MCPClient:
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(*stdio_transport)
)
-
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
response = await self.session.list_tools()
- logger.debug(f"MCP server {self.name} list tools response: {response}")
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
+ self.running_event.set() # Set the running event to indicate cleanup is done
class FuncCall:
@@ -201,8 +282,6 @@ class FuncCall:
"""内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {}
"""MCP 服务列表"""
- self.mcp_service_queue = asyncio.Queue()
- """用于外部控制 MCP 服务的启停"""
self.mcp_client_event: Dict[str, asyncio.Event] = {}
def empty(self) -> bool:
@@ -258,7 +337,7 @@ class FuncCall:
return f
return None
- async def _init_mcp_clients(self) -> None:
+ async def init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
{
@@ -300,115 +379,64 @@ class FuncCall:
)
self.mcp_client_event[name] = event
- async def mcp_service_selector(self):
- """为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
-
- 使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
-
- {"type": "init"} 初始化所有MCP客户端
-
- {"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
-
- {"type": "terminate"} 终止所有MCP客户端
-
- {"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
- """
- while True:
- data = await self.mcp_service_queue.get()
- if data["type"] == "init":
- if "name" in data:
- event = asyncio.Event()
- asyncio.create_task(
- self._init_mcp_client_task_wrapper(
- data["name"], data["cfg"], event
- )
- )
- self.mcp_client_event[data["name"]] = event
- else:
- await self._init_mcp_clients()
- elif data["type"] == "terminate":
- if "name" in data:
- # await self._terminate_mcp_client(data["name"])
- if data["name"] in self.mcp_client_event:
- self.mcp_client_event[data["name"]].set()
- self.mcp_client_event.pop(data["name"], None)
- self.func_list = [
- f
- for f in self.func_list
- if not (
- f.origin == "mcp" and f.mcp_server_name == data["name"]
- )
- ]
- else:
- for name in self.mcp_client_dict.keys():
- # await self._terminate_mcp_client(name)
- # self.mcp_client_event[name].set()
- if name in self.mcp_client_event:
- self.mcp_client_event[name].set()
- self.mcp_client_event.pop(name, None)
- self.func_list = [f for f in self.func_list if f.origin != "mcp"]
-
async def _init_mcp_client_task_wrapper(
- self, name: str, cfg: dict, event: asyncio.Event
+ self,
+ name: str,
+ cfg: dict,
+ event: asyncio.Event,
+ ready_future: asyncio.Future = None,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
await self._init_mcp_client(name, cfg)
+ tools = await self.mcp_client_dict[name].list_tools_and_save()
+ if ready_future and not ready_future.done():
+ # tell the caller we are ready
+ ready_future.set_result(tools)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
except Exception as e:
- import traceback
-
- traceback.print_exc()
- logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
+ logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
+ if ready_future and not ready_future.done():
+ ready_future.set_exception(e)
finally:
# 无论如何都能清理
await self._terminate_mcp_client(name)
async def _init_mcp_client(self, name: str, config: dict) -> None:
"""初始化单个MCP客户端"""
- try:
- # 先清理之前的客户端,如果存在
- if name in self.mcp_client_dict:
- await self._terminate_mcp_client(name)
+ # 先清理之前的客户端,如果存在
+ if name in self.mcp_client_dict:
+ await self._terminate_mcp_client(name)
- mcp_client = MCPClient()
- mcp_client.name = name
- self.mcp_client_dict[name] = mcp_client
- await mcp_client.connect_to_server(config, name)
- tools_res = await mcp_client.list_tools_and_save()
- tool_names = [tool.name for tool in tools_res.tools]
+ mcp_client = MCPClient()
+ mcp_client.name = name
+ self.mcp_client_dict[name] = mcp_client
+ await mcp_client.connect_to_server(config, name)
+ tools_res = await mcp_client.list_tools_and_save()
+ logger.debug(f"MCP server {name} list tools response: {tools_res}")
+ tool_names = [tool.name for tool in tools_res.tools]
- # 移除该MCP服务之前的工具(如有)
- self.func_list = [
- f
- for f in self.func_list
- if not (f.origin == "mcp" and f.mcp_server_name == name)
- ]
+ # 移除该MCP服务之前的工具(如有)
+ self.func_list = [
+ f
+ for f in self.func_list
+ if not (f.origin == "mcp" and f.mcp_server_name == name)
+ ]
- # 将 MCP 工具转换为 FuncTool 并添加到 func_list
- for tool in mcp_client.tools:
- func_tool = FuncTool(
- name=tool.name,
- parameters=tool.inputSchema,
- description=tool.description,
- origin="mcp",
- mcp_server_name=name,
- mcp_client=mcp_client,
- )
- self.func_list.append(func_tool)
+ # 将 MCP 工具转换为 FuncTool 并添加到 func_list
+ for tool in mcp_client.tools:
+ func_tool = FuncTool(
+ name=tool.name,
+ parameters=tool.inputSchema,
+ description=tool.description,
+ origin="mcp",
+ mcp_server_name=name,
+ mcp_client=mcp_client,
+ )
+ self.func_list.append(func_tool)
- logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
- return
- except Exception as e:
- import traceback
-
- logger.error(traceback.format_exc())
- logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
- # 发生错误时确保客户端被清理
- if name in self.mcp_client_dict:
- await self._terminate_mcp_client(name)
- return
+ logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
@@ -418,7 +446,7 @@ class FuncCall:
await self.mcp_client_dict[name].cleanup()
self.mcp_client_dict.pop(name)
except Exception as e:
- logger.info(f"清空 MCP 客户端资源 {name}: {e}。")
+ logger.error(f"清空 MCP 客户端资源 {name}: {e}。")
# 移除关联的FuncTool
self.func_list = [
f
@@ -427,6 +455,103 @@ class FuncCall:
]
logger.info(f"已关闭 MCP 服务 {name}")
+ @staticmethod
+ async def test_mcp_server_connection(config: dict) -> list[str]:
+ if "url" in config:
+ success, error_msg = await _quick_test_mcp_connection(config)
+ if not success:
+ raise Exception(error_msg)
+
+ mcp_client = MCPClient()
+ try:
+ logger.debug(f"testing MCP server connection with config: {config}")
+ await mcp_client.connect_to_server(config, "test")
+ tools_res = await mcp_client.list_tools_and_save()
+ tool_names = [tool.name for tool in tools_res.tools]
+ finally:
+ logger.debug("Cleaning up MCP client after testing connection.")
+ await mcp_client.cleanup()
+ return tool_names
+
+ async def enable_mcp_server(
+ self,
+ name: str,
+ config: dict,
+ event: asyncio.Event | None = None,
+ ready_future: asyncio.Future | None = None,
+ timeout: int = 30,
+ ) -> None:
+ """Enable_mcp_server a new MCP server to the manager and initialize it.
+
+ Args:
+ name (str): The name of the MCP server.
+ config (dict): Configuration for the MCP server.
+ event (asyncio.Event): Event to signal when the MCP client is ready.
+ ready_future (asyncio.Future): Future to signal when the MCP client is ready.
+ timeout (int): Timeout for the initialization.
+ Raises:
+ TimeoutError: If the initialization does not complete within the specified timeout.
+ Exception: If there is an error during initialization.
+ """
+ if not event:
+ event = asyncio.Event()
+ if not ready_future:
+ ready_future = asyncio.Future()
+ if name in self.mcp_client_dict:
+ return
+ asyncio.create_task(
+ self._init_mcp_client_task_wrapper(name, config, event, ready_future)
+ )
+ try:
+ await asyncio.wait_for(ready_future, timeout=timeout)
+ finally:
+ self.mcp_client_event[name] = event
+
+ if ready_future.done() and ready_future.exception():
+ exc = ready_future.exception()
+ if exc is not None:
+ raise exc
+
+ async def disable_mcp_server(
+ self, name: str | None = None, timeout: float = 10
+ ) -> None:
+ """Disable an MCP server by its name.
+
+ Args:
+ name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
+ timeout (int): Timeout.
+ """
+ if name:
+ if name not in self.mcp_client_event:
+ return
+ client = self.mcp_client_dict.get(name)
+ self.mcp_client_event[name].set()
+ if not client:
+ return
+ client_running_event = client.running_event
+ try:
+ await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
+ finally:
+ self.mcp_client_event.pop(name, None)
+ self.func_list = [
+ f
+ for f in self.func_list
+ if f.origin != "mcp" or f.mcp_server_name != name
+ ]
+ else:
+ running_events = [
+ client.running_event.wait() for client in self.mcp_client_dict.values()
+ ]
+ for key, event in self.mcp_client_event.items():
+ event.set()
+ # waiting for all clients to finish
+ try:
+ await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
+ finally:
+ self.mcp_client_event.clear()
+ self.mcp_client_dict.clear()
+ self.func_list = [f for f in self.func_list if f.origin != "mcp"]
+
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py
index df21e6a12..370c5322b 100644
--- a/astrbot/core/provider/manager.py
+++ b/astrbot/core/provider/manager.py
@@ -169,10 +169,7 @@ class ProviderManager:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
# 初始化 MCP Client 连接
- asyncio.create_task(
- self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
- )
- self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
+ asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
async def load_provider(self, provider_config: dict):
if not provider_config["enable"]:
@@ -422,7 +419,7 @@ class ProviderManager:
self.curr_tts_provider_inst = None
if getattr(self.inst_map[provider_id], "terminate", None):
- await self.inst_map[provider_id].terminate() # type: ignore
+ await self.inst_map[provider_id].terminate() # type: ignore
logger.info(
f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})"
@@ -432,6 +429,8 @@ class ProviderManager:
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
- await provider_inst.terminate() # type: ignore
- # 清理 MCP Client 连接
- await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
+ await provider_inst.terminate() # type: ignore
+ try:
+ await self.llm_tools.disable_mcp_server()
+ except Exception:
+ logger.error("Error while disabling MCP servers", exc_info=True)
diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py
index 25c291659..86318f8b7 100644
--- a/astrbot/core/star/__init__.py
+++ b/astrbot/core/star/__init__.py
@@ -10,7 +10,7 @@ from astrbot.core.star.star_tools import StarTools
class Star(CommandParserMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
- def __init__(self, context: Context):
+ def __init__(self, context: Context, config: dict | None = None):
StarTools.initialize(context)
self.context = context
@@ -41,9 +41,17 @@ class Star(CommandParserMixin):
tmpl, data, return_url=return_url, options=options
)
+ async def initialize(self):
+ """当插件被激活时会调用这个方法"""
+ pass
+
async def terminate(self):
"""当插件被禁用、重载插件时会调用这个方法"""
pass
+ def __del__(self):
+ """[Deprecated] 当插件被禁用、重载插件时会调用这个方法"""
+ pass
+
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py
index d44388238..2fe9dd7f3 100644
--- a/astrbot/core/star/star.py
+++ b/astrbot/core/star/star.py
@@ -2,6 +2,7 @@ from __future__ import annotations
from dataclasses import dataclass, field
from types import ModuleType
+from typing import TYPE_CHECKING
from astrbot.core.config import AstrBotConfig
@@ -9,6 +10,9 @@ star_registry: list[StarMetadata] = []
star_map: dict[str, StarMetadata] = {}
"""key 是模块路径,__module__"""
+if TYPE_CHECKING:
+ from . import Star
+
@dataclass
class StarMetadata:
@@ -29,12 +33,12 @@ class StarMetadata:
repo: str | None = None
"""插件仓库地址"""
- star_cls_type: type | None = None
+ star_cls_type: type[Star] | None = None
"""插件的类对象的类型"""
module_path: str | None = None
"""插件的模块路径"""
- star_cls: object | None = None
+ star_cls: Star | None = None
"""插件的类对象"""
module: ModuleType | None = None
"""插件的模块对象"""
diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py
index 4a6d4d902..b64b4aa85 100644
--- a/astrbot/core/star/star_manager.py
+++ b/astrbot/core/star/star_manager.py
@@ -163,7 +163,7 @@ class PluginManager:
plugins.extend(_p)
return plugins
- async def _check_plugin_dept_update(self, target_plugin: str = None):
+ async def _check_plugin_dept_update(self, target_plugin: str | None = None):
"""检查插件的依赖
如果 target_plugin 为 None,则检查所有插件的依赖
"""
@@ -187,7 +187,7 @@ class PluginManager:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
@staticmethod
- def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata:
+ def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None:
"""先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。
Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数来获取元数据。
@@ -253,8 +253,8 @@ class PluginManager:
def _purge_modules(
self,
- module_patterns: list[str] = None,
- root_dir_name: str = None,
+ module_patterns: list[str] | None = None,
+ root_dir_name: str | None = None,
is_reserved: bool = False,
):
"""从 sys.modules 中移除指定的模块
@@ -314,8 +314,8 @@ class PluginManager:
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
-
- await self._unbind_plugin(smd.name, smd.module_path)
+ if smd.name and smd.module_path:
+ await self._unbind_plugin(smd.name, smd.module_path)
star_handlers_registry.clear()
star_map.clear()
@@ -331,8 +331,8 @@ class PluginManager:
logger.warning(
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
-
- await self._unbind_plugin(smd.name, specified_module_path)
+ if smd.name:
+ await self._unbind_plugin(smd.name, specified_module_path)
result = await self.load(specified_module_path)
@@ -460,8 +460,7 @@ class PluginManager:
metadata.config = plugin_config
if path not in inactivated_plugins:
# 只有没有禁用插件时才实例化插件类
- if plugin_config:
- # metadata.config = plugin_config
+ if plugin_config and metadata.star_cls_type:
try:
metadata.star_cls = metadata.star_cls_type(
context=self.context, config=plugin_config
@@ -470,7 +469,7 @@ class PluginManager:
metadata.star_cls = metadata.star_cls_type(
context=self.context
)
- else:
+ elif metadata.star_cls_type:
metadata.star_cls = metadata.star_cls_type(
context=self.context
)
@@ -487,6 +486,10 @@ class PluginManager:
)
metadata.update_platform_compatibility(plugin_enable_config)
+ assert metadata.module_path is not None, (
+ f"插件 {metadata.name} 的模块路径为空。"
+ )
+
# 绑定 handler
related_handlers = (
star_handlers_registry.get_handlers_by_module_name(
@@ -495,7 +498,8 @@ class PluginManager:
)
for handler in related_handlers:
handler.handler = functools.partial(
- handler.handler, metadata.star_cls
+ handler.handler,
+ metadata.star_cls, # type: ignore
)
# 绑定 llm_tool handler
for func_tool in llm_tools.func_list:
@@ -505,7 +509,8 @@ class PluginManager:
):
func_tool.handler_module_path = metadata.module_path
func_tool.handler = functools.partial(
- func_tool.handler, metadata.star_cls
+ func_tool.handler,
+ metadata.star_cls, # type: ignore
)
if func_tool.name in inactivated_llm_tools:
func_tool.active = False
@@ -532,13 +537,12 @@ class PluginManager:
obj = getattr(module, classes[0])(
context=self.context
) # 实例化插件类
- else:
- logger.info(f"插件 {metadata.name} 已被禁用。")
- metadata = None
metadata = self._load_plugin_metadata(
plugin_path=plugin_dir_path, plugin_obj=obj
)
+ if not metadata:
+ raise Exception(f"无法找到插件 {plugin_dir_path} 的元数据。")
metadata.star_cls = obj
metadata.config = plugin_config
metadata.module = module
@@ -553,6 +557,10 @@ class PluginManager:
if metadata.module_path in inactivated_plugins:
metadata.activated = False
+ assert metadata.module_path is not None, (
+ f"插件 {metadata.name} 的模块路径为空。"
+ )
+
full_names = []
for handler in star_handlers_registry.get_handlers_by_module_name(
metadata.module_path
@@ -592,7 +600,7 @@ class PluginManager:
metadata.star_handler_full_names = full_names
# 执行 initialize() 方法
- if hasattr(metadata.star_cls, "initialize"):
+ if hasattr(metadata.star_cls, "initialize") and metadata.star_cls:
await metadata.star_cls.initialize()
except BaseException as e:
@@ -734,6 +742,9 @@ class PluginManager:
]:
del star_handlers_registry.star_handlers_map[k]
+ if plugin is None:
+ return
+
self._purge_modules(
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
)
@@ -795,6 +806,9 @@ class PluginManager:
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
return
+ if star_metadata.star_cls is None:
+ return
+
if hasattr(star_metadata.star_cls, "__del__"):
asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__
diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py
index 2cd8fd9c2..2b34c2a14 100644
--- a/astrbot/core/utils/io.py
+++ b/astrbot/core/utils/io.py
@@ -30,7 +30,7 @@ def on_error(func, path, exc_info):
raise exc_info[1]
-def remove_dir(file_path) -> bool:
+def remove_dir(file_path: str) -> bool:
if not os.path.exists(file_path):
return True
shutil.rmtree(file_path, onerror=on_error)
diff --git a/astrbot/core/utils/session_lock.py b/astrbot/core/utils/session_lock.py
new file mode 100644
index 000000000..912d91e53
--- /dev/null
+++ b/astrbot/core/utils/session_lock.py
@@ -0,0 +1,29 @@
+import asyncio
+from collections import defaultdict
+from contextlib import asynccontextmanager
+
+
+class SessionLockManager:
+ def __init__(self):
+ self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
+ self._lock_count: dict[str, int] = defaultdict(int)
+ self._access_lock = asyncio.Lock()
+
+ @asynccontextmanager
+ async def acquire_lock(self, session_id: str):
+ async with self._access_lock:
+ lock = self._locks[session_id]
+ self._lock_count[session_id] += 1
+
+ try:
+ async with lock:
+ yield
+ finally:
+ async with self._access_lock:
+ self._lock_count[session_id] -= 1
+ if self._lock_count[session_id] == 0:
+ self._locks.pop(session_id, None)
+ self._lock_count.pop(session_id, None)
+
+
+session_lock_manager = SessionLockManager()
diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py
index 7a503583b..42018d19e 100644
--- a/astrbot/core/utils/shared_preferences.py
+++ b/astrbot/core/utils/shared_preferences.py
@@ -1,7 +1,9 @@
import json
import os
+from typing import TypeVar
from .astrbot_path import get_astrbot_data_path
+_VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
@@ -24,7 +26,7 @@ class SharedPreferences:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
- def get(self, key, default=None):
+ def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default)
def put(self, key, value):
diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py
index 79397290e..2a8389396 100644
--- a/astrbot/dashboard/routes/stat.py
+++ b/astrbot/dashboard/routes/stat.py
@@ -2,6 +2,7 @@ import traceback
import psutil
import time
import threading
+import aiohttp
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
@@ -25,6 +26,7 @@ class StatRoute(Route):
"/stat/version": ("GET", self.get_version),
"/stat/start-time": ("GET", self.get_start_time),
"/stat/restart-core": ("POST", self.restart_core),
+ "/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection),
}
self.db_helper = db_helper
self.register_routes()
@@ -45,11 +47,7 @@ class StatRoute(Route):
"""将总秒数转换为时分秒组件"""
minutes, seconds = divmod(total_seconds, 60)
hours, minutes = divmod(minutes, 60)
- return {
- "hours": hours,
- "minutes": minutes,
- "seconds": seconds
- }
+ return {"hours": hours, "minutes": minutes, "seconds": seconds}
def is_default_cred(self):
username = self.config["dashboard"]["username"]
@@ -144,3 +142,40 @@ class StatRoute(Route):
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(e.__str__()).__dict__
+
+ async def test_ghproxy_connection(self):
+ """
+ 测试 GitHub 代理连接是否可用。
+ """
+ try:
+ data = await request.get_json()
+ proxy_url: str = data.get("proxy_url")
+
+ if not proxy_url:
+ return Response().error("proxy_url is required").__dict__
+
+ proxy_url = proxy_url.rstrip("/")
+
+ test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version"
+ start_time = time.time()
+
+ async with aiohttp.ClientSession() as session:
+ async with session.get(
+ test_url, timeout=aiohttp.ClientTimeout(total=10)
+ ) as response:
+ if response.status == 200:
+ end_time = time.time()
+ _ = await response.text()
+ ret = {
+ "latency": round((end_time - start_time) * 1000, 2),
+ }
+ return Response().ok(data=ret).__dict__
+ else:
+ return (
+ Response()
+ .error(f"Failed. Status code: {response.status}")
+ .__dict__
+ )
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return Response().error(f"Error: {str(e)}").__dict__
diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py
index d38014c71..5dad2576b 100644
--- a/astrbot/dashboard/routes/tools.py
+++ b/astrbot/dashboard/routes/tools.py
@@ -26,6 +26,7 @@ class ToolsRoute(Route):
"/tools/mcp/update": ("POST", self.update_mcp_server),
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
"/tools/mcp/market": ("GET", self.get_mcp_markets),
+ "/tools/mcp/test": ("POST", self.test_mcp_connection),
}
self.register_routes()
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
@@ -132,12 +133,19 @@ class ToolsRoute(Route):
config["mcpServers"][name] = server_config
if self.save_mcp_config(config):
- # 动态初始化新MCP客户端
- await self.tool_mgr.mcp_service_queue.put({
- "type": "init",
- "name": name,
- "cfg": config["mcpServers"][name],
- })
+ try:
+ await self.tool_mgr.enable_mcp_server(
+ name, server_config, timeout=30
+ )
+ except TimeoutError:
+ return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return (
+ Response()
+ .error(f"启用 MCP 服务器 {name} 失败: {str(e)}")
+ .__dict__
+ )
return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
else:
return Response().error("保存配置失败").__dict__
@@ -193,31 +201,55 @@ class ToolsRoute(Route):
if self.save_mcp_config(config):
# 处理MCP客户端状态变化
if active:
- # 如果要激活服务器或者配置已更改
if name in self.tool_mgr.mcp_client_dict or not only_update_active:
- await self.tool_mgr.mcp_service_queue.put({
- "type": "terminate",
- "name": name,
- })
- await self.tool_mgr.mcp_service_queue.put({
- "type": "init",
- "name": name,
- "cfg": config["mcpServers"][name],
- })
- else:
- # 客户端不存在,初始化
- await self.tool_mgr.mcp_service_queue.put({
- "type": "init",
- "name": name,
- "cfg": config["mcpServers"][name],
- })
+ try:
+ await self.tool_mgr.disable_mcp_server(name, timeout=10)
+ except TimeoutError as e:
+ return (
+ Response()
+ .error(f"启用前停用 MCP 服务器时 {name} 超时: {str(e)}")
+ .__dict__
+ )
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return (
+ Response()
+ .error(f"启用前停用 MCP 服务器时 {name} 失败: {str(e)}")
+ .__dict__
+ )
+ try:
+ await self.tool_mgr.enable_mcp_server(
+ name, config["mcpServers"][name], timeout=30
+ )
+ except TimeoutError:
+ return (
+ Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
+ )
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return (
+ Response()
+ .error(f"启用 MCP 服务器 {name} 失败: {str(e)}")
+ .__dict__
+ )
else:
# 如果要停用服务器
if name in self.tool_mgr.mcp_client_dict:
- self.tool_mgr.mcp_service_queue.put_nowait({
- "type": "terminate",
- "name": name,
- })
+ try:
+ await self.tool_mgr.disable_mcp_server(name, timeout=10)
+ except TimeoutError:
+ return (
+ Response()
+ .error(f"停用 MCP 服务器 {name} 超时。")
+ .__dict__
+ )
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return (
+ Response()
+ .error(f"停用 MCP 服务器 {name} 失败: {str(e)}")
+ .__dict__
+ )
return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__
else:
@@ -239,17 +271,23 @@ class ToolsRoute(Route):
if name not in config["mcpServers"]:
return Response().error(f"服务器 {name} 不存在").__dict__
- # 删除服务器配置
del config["mcpServers"][name]
if self.save_mcp_config(config):
- # 关闭并删除MCP客户端
if name in self.tool_mgr.mcp_client_dict:
- self.tool_mgr.mcp_service_queue.put_nowait({
- "type": "terminate",
- "name": name,
- })
-
+ try:
+ await self.tool_mgr.disable_mcp_server(name, timeout=10)
+ except TimeoutError:
+ return (
+ Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__
+ )
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return (
+ Response()
+ .error(f"停用 MCP 服务器 {name} 失败: {str(e)}")
+ .__dict__
+ )
return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__
else:
return Response().error("保存配置失败").__dict__
@@ -281,3 +319,20 @@ class ToolsRoute(Route):
except Exception as _:
logger.error(traceback.format_exc())
return Response().error("获取市场数据失败").__dict__
+
+ async def test_mcp_connection(self):
+ """
+ 测试 MCP 服务器连接
+ """
+ try:
+ server_data = await request.json
+ config = server_data.get("mcp_server_config", None)
+
+ tools_name = await self.tool_mgr.test_mcp_server_connection(config)
+ return (
+ Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
+ )
+
+ except Exception as e:
+ logger.error(traceback.format_exc())
+ return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__
diff --git a/dashboard/src/components/shared/ItemCard.vue b/dashboard/src/components/shared/ItemCard.vue
index ff790cb7b..6152c531f 100644
--- a/dashboard/src/components/shared/ItemCard.vue
+++ b/dashboard/src/components/shared/ItemCard.vue
@@ -9,6 +9,8 @@
hide-details
density="compact"
:model-value="getItemEnabled()"
+ :loading="loading"
+ :disabled="loading"
v-bind="props"
@update:model-value="toggleEnabled"
>
@@ -77,6 +79,10 @@ export default {
bglogo: {
type: String,
default: null
+ },
+ loading: {
+ type: Boolean,
+ default: false
}
},
emits: ['toggle-enabled', 'delete', 'edit'],
diff --git a/dashboard/src/components/shared/ProxySelector.vue b/dashboard/src/components/shared/ProxySelector.vue
new file mode 100644
index 000000000..d45a0f520
--- /dev/null
+++ b/dashboard/src/components/shared/ProxySelector.vue
@@ -0,0 +1,152 @@
+
+ GitHub 加速
+
{{ extension_data.message }}
-{{ tm('dialogs.error.checkConsole') }}
-{{ extension_data.message }}
+{{ tm('dialogs.error.checkConsole') }}
+{{ tm('mcpServers.empty') }}
+{{ tm('mcpServers.empty') }}
-{{ tm('functionTools.empty') }}
+
-
{{ tool.function.description }}
- - -
-
{{ tm('functionTools.noParameters') }}
-{{ tm('functionTools.empty') }}
+
+
{{ tool.function.description }}
+ + +
+
{{ tm('functionTools.noParameters') }}
+