🎈 auto fixes by pre-commit hooks

This commit is contained in:
pre-commit-ci[bot]
2025-03-15 11:54:09 +00:00
parent d4d9a1df4c
commit 3dea60366a
3 changed files with 44 additions and 24 deletions
@@ -176,13 +176,18 @@ class LLMRequestSubStage(Stage):
llm_response.tools_call_name, llm_response.tools_call_args
):
try:
if func_tool_name.startswith('mcp:'):
_, mcp_server_name, mcp_func_name = func_tool_name.split(':')
if func_tool_name.startswith("mcp:"):
_, mcp_server_name, mcp_func_name = func_tool_name.split(
":"
)
logger.info(
f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}")
f"从mcp服务 {mcp_server_name} 调用工具函数:{mcp_func_name},参数:{func_tool_args}"
)
client = req.func_tool.mcp_client_dict[mcp_server_name]
res = await client.session.call_tool(mcp_func_name, func_tool_args)
res = await client.session.call_tool(
mcp_func_name, func_tool_args
)
if res:
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
res_event = event.plain_result(res.content[0].text)
+31 -19
View File
@@ -15,6 +15,7 @@ from anthropic import Anthropic
from ... import logger
@dataclass
class FuncTool:
"""
@@ -59,28 +60,31 @@ class MCPClient:
Args:
server_script_path: Path to the server script (.py or .js)
"""
is_python = server_script_path.endswith('.py')
is_js = server_script_path.endswith('.js')
is_python = server_script_path.endswith(".py")
is_js = server_script_path.endswith(".js")
if not (is_python or is_js):
raise ValueError("Server script must be a .py or .js file")
command = "python" if is_python else "node"
server_params = StdioServerParameters(
command=command,
args=[server_script_path],
env=None
command=command, args=[server_script_path], env=None
)
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(server_params)
)
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
self.session = await self.exit_stack.enter_async_context(
ClientSession(self.stdio, self.write)
)
await self.session.initialize()
class FuncCall:
def __init__(self) -> None:
self.func_list: List[FuncTool] = []
self.mcp_client_dict: Dict[str: MCPClient] = dict()
self.mcp_client_dict: Dict[str:MCPClient] = dict()
def empty(self) -> bool:
return len(self.func_list) == 0
@@ -154,14 +158,20 @@ class FuncCall:
mcp_json_file = os.path.join(project_root, "mcp_server.json")
if not os.path.exists(mcp_json_file):
# 配置文件不存在错误处理
logger.warning(f"mcp server config file {mcp_json_file} not found. skip init mcp client list.")
logger.warning(
f"mcp server config file {mcp_json_file} not found. skip init mcp client list."
)
return
mcp_server_json_obj = json.load(open(mcp_json_file, "r", encoding="utf-8"))
for mcp_server_name, mcp_server_script_path in mcp_server_json_obj["mcpServers"].items():
for mcp_server_name, mcp_server_script_path in mcp_server_json_obj[
"mcpServers"
].items():
if not os.path.exists(mcp_server_script_path["script_path"]):
logger.error(f"mcp server import err: Server script {mcp_server_script_path["script_path"]} not found.")
logger.error(
f"mcp server import err: Server script {mcp_server_script_path['script_path']} not found."
)
continue
mcp_client = MCPClient()
mcp_client.name = mcp_server_name
@@ -186,7 +196,7 @@ class FuncCall:
"name": f.name,
"parameters": f.parameters,
"description": f.description,
}
},
}
)
@@ -194,14 +204,16 @@ class FuncCall:
for name, client in self.mcp_client_dict.items():
responses = await client.session.list_tools()
for tool in responses.tools:
_l.append({
"type": "function",
"function": {
"name": f"mcp:{name}:{tool.name}",
"parameters": tool.inputSchema,
"description": tool.description,
_l.append(
{
"type": "function",
"function": {
"name": f"mcp:{name}:{tool.name}",
"parameters": tool.inputSchema,
"description": tool.description,
},
}
})
)
return _l
def get_func_desc_anthropic_style(self) -> list:
@@ -124,7 +124,10 @@ class ProviderOpenAIOfficial(Provider):
for tool in tools.func_list:
if tool.name == tool_call.function.name:
args = json.loads(tool_call.function.arguments)
if tool_call.function.name.startswith("mcp:") and tool_call.function.name.split(':')[1] in tools.mcp_client_dict:
if (
tool_call.function.name.startswith("mcp:")
and tool_call.function.name.split(":")[1] in tools.mcp_client_dict
):
args = json.loads(tool_call.function.arguments)
args_ls.append(args)
func_name_ls.append(tool_call.function.name)