From 3dea60366aee59d1a4f980ce7316936cac83e891 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 15 Mar 2025 11:54:09 +0000 Subject: [PATCH] :balloon: auto fixes by pre-commit hooks --- .../process_stage/method/llm_request.py | 13 +++-- astrbot/core/provider/func_tool_manager.py | 50 ++++++++++++------- .../core/provider/sources/openai_source.py | 5 +- 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index c243583a9..2050f37d6 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -176,13 +176,18 @@ class LLMRequestSubStage(Stage): llm_response.tools_call_name, llm_response.tools_call_args ): try: - if func_tool_name.startswith('mcp:'): - _, mcp_server_name, mcp_func_name = func_tool_name.split(':') + if func_tool_name.startswith("mcp:"): + _, mcp_server_name, mcp_func_name = func_tool_name.split( + ":" + ) logger.info( - f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}") + f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}" + ) client = req.func_tool.mcp_client_dict[mcp_server_name] - res = await client.session.call_tool(mcp_func_name, func_tool_args) + res = await client.session.call_tool( + mcp_func_name, func_tool_args + ) if res: # TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。 res_event = event.plain_result(res.content[0].text) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 2c7381a16..5bfabf0b5 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -15,6 +15,7 @@ from anthropic import Anthropic from ... import logger + @dataclass class FuncTool: """ @@ -59,28 +60,31 @@ class MCPClient: Args: server_script_path: Path to the server script (.py or .js) """ - is_python = server_script_path.endswith('.py') - is_js = server_script_path.endswith('.js') + is_python = server_script_path.endswith(".py") + is_js = server_script_path.endswith(".js") if not (is_python or is_js): raise ValueError("Server script must be a .py or .js file") command = "python" if is_python else "node" server_params = StdioServerParameters( - command=command, - args=[server_script_path], - env=None + command=command, args=[server_script_path], env=None ) - stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params)) + stdio_transport = await self.exit_stack.enter_async_context( + stdio_client(server_params) + ) self.stdio, self.write = stdio_transport - self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write)) + self.session = await self.exit_stack.enter_async_context( + ClientSession(self.stdio, self.write) + ) await self.session.initialize() + class FuncCall: def __init__(self) -> None: self.func_list: List[FuncTool] = [] - self.mcp_client_dict: Dict[str: MCPClient] = dict() + self.mcp_client_dict: Dict[str:MCPClient] = dict() def empty(self) -> bool: return len(self.func_list) == 0 @@ -154,14 +158,20 @@ class FuncCall: mcp_json_file = os.path.join(project_root, "mcp_server.json") if not os.path.exists(mcp_json_file): # 配置文件不存在错误处理 - logger.warning(f"mcp server config file {mcp_json_file} not found. skip init mcp client list.") + logger.warning( + f"mcp server config file {mcp_json_file} not found. skip init mcp client list." + ) return mcp_server_json_obj = json.load(open(mcp_json_file, "r", encoding="utf-8")) - for mcp_server_name, mcp_server_script_path in mcp_server_json_obj["mcpServers"].items(): + for mcp_server_name, mcp_server_script_path in mcp_server_json_obj[ + "mcpServers" + ].items(): if not os.path.exists(mcp_server_script_path["script_path"]): - logger.error(f"mcp server import err: Server script {mcp_server_script_path["script_path"]} not found.") + logger.error( + f"mcp server import err: Server script {mcp_server_script_path['script_path']} not found." + ) continue mcp_client = MCPClient() mcp_client.name = mcp_server_name @@ -186,7 +196,7 @@ class FuncCall: "name": f.name, "parameters": f.parameters, "description": f.description, - } + }, } ) @@ -194,14 +204,16 @@ class FuncCall: for name, client in self.mcp_client_dict.items(): responses = await client.session.list_tools() for tool in responses.tools: - _l.append({ - "type": "function", - "function": { - "name": f"mcp:{name}:{tool.name}", - "parameters": tool.inputSchema, - "description": tool.description, + _l.append( + { + "type": "function", + "function": { + "name": f"mcp:{name}:{tool.name}", + "parameters": tool.inputSchema, + "description": tool.description, + }, } - }) + ) return _l def get_func_desc_anthropic_style(self) -> list: diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 5da2f9263..342c2febb 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -124,7 +124,10 @@ class ProviderOpenAIOfficial(Provider): for tool in tools.func_list: if tool.name == tool_call.function.name: args = json.loads(tool_call.function.arguments) - if tool_call.function.name.startswith("mcp:") and tool_call.function.name.split(':')[1] in tools.mcp_client_dict: + if ( + tool_call.function.name.startswith("mcp:") + and tool_call.function.name.split(":")[1] in tools.mcp_client_dict + ): args = json.loads(tool_call.function.arguments) args_ls.append(args) func_name_ls.append(tool_call.function.name)