diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 50b436043..54ad1e63b 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -13,6 +13,7 @@ from astrbot.core.message.message_event_result import BaseMessageComponent from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star import star_map from astrbot.core.utils.path_util import path_Mapping +from astrbot.core.utils.session_lock import session_lock_manager @register_stage @@ -177,25 +178,26 @@ class RespondStage(Stage): result.chain.remove(comp) break - for rcomp in record_comps: - i = await self._calc_comp_interval(rcomp) - await asyncio.sleep(i) - try: - await event.send(MessageChain([rcomp])) - except Exception as e: - logger.error(f"发送消息失败: {e} chain: {result.chain}") - break - - # 分段回复 - for comp in non_record_comps: - i = await self._calc_comp_interval(comp) - await asyncio.sleep(i) - try: - await event.send(MessageChain([*decorated_comps, comp])) - decorated_comps = [] # 清空已发送的装饰组件 - except Exception as e: - logger.error(f"发送消息失败: {e} chain: {result.chain}") - break + # leverage lock to guarentee the order of message sending among different events + async with session_lock_manager.acquire_lock(event.unified_msg_origin): + for rcomp in record_comps: + i = await self._calc_comp_interval(rcomp) + await asyncio.sleep(i) + try: + await event.send(MessageChain([rcomp])) + except Exception as e: + logger.error(f"发送消息失败: {e} chain: {result.chain}") + break + # 分段回复 + for comp in non_record_comps: + i = await self._calc_comp_interval(comp) + await asyncio.sleep(i) + try: + await event.send(MessageChain([*decorated_comps, comp])) + decorated_comps = [] # 清空已发送的装饰组件 + except Exception as e: + logger.error(f"发送消息失败: {e} chain: {result.chain}") + break else: for rcomp in record_comps: try: diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index cea2e4f38..07a0fbd8f 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -39,6 +39,72 @@ SUPPORTED_TYPES = [ ] # json schema 支持的数据类型 +def _prepare_config(config: dict) -> dict: + """准备配置,处理嵌套格式""" + if "mcpServers" in config and config["mcpServers"]: + first_key = next(iter(config["mcpServers"])) + config = config["mcpServers"][first_key] + config.pop("active", None) + return config + + +async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: + """快速测试 MCP 服务器可达性""" + import aiohttp + + cfg = _prepare_config(config.copy()) + + url = cfg["url"] + headers = cfg.get("headers", {}) + timeout = cfg.get("timeout", 10) + + try: + async with aiohttp.ClientSession() as session: + if cfg.get("transport") == "streamable_http": + test_payload = { + "jsonrpc": "2.0", + "method": "initialize", + "id": 0, + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.2.3"}, + }, + } + async with session.post( + url, + headers={ + **headers, + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + }, + json=test_payload, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as response: + if response.status == 200: + return True, "" + else: + return False, f"HTTP {response.status}: {response.reason}" + else: + async with session.get( + url, + headers={ + **headers, + "Accept": "application/json, text/event-stream", + }, + timeout=aiohttp.ClientTimeout(total=timeout), + ) as response: + if response.status == 200: + return True, "" + else: + return False, f"HTTP {response.status}: {response.reason}" + + except asyncio.TimeoutError: + return False, f"连接超时: {timeout}秒" + except Exception as e: + return False, f"{e!s}" + + @dataclass class FuncTool: """ @@ -80,12 +146,10 @@ class FuncTool: if not self.mcp_client or not self.mcp_client.session: raise Exception(f"MCP client for {self.name} is not available") # 使用name属性而不是额外的mcp_tool_name - if ":" in self.name: - # 如果名字是格式为 mcp:server:tool_name,提取实际的工具名 - actual_tool_name = self.name.split(":")[-1] - return await self.mcp_client.session.call_tool(actual_tool_name, args) - else: - return await self.mcp_client.session.call_tool(self.name, args) + actual_tool_name = ( + self.name.split(":")[-1] if ":" in self.name else self.name + ) + return await self.mcp_client.session.call_tool(actual_tool_name, args) else: raise Exception(f"Unknown function origin: {self.origin}") @@ -100,6 +164,7 @@ class MCPClient: self.active: bool = True self.tools: List[mcp.Tool] = [] self.server_errlogs: List[str] = [] + self.running_event = asyncio.Event() async def connect_to_server(self, mcp_server_config: dict, name: str): """连接到 MCP 服务器 @@ -112,17 +177,19 @@ class MCPClient: Args: mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server """ - cfg = mcp_server_config.copy() - if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0: - key_0 = list(cfg["mcpServers"].keys())[0] - cfg = cfg["mcpServers"][key_0] - cfg.pop("active", None) # Remove active flag from config + cfg = _prepare_config(mcp_server_config.copy()) + + def logging_callback(msg: str): + # 处理 MCP 服务的错误日志 + print(f"MCP Server {name} Error: {msg}") + self.server_errlogs.append(msg) if "url" in cfg: - is_sse = True - if cfg.get("transport") == "streamable_http": - is_sse = False - if is_sse: + success, error_msg = await _quick_test_mcp_connection(cfg) + if not success: + raise Exception(error_msg) + + if cfg.get("transport") != "streamable_http": # SSE transport method self._streams_context = sse_client( url=cfg["url"], @@ -130,11 +197,18 @@ class MCPClient: timeout=cfg.get("timeout", 5), sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5), ) - streams = await self.exit_stack.enter_async_context(self._streams_context) + streams = await self.exit_stack.enter_async_context( + self._streams_context + ) # Create a new client session + read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20)) self.session = await self.exit_stack.enter_async_context( - mcp.ClientSession(*streams) + mcp.ClientSession( + *streams, + read_timeout_seconds=read_timeout, + logging_callback=logging_callback, # type: ignore + ) ) else: timeout = timedelta(seconds=cfg.get("timeout", 30)) @@ -148,11 +222,19 @@ class MCPClient: sse_read_timeout=sse_read_timeout, terminate_on_close=cfg.get("terminate_on_close", True), ) - read_s, write_s, _ = await self.exit_stack.enter_async_context(self._streams_context) + read_s, write_s, _ = await self.exit_stack.enter_async_context( + self._streams_context + ) # Create a new client session + read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 20)) self.session = await self.exit_stack.enter_async_context( - mcp.ClientSession(read_stream=read_s, write_stream=write_s) + mcp.ClientSession( + read_stream=read_s, + write_stream=write_s, + read_timeout_seconds=read_timeout, + logging_callback=logging_callback, # type: ignore + ) ) else: @@ -172,7 +254,7 @@ class MCPClient: logger=logger, identifier=f"MCPServer-{name}", callback=callback, - ), + ), # type: ignore ), ) @@ -180,19 +262,18 @@ class MCPClient: self.session = await self.exit_stack.enter_async_context( mcp.ClientSession(*stdio_transport) ) - await self.session.initialize() async def list_tools_and_save(self) -> mcp.ListToolsResult: """List all tools from the server and save them to self.tools""" response = await self.session.list_tools() - logger.debug(f"MCP server {self.name} list tools response: {response}") self.tools = response.tools return response async def cleanup(self): """Clean up resources""" await self.exit_stack.aclose() + self.running_event.set() # Set the running event to indicate cleanup is done class FuncCall: @@ -201,8 +282,6 @@ class FuncCall: """内部加载的 func tools""" self.mcp_client_dict: Dict[str, MCPClient] = {} """MCP 服务列表""" - self.mcp_service_queue = asyncio.Queue() - """用于外部控制 MCP 服务的启停""" self.mcp_client_event: Dict[str, asyncio.Event] = {} def empty(self) -> bool: @@ -258,7 +337,7 @@ class FuncCall: return f return None - async def _init_mcp_clients(self) -> None: + async def init_mcp_clients(self) -> None: """从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下: ``` { @@ -300,115 +379,64 @@ class FuncCall: ) self.mcp_client_event[name] = event - async def mcp_service_selector(self): - """为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制 - - 使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下: - - {"type": "init"} 初始化所有MCP客户端 - - {"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端 - - {"type": "terminate"} 终止所有MCP客户端 - - {"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端 - """ - while True: - data = await self.mcp_service_queue.get() - if data["type"] == "init": - if "name" in data: - event = asyncio.Event() - asyncio.create_task( - self._init_mcp_client_task_wrapper( - data["name"], data["cfg"], event - ) - ) - self.mcp_client_event[data["name"]] = event - else: - await self._init_mcp_clients() - elif data["type"] == "terminate": - if "name" in data: - # await self._terminate_mcp_client(data["name"]) - if data["name"] in self.mcp_client_event: - self.mcp_client_event[data["name"]].set() - self.mcp_client_event.pop(data["name"], None) - self.func_list = [ - f - for f in self.func_list - if not ( - f.origin == "mcp" and f.mcp_server_name == data["name"] - ) - ] - else: - for name in self.mcp_client_dict.keys(): - # await self._terminate_mcp_client(name) - # self.mcp_client_event[name].set() - if name in self.mcp_client_event: - self.mcp_client_event[name].set() - self.mcp_client_event.pop(name, None) - self.func_list = [f for f in self.func_list if f.origin != "mcp"] - async def _init_mcp_client_task_wrapper( - self, name: str, cfg: dict, event: asyncio.Event + self, + name: str, + cfg: dict, + event: asyncio.Event, + ready_future: asyncio.Future = None, ) -> None: """初始化 MCP 客户端的包装函数,用于捕获异常""" try: await self._init_mcp_client(name, cfg) + tools = await self.mcp_client_dict[name].list_tools_and_save() + if ready_future and not ready_future.done(): + # tell the caller we are ready + ready_future.set_result(tools) await event.wait() logger.info(f"收到 MCP 客户端 {name} 终止信号") except Exception as e: - import traceback - - traceback.print_exc() - logger.error(f"初始化 MCP 客户端 {name} 失败: {e}") + logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True) + if ready_future and not ready_future.done(): + ready_future.set_exception(e) finally: # 无论如何都能清理 await self._terminate_mcp_client(name) async def _init_mcp_client(self, name: str, config: dict) -> None: """初始化单个MCP客户端""" - try: - # 先清理之前的客户端,如果存在 - if name in self.mcp_client_dict: - await self._terminate_mcp_client(name) + # 先清理之前的客户端,如果存在 + if name in self.mcp_client_dict: + await self._terminate_mcp_client(name) - mcp_client = MCPClient() - mcp_client.name = name - self.mcp_client_dict[name] = mcp_client - await mcp_client.connect_to_server(config, name) - tools_res = await mcp_client.list_tools_and_save() - tool_names = [tool.name for tool in tools_res.tools] + mcp_client = MCPClient() + mcp_client.name = name + self.mcp_client_dict[name] = mcp_client + await mcp_client.connect_to_server(config, name) + tools_res = await mcp_client.list_tools_and_save() + logger.debug(f"MCP server {name} list tools response: {tools_res}") + tool_names = [tool.name for tool in tools_res.tools] - # 移除该MCP服务之前的工具(如有) - self.func_list = [ - f - for f in self.func_list - if not (f.origin == "mcp" and f.mcp_server_name == name) - ] + # 移除该MCP服务之前的工具(如有) + self.func_list = [ + f + for f in self.func_list + if not (f.origin == "mcp" and f.mcp_server_name == name) + ] - # 将 MCP 工具转换为 FuncTool 并添加到 func_list - for tool in mcp_client.tools: - func_tool = FuncTool( - name=tool.name, - parameters=tool.inputSchema, - description=tool.description, - origin="mcp", - mcp_server_name=name, - mcp_client=mcp_client, - ) - self.func_list.append(func_tool) + # 将 MCP 工具转换为 FuncTool 并添加到 func_list + for tool in mcp_client.tools: + func_tool = FuncTool( + name=tool.name, + parameters=tool.inputSchema, + description=tool.description, + origin="mcp", + mcp_server_name=name, + mcp_client=mcp_client, + ) + self.func_list.append(func_tool) - logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}") - return - except Exception as e: - import traceback - - logger.error(traceback.format_exc()) - logger.error(f"初始化 MCP 客户端 {name} 失败: {e}") - # 发生错误时确保客户端被清理 - if name in self.mcp_client_dict: - await self._terminate_mcp_client(name) - return + logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}") async def _terminate_mcp_client(self, name: str) -> None: """关闭并清理MCP客户端""" @@ -418,7 +446,7 @@ class FuncCall: await self.mcp_client_dict[name].cleanup() self.mcp_client_dict.pop(name) except Exception as e: - logger.info(f"清空 MCP 客户端资源 {name}: {e}。") + logger.error(f"清空 MCP 客户端资源 {name}: {e}。") # 移除关联的FuncTool self.func_list = [ f @@ -427,6 +455,103 @@ class FuncCall: ] logger.info(f"已关闭 MCP 服务 {name}") + @staticmethod + async def test_mcp_server_connection(config: dict) -> list[str]: + if "url" in config: + success, error_msg = await _quick_test_mcp_connection(config) + if not success: + raise Exception(error_msg) + + mcp_client = MCPClient() + try: + logger.debug(f"testing MCP server connection with config: {config}") + await mcp_client.connect_to_server(config, "test") + tools_res = await mcp_client.list_tools_and_save() + tool_names = [tool.name for tool in tools_res.tools] + finally: + logger.debug("Cleaning up MCP client after testing connection.") + await mcp_client.cleanup() + return tool_names + + async def enable_mcp_server( + self, + name: str, + config: dict, + event: asyncio.Event | None = None, + ready_future: asyncio.Future | None = None, + timeout: int = 30, + ) -> None: + """Enable_mcp_server a new MCP server to the manager and initialize it. + + Args: + name (str): The name of the MCP server. + config (dict): Configuration for the MCP server. + event (asyncio.Event): Event to signal when the MCP client is ready. + ready_future (asyncio.Future): Future to signal when the MCP client is ready. + timeout (int): Timeout for the initialization. + Raises: + TimeoutError: If the initialization does not complete within the specified timeout. + Exception: If there is an error during initialization. + """ + if not event: + event = asyncio.Event() + if not ready_future: + ready_future = asyncio.Future() + if name in self.mcp_client_dict: + return + asyncio.create_task( + self._init_mcp_client_task_wrapper(name, config, event, ready_future) + ) + try: + await asyncio.wait_for(ready_future, timeout=timeout) + finally: + self.mcp_client_event[name] = event + + if ready_future.done() and ready_future.exception(): + exc = ready_future.exception() + if exc is not None: + raise exc + + async def disable_mcp_server( + self, name: str | None = None, timeout: float = 10 + ) -> None: + """Disable an MCP server by its name. + + Args: + name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled. + timeout (int): Timeout. + """ + if name: + if name not in self.mcp_client_event: + return + client = self.mcp_client_dict.get(name) + self.mcp_client_event[name].set() + if not client: + return + client_running_event = client.running_event + try: + await asyncio.wait_for(client_running_event.wait(), timeout=timeout) + finally: + self.mcp_client_event.pop(name, None) + self.func_list = [ + f + for f in self.func_list + if f.origin != "mcp" or f.mcp_server_name != name + ] + else: + running_events = [ + client.running_event.wait() for client in self.mcp_client_dict.values() + ] + for key, event in self.mcp_client_event.items(): + event.set() + # waiting for all clients to finish + try: + await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout) + finally: + self.mcp_client_event.clear() + self.mcp_client_dict.clear() + self.func_list = [f for f in self.func_list if f.origin != "mcp"] + def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list: """ 获得 OpenAI API 风格的**已经激活**的工具描述 diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index df21e6a12..370c5322b 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -169,10 +169,7 @@ class ProviderManager: self.curr_tts_provider_inst = self.tts_provider_insts[0] # 初始化 MCP Client 连接 - asyncio.create_task( - self.llm_tools.mcp_service_selector(), name="mcp-service-handler" - ) - self.llm_tools.mcp_service_queue.put_nowait({"type": "init"}) + asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients") async def load_provider(self, provider_config: dict): if not provider_config["enable"]: @@ -422,7 +419,7 @@ class ProviderManager: self.curr_tts_provider_inst = None if getattr(self.inst_map[provider_id], "terminate", None): - await self.inst_map[provider_id].terminate() # type: ignore + await self.inst_map[provider_id].terminate() # type: ignore logger.info( f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})" @@ -432,6 +429,8 @@ class ProviderManager: async def terminate(self): for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): - await provider_inst.terminate() # type: ignore - # 清理 MCP Client 连接 - await self.llm_tools.mcp_service_queue.put({"type": "terminate"}) + await provider_inst.terminate() # type: ignore + try: + await self.llm_tools.disable_mcp_server() + except Exception: + logger.error("Error while disabling MCP servers", exc_info=True) diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index 25c291659..86318f8b7 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -10,7 +10,7 @@ from astrbot.core.star.star_tools import StarTools class Star(CommandParserMixin): """所有插件(Star)的父类,所有插件都应该继承于这个类""" - def __init__(self, context: Context): + def __init__(self, context: Context, config: dict | None = None): StarTools.initialize(context) self.context = context @@ -41,9 +41,17 @@ class Star(CommandParserMixin): tmpl, data, return_url=return_url, options=options ) + async def initialize(self): + """当插件被激活时会调用这个方法""" + pass + async def terminate(self): """当插件被禁用、重载插件时会调用这个方法""" pass + def __del__(self): + """[Deprecated] 当插件被禁用、重载插件时会调用这个方法""" + pass + __all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"] diff --git a/astrbot/core/star/star.py b/astrbot/core/star/star.py index d44388238..2fe9dd7f3 100644 --- a/astrbot/core/star/star.py +++ b/astrbot/core/star/star.py @@ -2,6 +2,7 @@ from __future__ import annotations from dataclasses import dataclass, field from types import ModuleType +from typing import TYPE_CHECKING from astrbot.core.config import AstrBotConfig @@ -9,6 +10,9 @@ star_registry: list[StarMetadata] = [] star_map: dict[str, StarMetadata] = {} """key 是模块路径,__module__""" +if TYPE_CHECKING: + from . import Star + @dataclass class StarMetadata: @@ -29,12 +33,12 @@ class StarMetadata: repo: str | None = None """插件仓库地址""" - star_cls_type: type | None = None + star_cls_type: type[Star] | None = None """插件的类对象的类型""" module_path: str | None = None """插件的模块路径""" - star_cls: object | None = None + star_cls: Star | None = None """插件的类对象""" module: ModuleType | None = None """插件的模块对象""" diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 4a6d4d902..b64b4aa85 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -163,7 +163,7 @@ class PluginManager: plugins.extend(_p) return plugins - async def _check_plugin_dept_update(self, target_plugin: str = None): + async def _check_plugin_dept_update(self, target_plugin: str | None = None): """检查插件的依赖 如果 target_plugin 为 None,则检查所有插件的依赖 """ @@ -187,7 +187,7 @@ class PluginManager: logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}") @staticmethod - def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata: + def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None: """先寻找 metadata.yaml 文件,如果不存在,则使用插件对象的 info() 函数获取元数据。 Notes: 旧版本 AstrBot 插件可能使用的是 info() 函数来获取元数据。 @@ -253,8 +253,8 @@ class PluginManager: def _purge_modules( self, - module_patterns: list[str] = None, - root_dir_name: str = None, + module_patterns: list[str] | None = None, + root_dir_name: str | None = None, is_reserved: bool = False, ): """从 sys.modules 中移除指定的模块 @@ -314,8 +314,8 @@ class PluginManager: logger.warning( f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" ) - - await self._unbind_plugin(smd.name, smd.module_path) + if smd.name and smd.module_path: + await self._unbind_plugin(smd.name, smd.module_path) star_handlers_registry.clear() star_map.clear() @@ -331,8 +331,8 @@ class PluginManager: logger.warning( f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。" ) - - await self._unbind_plugin(smd.name, specified_module_path) + if smd.name: + await self._unbind_plugin(smd.name, specified_module_path) result = await self.load(specified_module_path) @@ -460,8 +460,7 @@ class PluginManager: metadata.config = plugin_config if path not in inactivated_plugins: # 只有没有禁用插件时才实例化插件类 - if plugin_config: - # metadata.config = plugin_config + if plugin_config and metadata.star_cls_type: try: metadata.star_cls = metadata.star_cls_type( context=self.context, config=plugin_config @@ -470,7 +469,7 @@ class PluginManager: metadata.star_cls = metadata.star_cls_type( context=self.context ) - else: + elif metadata.star_cls_type: metadata.star_cls = metadata.star_cls_type( context=self.context ) @@ -487,6 +486,10 @@ class PluginManager: ) metadata.update_platform_compatibility(plugin_enable_config) + assert metadata.module_path is not None, ( + f"插件 {metadata.name} 的模块路径为空。" + ) + # 绑定 handler related_handlers = ( star_handlers_registry.get_handlers_by_module_name( @@ -495,7 +498,8 @@ class PluginManager: ) for handler in related_handlers: handler.handler = functools.partial( - handler.handler, metadata.star_cls + handler.handler, + metadata.star_cls, # type: ignore ) # 绑定 llm_tool handler for func_tool in llm_tools.func_list: @@ -505,7 +509,8 @@ class PluginManager: ): func_tool.handler_module_path = metadata.module_path func_tool.handler = functools.partial( - func_tool.handler, metadata.star_cls + func_tool.handler, + metadata.star_cls, # type: ignore ) if func_tool.name in inactivated_llm_tools: func_tool.active = False @@ -532,13 +537,12 @@ class PluginManager: obj = getattr(module, classes[0])( context=self.context ) # 实例化插件类 - else: - logger.info(f"插件 {metadata.name} 已被禁用。") - metadata = None metadata = self._load_plugin_metadata( plugin_path=plugin_dir_path, plugin_obj=obj ) + if not metadata: + raise Exception(f"无法找到插件 {plugin_dir_path} 的元数据。") metadata.star_cls = obj metadata.config = plugin_config metadata.module = module @@ -553,6 +557,10 @@ class PluginManager: if metadata.module_path in inactivated_plugins: metadata.activated = False + assert metadata.module_path is not None, ( + f"插件 {metadata.name} 的模块路径为空。" + ) + full_names = [] for handler in star_handlers_registry.get_handlers_by_module_name( metadata.module_path @@ -592,7 +600,7 @@ class PluginManager: metadata.star_handler_full_names = full_names # 执行 initialize() 方法 - if hasattr(metadata.star_cls, "initialize"): + if hasattr(metadata.star_cls, "initialize") and metadata.star_cls: await metadata.star_cls.initialize() except BaseException as e: @@ -734,6 +742,9 @@ class PluginManager: ]: del star_handlers_registry.star_handlers_map[k] + if plugin is None: + return + self._purge_modules( root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved ) @@ -795,6 +806,9 @@ class PluginManager: logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。") return + if star_metadata.star_cls is None: + return + if hasattr(star_metadata.star_cls, "__del__"): asyncio.get_event_loop().run_in_executor( None, star_metadata.star_cls.__del__ diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 2cd8fd9c2..2b34c2a14 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -30,7 +30,7 @@ def on_error(func, path, exc_info): raise exc_info[1] -def remove_dir(file_path) -> bool: +def remove_dir(file_path: str) -> bool: if not os.path.exists(file_path): return True shutil.rmtree(file_path, onerror=on_error) diff --git a/astrbot/core/utils/session_lock.py b/astrbot/core/utils/session_lock.py new file mode 100644 index 000000000..912d91e53 --- /dev/null +++ b/astrbot/core/utils/session_lock.py @@ -0,0 +1,29 @@ +import asyncio +from collections import defaultdict +from contextlib import asynccontextmanager + + +class SessionLockManager: + def __init__(self): + self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock) + self._lock_count: dict[str, int] = defaultdict(int) + self._access_lock = asyncio.Lock() + + @asynccontextmanager + async def acquire_lock(self, session_id: str): + async with self._access_lock: + lock = self._locks[session_id] + self._lock_count[session_id] += 1 + + try: + async with lock: + yield + finally: + async with self._access_lock: + self._lock_count[session_id] -= 1 + if self._lock_count[session_id] == 0: + self._locks.pop(session_id, None) + self._lock_count.pop(session_id, None) + + +session_lock_manager = SessionLockManager() diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index 7a503583b..42018d19e 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -1,7 +1,9 @@ import json import os +from typing import TypeVar from .astrbot_path import get_astrbot_data_path +_VT = TypeVar("_VT") class SharedPreferences: def __init__(self, path=None): @@ -24,7 +26,7 @@ class SharedPreferences: json.dump(self._data, f, indent=4, ensure_ascii=False) f.flush() - def get(self, key, default=None): + def get(self, key, default: _VT = None) -> _VT: return self._data.get(key, default) def put(self, key, value): diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 79397290e..2a8389396 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -2,6 +2,7 @@ import traceback import psutil import time import threading +import aiohttp from .route import Route, Response, RouteContext from astrbot.core import logger from quart import request @@ -25,6 +26,7 @@ class StatRoute(Route): "/stat/version": ("GET", self.get_version), "/stat/start-time": ("GET", self.get_start_time), "/stat/restart-core": ("POST", self.restart_core), + "/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection), } self.db_helper = db_helper self.register_routes() @@ -45,11 +47,7 @@ class StatRoute(Route): """将总秒数转换为时分秒组件""" minutes, seconds = divmod(total_seconds, 60) hours, minutes = divmod(minutes, 60) - return { - "hours": hours, - "minutes": minutes, - "seconds": seconds - } + return {"hours": hours, "minutes": minutes, "seconds": seconds} def is_default_cred(self): username = self.config["dashboard"]["username"] @@ -144,3 +142,40 @@ class StatRoute(Route): except Exception as e: logger.error(traceback.format_exc()) return Response().error(e.__str__()).__dict__ + + async def test_ghproxy_connection(self): + """ + 测试 GitHub 代理连接是否可用。 + """ + try: + data = await request.get_json() + proxy_url: str = data.get("proxy_url") + + if not proxy_url: + return Response().error("proxy_url is required").__dict__ + + proxy_url = proxy_url.rstrip("/") + + test_url = f"{proxy_url}/https://github.com/AstrBotDevs/AstrBot/raw/refs/heads/master/.python-version" + start_time = time.time() + + async with aiohttp.ClientSession() as session: + async with session.get( + test_url, timeout=aiohttp.ClientTimeout(total=10) + ) as response: + if response.status == 200: + end_time = time.time() + _ = await response.text() + ret = { + "latency": round((end_time - start_time) * 1000, 2), + } + return Response().ok(data=ret).__dict__ + else: + return ( + Response() + .error(f"Failed. Status code: {response.status}") + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"Error: {str(e)}").__dict__ diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index d38014c71..5dad2576b 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -26,6 +26,7 @@ class ToolsRoute(Route): "/tools/mcp/update": ("POST", self.update_mcp_server), "/tools/mcp/delete": ("POST", self.delete_mcp_server), "/tools/mcp/market": ("GET", self.get_mcp_markets), + "/tools/mcp/test": ("POST", self.test_mcp_connection), } self.register_routes() self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools @@ -132,12 +133,19 @@ class ToolsRoute(Route): config["mcpServers"][name] = server_config if self.save_mcp_config(config): - # 动态初始化新MCP客户端 - await self.tool_mgr.mcp_service_queue.put({ - "type": "init", - "name": name, - "cfg": config["mcpServers"][name], - }) + try: + await self.tool_mgr.enable_mcp_server( + name, server_config, timeout=30 + ) + except TimeoutError: + return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__ + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"启用 MCP 服务器 {name} 失败: {str(e)}") + .__dict__ + ) return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__ else: return Response().error("保存配置失败").__dict__ @@ -193,31 +201,55 @@ class ToolsRoute(Route): if self.save_mcp_config(config): # 处理MCP客户端状态变化 if active: - # 如果要激活服务器或者配置已更改 if name in self.tool_mgr.mcp_client_dict or not only_update_active: - await self.tool_mgr.mcp_service_queue.put({ - "type": "terminate", - "name": name, - }) - await self.tool_mgr.mcp_service_queue.put({ - "type": "init", - "name": name, - "cfg": config["mcpServers"][name], - }) - else: - # 客户端不存在,初始化 - await self.tool_mgr.mcp_service_queue.put({ - "type": "init", - "name": name, - "cfg": config["mcpServers"][name], - }) + try: + await self.tool_mgr.disable_mcp_server(name, timeout=10) + except TimeoutError as e: + return ( + Response() + .error(f"启用前停用 MCP 服务器时 {name} 超时: {str(e)}") + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"启用前停用 MCP 服务器时 {name} 失败: {str(e)}") + .__dict__ + ) + try: + await self.tool_mgr.enable_mcp_server( + name, config["mcpServers"][name], timeout=30 + ) + except TimeoutError: + return ( + Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"启用 MCP 服务器 {name} 失败: {str(e)}") + .__dict__ + ) else: # 如果要停用服务器 if name in self.tool_mgr.mcp_client_dict: - self.tool_mgr.mcp_service_queue.put_nowait({ - "type": "terminate", - "name": name, - }) + try: + await self.tool_mgr.disable_mcp_server(name, timeout=10) + except TimeoutError: + return ( + Response() + .error(f"停用 MCP 服务器 {name} 超时。") + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"停用 MCP 服务器 {name} 失败: {str(e)}") + .__dict__ + ) return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__ else: @@ -239,17 +271,23 @@ class ToolsRoute(Route): if name not in config["mcpServers"]: return Response().error(f"服务器 {name} 不存在").__dict__ - # 删除服务器配置 del config["mcpServers"][name] if self.save_mcp_config(config): - # 关闭并删除MCP客户端 if name in self.tool_mgr.mcp_client_dict: - self.tool_mgr.mcp_service_queue.put_nowait({ - "type": "terminate", - "name": name, - }) - + try: + await self.tool_mgr.disable_mcp_server(name, timeout=10) + except TimeoutError: + return ( + Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return ( + Response() + .error(f"停用 MCP 服务器 {name} 失败: {str(e)}") + .__dict__ + ) return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__ else: return Response().error("保存配置失败").__dict__ @@ -281,3 +319,20 @@ class ToolsRoute(Route): except Exception as _: logger.error(traceback.format_exc()) return Response().error("获取市场数据失败").__dict__ + + async def test_mcp_connection(self): + """ + 测试 MCP 服务器连接 + """ + try: + server_data = await request.json + config = server_data.get("mcp_server_config", None) + + tools_name = await self.tool_mgr.test_mcp_server_connection(config) + return ( + Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__ + ) + + except Exception as e: + logger.error(traceback.format_exc()) + return Response().error(f"测试 MCP 连接失败: {str(e)}").__dict__ diff --git a/dashboard/src/components/shared/ItemCard.vue b/dashboard/src/components/shared/ItemCard.vue index ff790cb7b..6152c531f 100644 --- a/dashboard/src/components/shared/ItemCard.vue +++ b/dashboard/src/components/shared/ItemCard.vue @@ -9,6 +9,8 @@ hide-details density="compact" :model-value="getItemEnabled()" + :loading="loading" + :disabled="loading" v-bind="props" @update:model-value="toggleEnabled" > @@ -77,6 +79,10 @@ export default { bglogo: { type: String, default: null + }, + loading: { + type: Boolean, + default: false } }, emits: ['toggle-enabled', 'delete', 'edit'], diff --git a/dashboard/src/components/shared/ProxySelector.vue b/dashboard/src/components/shared/ProxySelector.vue new file mode 100644 index 000000000..d45a0f520 --- /dev/null +++ b/dashboard/src/components/shared/ProxySelector.vue @@ -0,0 +1,152 @@ + + + + + + \ No newline at end of file diff --git a/dashboard/src/i18n/locales/en-US/features/tool-use.json b/dashboard/src/i18n/locales/en-US/features/tool-use.json index fad67a0d5..96c4760e8 100644 --- a/dashboard/src/i18n/locales/en-US/features/tool-use.json +++ b/dashboard/src/i18n/locales/en-US/features/tool-use.json @@ -15,7 +15,9 @@ "buttons": { "refresh": "Refresh", "add": "Add Server", - "useTemplate": "Use Template" + "useTemplateStdio": "Stdio Template", + "useTemplateStreamableHttp": "Streamable HTTP Template", + "useTemplateSse": "SSE Template" }, "empty": "No MCP servers available, click Add Server to add one", "status": { @@ -28,8 +30,7 @@ "functionTools": { "title": "Function Tools", "buttons": { - "expand": "Expand", - "collapse": "Collapse" + "view": "View Tools" }, "search": "Search function tools", "empty": "No function tools available", @@ -68,10 +69,6 @@ "enable": "Enable Server", "config": "Server Configuration" }, - "configNotes": { - "note1": "1. Some MCP servers may require filling in `API_KEY` or `TOKEN` information in env according to their requirements, please check if filled.", - "note2": "2. When url parameter is specified in configuration: if `transport` parameter is also specified as `streamable_http`, Streamable HTTP is used, otherwise SSE connection is used." - }, "errors": { "configEmpty": "Configuration cannot be empty", "jsonFormat": "JSON format error: {error}", @@ -79,7 +76,8 @@ }, "buttons": { "cancel": "Cancel", - "save": "Save" + "save": "Save", + "testConnection": "Test Connection" } }, "serverDetail": { diff --git a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json index f44a16d59..61b8691bc 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json +++ b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json @@ -15,7 +15,9 @@ "buttons": { "refresh": "刷新", "add": "新增服务器", - "useTemplate": "使用模板" + "useTemplateStdio": "Stdio 模板", + "useTemplateStreamableHttp": "Streamable HTTP 模板", + "useTemplateSse": "SSE 模板" }, "empty": "暂无 MCP 服务器,点击 新增服务器 添加", "status": { @@ -28,8 +30,7 @@ "functionTools": { "title": "函数工具", "buttons": { - "expand": "展开", - "collapse": "收起" + "view": "查看工具" }, "search": "搜索函数工具", "empty": "没有可用的函数工具", @@ -68,10 +69,6 @@ "enable": "启用服务器", "config": "服务器配置" }, - "configNotes": { - "note1": "1. 某些 MCP 服务器可能需要按照其要求在 env 中填充 `API_KEY` 或 `TOKEN` 等信息,请注意检查是否填写。", - "note2": "2. 当配置中指定 url 参数时:如果还同时指定 `transport` 参数的值为 `streamable_http`,则使用 Steamable HTTP,否则使用 SSE 连接。" - }, "errors": { "configEmpty": "配置不能为空", "jsonFormat": "JSON 格式错误: {error}", @@ -79,7 +76,8 @@ }, "buttons": { "cancel": "取消", - "save": "保存" + "save": "保存", + "testConnection": "测试连接" } }, "serverDetail": { diff --git a/dashboard/src/theme/DarkTheme.ts b/dashboard/src/theme/DarkTheme.ts index 9899fcfff..177bee39c 100644 --- a/dashboard/src/theme/DarkTheme.ts +++ b/dashboard/src/theme/DarkTheme.ts @@ -36,12 +36,13 @@ const PurpleThemeDark: ThemeTypes = { gray100: '#cccccccc', primary200: '#90caf9', secondary200: '#b39ddb', - background: '#111111', + background: '#1d1d1d', overlay: '#111111aa', codeBg: '#282833', preBg: 'rgb(23, 23, 23)', code: '#ffffffdd', chatMessageBubble: '#2d2e30', + mcpCardBg: '#2a2a2a', } }; diff --git a/dashboard/src/theme/LightTheme.ts b/dashboard/src/theme/LightTheme.ts index 03630523f..b8fdec259 100644 --- a/dashboard/src/theme/LightTheme.ts +++ b/dashboard/src/theme/LightTheme.ts @@ -27,7 +27,7 @@ const PurpleTheme: ThemeTypes = { borderLight: '#d0d0d0', border: '#d0d0d0', inputBorder: '#787878', - containerBg: '#f7f1f6', + containerBg: '#f9fafcf4', surface: '#fff', 'on-surface-variant': '#fff', facebook: '#4267b2', @@ -36,12 +36,13 @@ const PurpleTheme: ThemeTypes = { gray100: '#fafafacc', primary200: '#90caf9', secondary200: '#b39ddb', - background: '#f9fafcf4', + background: '#ffffff', overlay: '#ffffffaa', codeBg: '#ececec', preBg: 'rgb(249, 249, 249)', code: 'rgb(13, 13, 13)', chatMessageBubble: '#e7ebf4', + mcpCardBg: '#f7f2f9', } }; diff --git a/dashboard/src/types/themeTypes/ThemeType.ts b/dashboard/src/types/themeTypes/ThemeType.ts index b18ee3dc5..8d2760044 100644 --- a/dashboard/src/types/themeTypes/ThemeType.ts +++ b/dashboard/src/types/themeTypes/ThemeType.ts @@ -37,5 +37,6 @@ export type ThemeTypes = { preBg?: string; code?: string; chatMessageBubble?: string; + mcpCardBg?: string; }; }; diff --git a/dashboard/src/views/ChatPage.vue b/dashboard/src/views/ChatPage.vue index 099cf09c0..ad75cbd3f 100644 --- a/dashboard/src/views/ChatPage.vue +++ b/dashboard/src/views/ChatPage.vue @@ -226,6 +226,9 @@
+ + @@ -668,34 +671,44 @@ export default { }; }, + async processAndUploadImage(file) { + const formData = new FormData(); + formData.append('file', file); + + try { + const response = await axios.post('/api/chat/post_image', formData, { + headers: { + 'Content-Type': 'multipart/form-data' + } + }); + + const img = response.data.data.filename; + this.stagedImagesName.push(img); // Store just the filename + this.stagedImagesUrl.push(URL.createObjectURL(file)); // Create a blob URL for immediate display + + } catch (err) { + console.error('Error uploading image:', err); + } + }, + async handlePaste(event) { console.log('Pasting image...'); const items = event.clipboardData.items; for (let i = 0; i < items.length; i++) { if (items[i].type.indexOf('image') !== -1) { const file = items[i].getAsFile(); - const formData = new FormData(); - formData.append('file', file); - - try { - const response = await axios.post('/api/chat/post_image', formData, { - headers: { - 'Content-Type': 'multipart/form-data' - } - }); - - const img = response.data.data.filename; - this.stagedImagesName.push(img); // Store just the filename - this.stagedImagesUrl.push(URL.createObjectURL(file)); // Create a blob URL for immediate display - - } catch (err) { - console.error('Error uploading image:', err); - } + this.processAndUploadImage(file); } } }, removeImage(index) { + // Revoke the blob URL to prevent memory leaks + const urlToRevoke = this.stagedImagesUrl[index]; + if (urlToRevoke && urlToRevoke.startsWith('blob:')) { + URL.revokeObjectURL(urlToRevoke); + } + this.stagedImagesName.splice(index, 1); this.stagedImagesUrl.splice(index, 1); }, @@ -703,6 +716,21 @@ export default { clearMessage() { this.prompt = ''; }, + + triggerImageInput() { + this.$refs.imageInput.click(); + }, + + handleFileSelect(event) { + const files = event.target.files; + if (files) { + for (const file of files) { + this.processAndUploadImage(file); + } + } + // Reset the input value to allow selecting the same file again + event.target.value = ''; + }, getConversations() { axios.get('/api/chat/conversations').then(response => { this.conversations = response.data.data; @@ -846,33 +874,42 @@ export default { // URL is already updated in newConversation method } + // 保存当前要发送的数据到临时变量 + const promptToSend = this.prompt.trim(); + const imageNamesToSend = [...this.stagedImagesName]; + const audioNameToSend = this.stagedAudioUrl; + + // 立即清空输入和附件预览 + this.prompt = ''; + this.stagedImagesName = []; + this.stagedImagesUrl = []; + this.stagedAudioUrl = ""; + // Create a message object with actual URLs for display const userMessage = { type: 'user', - message: this.prompt.trim(), // 使用 trim() 去除前后空格 + message: promptToSend, image_url: [], audio_url: null }; // Convert image filenames to blob URLs for display - if (this.stagedImagesName.length > 0) { - for (let i = 0; i < this.stagedImagesName.length; i++) { - // If it's just a filename, get the blob URL - if (!this.stagedImagesName[i].startsWith('blob:')) { - const imgUrl = await this.getMediaFile(this.stagedImagesName[i]); - userMessage.image_url.push(imgUrl); - } else { - userMessage.image_url.push(this.stagedImagesName[i]); + if (imageNamesToSend.length > 0) { + const imagePromises = imageNamesToSend.map(name => { + if (!name.startsWith('blob:')) { + return this.getMediaFile(name); } - } + return Promise.resolve(name); + }); + userMessage.image_url = await Promise.all(imagePromises); } // Convert audio filename to blob URL for display - if (this.stagedAudioUrl) { - if (!this.stagedAudioUrl.startsWith('blob:')) { - userMessage.audio_url = await this.getMediaFile(this.stagedAudioUrl); + if (audioNameToSend) { + if (!audioNameToSend.startsWith('blob:')) { + userMessage.audio_url = await this.getMediaFile(audioNameToSend); } else { - userMessage.audio_url = this.stagedAudioUrl; + userMessage.audio_url = audioNameToSend; } } @@ -885,8 +922,6 @@ export default { const selection = this.$refs.providerModelSelector?.getCurrentSelection(); const selectedProviderId = selection?.providerId || ''; const selectedModelName = selection?.modelName || ''; - let prompt = this.prompt.trim(); - this.prompt = ''; // 清空输入框 try { const response = await fetch('/api/chat/send', { @@ -896,10 +931,10 @@ export default { 'Authorization': 'Bearer ' + localStorage.getItem('token') }, body: JSON.stringify({ - message: prompt, + message: promptToSend, conversation_id: this.currCid, - image_url: this.stagedImagesName, - audio_url: this.stagedAudioUrl ? [this.stagedAudioUrl] : [], + image_url: imageNamesToSend, + audio_url: audioNameToSend ? [audioNameToSend] : [], selected_provider: selectedProviderId, selected_model: selectedModelName }) @@ -1003,11 +1038,7 @@ export default { } } - // Clear input after successful send - this.prompt = ''; - this.stagedImagesName = []; - this.stagedImagesUrl = []; - this.stagedAudioUrl = ""; + // Input and attachments are already cleared this.loadingChat = false; // get the latest conversations diff --git a/dashboard/src/views/ExtensionPage.vue b/dashboard/src/views/ExtensionPage.vue index 892034add..a3dbc9c37 100644 --- a/dashboard/src/views/ExtensionPage.vue +++ b/dashboard/src/views/ExtensionPage.vue @@ -3,6 +3,7 @@ import ExtensionCard from '@/components/shared/ExtensionCard.vue'; import AstrBotConfig from '@/components/shared/AstrBotConfig.vue'; import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue'; import ReadmeDialog from '@/components/shared/ReadmeDialog.vue'; +import ProxySelector from '@/components/shared/ProxySelector.vue'; import axios from 'axios'; import { useCommonStore } from '@/stores/common'; import { useI18n, useModuleI18n } from '@/i18n/composables'; @@ -29,12 +30,12 @@ const extension_config = reactive({ config: {} }); const pluginMarketData = ref([]); - const loadingDialog = reactive({ - show: false, - title: "", - statusCode: 0, // 0: loading, 1: success, 2: error, - result: "" - }); +const loadingDialog = reactive({ + show: false, + title: "", + statusCode: 0, // 0: loading, 1: success, 2: error, + result: "" +}); const showPluginInfoDialog = ref(false); const selectedPlugin = ref({}); const curr_namespace = ref(""); @@ -184,8 +185,8 @@ const checkUpdate = () => { if (matchedPlugin) { extension.online_version = matchedPlugin.version; - extension.has_update = extension.version !== matchedPlugin.version && - matchedPlugin.version !== tm('status.unknown'); + extension.has_update = extension.version !== matchedPlugin.version && + matchedPlugin.version !== tm('status.unknown'); } else { extension.has_update = false; } @@ -622,27 +623,12 @@ onMounted(async () => { - + - + @@ -678,33 +664,32 @@ onMounted(async () => { mdi-plus {{ tm('buttons.install') }} - - - - - - + + + + + + @@ -726,7 +711,8 @@ onMounted(async () => {
{{ item.name }}
- {{ tm('status.system') }} + {{ tm('status.system') + }}
@@ -847,8 +833,8 @@ onMounted(async () => { - +
@@ -865,8 +851,8 @@ onMounted(async () => {

{{ tm('market.allPlugins') }}

- +
@@ -904,7 +890,8 @@ onMounted(async () => { \ No newline at end of file diff --git a/dashboard/src/views/ToolUsePage.vue b/dashboard/src/views/ToolUsePage.vue index cc70c415b..6060088a4 100644 --- a/dashboard/src/views/ToolUsePage.vue +++ b/dashboard/src/views/ToolUsePage.vue @@ -20,9 +20,16 @@

- - {{ tm('mcpServers.buttons.add') }} - +
+ + {{ tm('functionTools.buttons.view') }}({{ tools.length }}) + + + {{ tm('mcpServers.buttons.add') }} + +
@@ -44,169 +51,79 @@ - - - mdi-server - {{ tm('mcpServers.title') }} - - - {{ tm('mcpServers.buttons.refresh') }} - - - {{ tm('mcpServers.buttons.add') }} - - - +
+ mdi-server-off +

{{ tm('mcpServers.empty') }}

+
- -
- mdi-server-off -

{{ tm('mcpServers.empty') }}

-
+ + + + + + + - - - -

- mdi-information - {{ tm('functionTools.description') }} -

-

{{ tool.function.description }}

- - -
- mdi-code-brackets -

{{ tm('functionTools.noParameters') }}

-
-
-
-
- - -
- - - - @@ -216,9 +133,9 @@ mdi-store {{ tm('marketplace.title') }} - + {{ tm('marketplace.buttons.refresh') }} @@ -256,7 +173,8 @@
mdi-tools - {{ tm('marketplace.status.availableTools', { count: server.tools ? server.tools.length : 0 }) }} + {{ tm('marketplace.status.availableTools', { count: server.tools ? server.tools.length : 0 }) + }}
@@ -310,31 +228,25 @@ - - - +
{{ tm('dialogs.addServer.fields.config') }} - - -
- {{ tm('tooltip.serverConfig') }} -
-
- - {{ tm('mcpServers.buttons.useTemplate') }} + + {{ tm('mcpServers.buttons.useTemplateStdio') }} + + + {{ tm('mcpServers.buttons.useTemplateStreamableHttp') }} + + + {{ tm('mcpServers.buttons.useTemplateSse') }}
- {{ tm('dialogs.addServer.configNotes.note1') }} -
- {{ tm('dialogs.addServer.configNotes.note2') }} -
+
+ {{ addServerDialogMessage }} +
- + {{ tm('dialogs.addServer.buttons.cancel') }} + + {{ tm('dialogs.addServer.buttons.testConnection') }} + {{ tm('dialogs.addServer.buttons.save') }} @@ -469,6 +386,106 @@ + + + + + {{ tm('functionTools.title') }} + {{ tools.length }} + + + +
+
+ mdi-api-off +

{{ tm('functionTools.empty') }}

+
+ +
+ + + + + + + +
+ + {{ tool.function.name.includes(':') ? 'mdi-server-network' : 'mdi-function-variant' }} + + + {{ formatToolName(tool.function.name) }} + +
+
+ + {{ tool.function.description }} + +
+
+ + + + +

+ mdi-information + {{ tm('functionTools.description') }} +

+

{{ tool.function.description }}

+ + +
+ mdi-code-brackets +

{{ tm('functionTools.noParameters') }}

+
+
+
+
+
+
+
+
+
+
+ + + + + {{ tm('dialogs.serverDetail.buttons.close') }} + + +
+
+ @@ -504,8 +521,12 @@ export default { tools: [], showMcpServerDialog: false, showServerDetailDialog: false, + addServerDialogMessage: "", + showToolsDialog: false, showTools: true, loading: false, + loadingGettingServers: false, + mcpServerUpdateLoaders: {}, // record loading state for each server update isEditMode: false, serverConfigJson: '', jsonError: null, @@ -575,10 +596,10 @@ export default { if (!this.marketplaceSearch.trim()) { return this.marketplaceServers; } - + const searchTerm = this.marketplaceSearch.toLowerCase(); - return this.marketplaceServers.filter(server => - server.name.toLowerCase().includes(searchTerm) || + return this.marketplaceServers.filter(server => + server.name.toLowerCase().includes(searchTerm) || (server.name_h && server.name_h.toLowerCase().includes(searchTerm)) || (server.description && server.description.toLowerCase().includes(searchTerm)) ); @@ -618,17 +639,21 @@ export default { }, getServers() { - this.loading = true + this.loadingGettingServers = true; axios.get('/api/tools/mcp/servers') .then(response => { this.mcpServers = response.data.data || []; + this.mcpServers.forEach(server => { + // Ensure each server has a loader state + if (!this.mcpServerUpdateLoaders[server.name]) { + this.mcpServerUpdateLoaders[server.name] = false; + } + }); }) .catch(error => { this.showError(this.tm('messages.getServersError', { error: error.message })); }).finally(() => { - setTimeout(() => { - this.loading = false; - }, 500); + this.loadingGettingServers = false; }); }, @@ -658,14 +683,28 @@ export default { } }, - setConfigTemplate() { - // 设置一个基本的配置模板 - const template = { - command: "python", - args: ["-m", "your_module"], - // 可以添加其他 MCP 支持的配置项 - }; - + setConfigTemplate(type = 'stdio') { + let template = {}; + if (type === 'streamable_http') { + template = { + transport: "streamable_http", + url: "your mcp server url", + headers: {}, + timeout: 30, + }; + } else if (type === 'sse') { + template = { + transport: "sse", + url: "your mcp server url", + headers: {}, + timeout: 30, + }; + } else { + template = { + command: "python", + args: ["-m", "your_module"], + }; + } this.serverConfigJson = JSON.stringify(template, null, 2); }, @@ -693,6 +732,7 @@ export default { .then(response => { this.loading = false; this.showMcpServerDialog = false; + this.addServerDialogMessage = ""; this.getServers(); this.getTools(); this.showSuccess(response.data.message || this.tm('messages.saveSuccess')); @@ -753,6 +793,7 @@ export default { updateServerStatus(server) { // 切换服务器状态 + this.mcpServerUpdateLoaders[server.name] = true; server.active = !server.active; axios.post('/api/tools/mcp/update', server) .then(response => { @@ -761,16 +802,48 @@ export default { }) .catch(error => { this.showError(this.tm('messages.updateError', { error: error.response?.data?.message || error.message })); - // 回滚状态 server.active = !server.active; + }) + .finally(() => { + this.mcpServerUpdateLoaders[server.name] = false; }); }, closeServerDialog() { this.showMcpServerDialog = false; + this.addServerDialogMessage = ''; this.resetForm(); }, + testServerConnection() { + if (!this.validateJson()) { + return; + } + + this.loading = true; + + let configObj; + try { + configObj = JSON.parse(this.serverConfigJson); + } catch (e) { + this.loading = false; + this.showError(this.tm('dialogs.addServer.errors.jsonParse', { error: e.message })); + return; + } + + axios.post('/api/tools/mcp/test', { + "mcp_server_config": configObj, + }) + .then(response => { + this.loading = false; + this.addServerDialogMessage = `${response.data.message} (tools: ${response.data.data})`; + }) + .catch(error => { + this.loading = false; + this.showError(this.tm('messages.testError', { error: error.response?.data?.message || error.message })); + }); + }, + resetForm() { this.currentServer = { name: '', @@ -939,7 +1012,7 @@ export default { .monaco-container { border: 1px solid rgba(0, 0, 0, 0.1); - border-radius: 4px; + border-radius: 8px; height: 300px; margin-top: 4px; overflow: hidden; diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index fcb34e250..404d65f85 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -1242,6 +1242,10 @@ UID: {user_id} 此 ID 可用于设置管理员。 logger.error(traceback.format_exc()) logger.error(f"主动回复失败: {e}") + @filter.on_decorating_result() + async def decorate_result(self, event: AstrMessageEvent): + logger.debug("Decorating result for event: %s", event) + @filter.on_llm_request() async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest): """在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""