Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 23f8d194ab | |||
| 65cceb2f21 | |||
| c4c356887b |
@@ -54,6 +54,14 @@ async def run_agent(
|
||||
return
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
|
||||
astr_event.trace.record(
|
||||
"agent_tool_result",
|
||||
tool_result=msg_chain.get_plain_text(
|
||||
with_other_comps_mark=True
|
||||
),
|
||||
)
|
||||
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
await astr_event.send(msg_chain)
|
||||
@@ -67,12 +75,22 @@ async def run_agent(
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
|
||||
tool_info = None
|
||||
|
||||
if resp.data["chain"].chain:
|
||||
json_comp = resp.data["chain"].chain[0]
|
||||
if isinstance(json_comp, Json):
|
||||
tool_info = json_comp.data
|
||||
astr_event.trace.record(
|
||||
"agent_tool_call",
|
||||
tool_name=tool_info if tool_info else "unknown",
|
||||
)
|
||||
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(resp.data["chain"])
|
||||
elif show_tool_use:
|
||||
json_comp = resp.data["chain"].chain[0]
|
||||
if isinstance(json_comp, Json):
|
||||
m = f"🔨 调用工具: {json_comp.data.get('name')}"
|
||||
if tool_info:
|
||||
m = f"🔨 调用工具: {tool_info.get('name', 'unknown')}"
|
||||
else:
|
||||
m = "🔨 调用工具..."
|
||||
chain = MessageChain(type="tool_call").message(m)
|
||||
|
||||
@@ -202,6 +202,7 @@ DEFAULT_CONFIG = {
|
||||
"log_file_enable": False,
|
||||
"log_file_path": "logs/astrbot.log",
|
||||
"log_file_max_mb": 20,
|
||||
"trace_enable": False,
|
||||
"trace_log_enable": False,
|
||||
"trace_log_path": "logs/astrbot.trace.log",
|
||||
"trace_log_max_mb": 20,
|
||||
|
||||
@@ -54,7 +54,6 @@ class EventBus:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
|
||||
"""
|
||||
event.trace.record("event_dispatch", config_name=conf_name)
|
||||
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||
if event.get_sender_name():
|
||||
logger.info(
|
||||
|
||||
@@ -9,6 +9,7 @@ from astrbot.core.message.components import (
|
||||
AtAll,
|
||||
BaseMessageComponent,
|
||||
Image,
|
||||
Json,
|
||||
Plain,
|
||||
)
|
||||
|
||||
@@ -117,9 +118,26 @@ class MessageChain:
|
||||
self.use_t2i_ = use_t2i
|
||||
return self
|
||||
|
||||
def get_plain_text(self) -> str:
|
||||
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
|
||||
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
||||
def get_plain_text(self, with_other_comps_mark: bool = False) -> str:
|
||||
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。
|
||||
|
||||
Args:
|
||||
with_other_comps_mark (bool): 是否在纯文本中标记其他组件的位置
|
||||
"""
|
||||
if not with_other_comps_mark:
|
||||
return " ".join(
|
||||
[comp.text for comp in self.chain if isinstance(comp, Plain)]
|
||||
)
|
||||
else:
|
||||
texts = []
|
||||
for comp in self.chain:
|
||||
if isinstance(comp, Plain):
|
||||
texts.append(comp.text)
|
||||
elif isinstance(comp, Json):
|
||||
texts.append(f"{comp.data}")
|
||||
else:
|
||||
texts.append(f"[{comp.__class__.__name__}]")
|
||||
return " ".join(texts)
|
||||
|
||||
def squash_plain(self):
|
||||
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
|
||||
|
||||
@@ -85,6 +85,4 @@ class PipelineScheduler:
|
||||
if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent):
|
||||
await event.send(None)
|
||||
|
||||
event.trace.record("event_end")
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
@@ -73,9 +73,6 @@ class AstrMessageEvent(abc.ABC):
|
||||
self.span = self.trace
|
||||
"""事件级 TraceSpan(别名: span)"""
|
||||
|
||||
self.trace.record("umo", umo=self.unified_msg_origin)
|
||||
self.trace.record("event_created", created_at=self.created_at)
|
||||
|
||||
self._has_send_oper = False
|
||||
"""在此次事件中是否有过至少一次发送消息的操作"""
|
||||
self.call_llm = False
|
||||
|
||||
@@ -4,7 +4,6 @@ import asyncio
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import urllib.parse
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
@@ -213,93 +212,15 @@ class FunctionToolManager:
|
||||
open(mcp_json_file, encoding="utf-8"),
|
||||
)["mcpServers"]
|
||||
|
||||
tasks: dict[str, asyncio.Task] = {}
|
||||
ready_futures: dict[str, asyncio.Future] = {}
|
||||
|
||||
for name, cfg in mcp_server_json_obj.items():
|
||||
for name in mcp_server_json_obj:
|
||||
cfg = mcp_server_json_obj[name]
|
||||
if cfg.get("active", True):
|
||||
event = asyncio.Event()
|
||||
ready_future = asyncio.get_running_loop().create_future()
|
||||
task = asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(
|
||||
name,
|
||||
cfg,
|
||||
event,
|
||||
ready_future,
|
||||
),
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, cfg, event),
|
||||
)
|
||||
tasks[name] = task
|
||||
ready_futures[name] = ready_future
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
if ready_futures:
|
||||
logger.info(f"等待 {len(ready_futures)} 个 MCP 服务初始化...")
|
||||
|
||||
_, pending_futures = await asyncio.wait(
|
||||
ready_futures.values(),
|
||||
timeout=20.0,
|
||||
)
|
||||
|
||||
pending_services = {
|
||||
name
|
||||
for name, ready_future in ready_futures.items()
|
||||
if ready_future in pending_futures
|
||||
}
|
||||
|
||||
if pending_services:
|
||||
logger.warning(
|
||||
"MCP 服务初始化超时(20秒),部分服务可能未完全加载。"
|
||||
"建议检查 MCP 服务器配置和网络连接。"
|
||||
)
|
||||
for name in pending_services:
|
||||
task = tasks[name]
|
||||
task.cancel()
|
||||
await asyncio.gather(
|
||||
*(tasks[name] for name in pending_services),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
success_count = 0
|
||||
failed_services: list[str] = []
|
||||
|
||||
for name, ready_future in ready_futures.items():
|
||||
if name in pending_services:
|
||||
logger.error(f"MCP 服务 {name} 初始化超时")
|
||||
failed_services.append(name)
|
||||
self.mcp_client_event.pop(name, None)
|
||||
continue
|
||||
|
||||
if ready_future.cancelled():
|
||||
logger.error(f"MCP 服务 {name} 初始化已取消")
|
||||
failed_services.append(name)
|
||||
self.mcp_client_event.pop(name, None)
|
||||
continue
|
||||
|
||||
exc = ready_future.exception()
|
||||
if exc is not None:
|
||||
logger.error(f"MCP 服务 {name} 初始化失败: {exc}")
|
||||
# 仅在 debug 级别输出完整配置,避免在生产日志中泄露敏感信息
|
||||
cfg = mcp_server_json_obj.get(name, {})
|
||||
if "command" in cfg:
|
||||
logger.debug(f" 命令: {cfg['command']}")
|
||||
if "args" in cfg:
|
||||
logger.debug(f" 参数: {cfg['args']}")
|
||||
elif "url" in cfg:
|
||||
parsed = urllib.parse.urlparse(cfg["url"])
|
||||
logger.debug(f" 主机: {parsed.scheme}://{parsed.netloc}")
|
||||
failed_services.append(name)
|
||||
self.mcp_client_event.pop(name, None)
|
||||
else:
|
||||
success_count += 1
|
||||
|
||||
if failed_services:
|
||||
logger.warning(
|
||||
f"以下 MCP 服务初始化失败: {', '.join(failed_services)}。"
|
||||
f"请检查配置文件 mcp_server.json 和服务器可用性。"
|
||||
)
|
||||
|
||||
logger.info(f"MCP 服务初始化完成: {success_count}/{len(tasks)} 成功")
|
||||
|
||||
async def _init_mcp_client_task_wrapper(
|
||||
self,
|
||||
name: str,
|
||||
@@ -308,29 +229,20 @@ class FunctionToolManager:
|
||||
ready_future: asyncio.Future | None = None,
|
||||
) -> None:
|
||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||
initialized = False
|
||||
try:
|
||||
await self._init_mcp_client(name, cfg)
|
||||
initialized = True
|
||||
tools = await self.mcp_client_dict[name].list_tools_and_save()
|
||||
if ready_future and not ready_future.done():
|
||||
ready_future.set_result(True)
|
||||
# tell the caller we are ready
|
||||
ready_future.set_result(tools)
|
||||
await event.wait()
|
||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||
except asyncio.CancelledError:
|
||||
if ready_future and not ready_future.done():
|
||||
ready_future.set_exception(
|
||||
asyncio.TimeoutError("MCP 客户端初始化超时"),
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
|
||||
if ready_future and not ready_future.done():
|
||||
ready_future.set_exception(e)
|
||||
if not initialized:
|
||||
# 初始化阶段失败,记录错误并向上抛出让 task.exception() 捕获
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
|
||||
raise
|
||||
# 初始化已成功,此处异常来自 event.wait() 被取消,属于正常终止流程
|
||||
finally:
|
||||
# 无论如何都能清理
|
||||
await self._terminate_mcp_client(name)
|
||||
|
||||
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||
@@ -428,22 +340,22 @@ class FunctionToolManager:
|
||||
if not event:
|
||||
event = asyncio.Event()
|
||||
if not ready_future:
|
||||
ready_future = asyncio.get_running_loop().create_future()
|
||||
ready_future = asyncio.Future()
|
||||
if name in self.mcp_client_dict:
|
||||
return
|
||||
init_task = asyncio.create_task(
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, config, event, ready_future),
|
||||
)
|
||||
try:
|
||||
await asyncio.wait_for(ready_future, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
init_task.cancel()
|
||||
await asyncio.gather(init_task, return_exceptions=True)
|
||||
self.mcp_client_event.pop(name, None)
|
||||
raise
|
||||
else:
|
||||
finally:
|
||||
self.mcp_client_event[name] = event
|
||||
|
||||
if ready_future.done() and ready_future.exception():
|
||||
exc = ready_future.exception()
|
||||
if exc is not None:
|
||||
raise exc
|
||||
|
||||
async def disable_mcp_server(
|
||||
self,
|
||||
name: str | None = None,
|
||||
|
||||
@@ -274,8 +274,8 @@ class ProviderManager:
|
||||
if not self.curr_tts_provider_inst and self.tts_provider_insts:
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
|
||||
# 初始化 MCP Client 连接(等待完成以确保工具可用)
|
||||
await self.llm_tools.init_mcp_clients()
|
||||
# 初始化 MCP Client 连接
|
||||
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
|
||||
|
||||
def dynamic_import_provider(self, type: str):
|
||||
"""动态导入提供商适配器模块
|
||||
|
||||
@@ -50,6 +50,10 @@ class TraceSpan:
|
||||
self.started_at = time.time()
|
||||
|
||||
def record(self, action: str, **fields: Any) -> None:
|
||||
# Check if trace recording is enabled
|
||||
if not astrbot_config.get("trace_enable", True):
|
||||
return
|
||||
|
||||
payload = {
|
||||
"type": "trace",
|
||||
"level": "TRACE",
|
||||
|
||||
@@ -31,6 +31,16 @@ class LogRoute(Route):
|
||||
view_func=self.log_history,
|
||||
methods=["GET"],
|
||||
)
|
||||
self.app.add_url_rule(
|
||||
"/api/trace/settings",
|
||||
view_func=self.get_trace_settings,
|
||||
methods=["GET"],
|
||||
)
|
||||
self.app.add_url_rule(
|
||||
"/api/trace/settings",
|
||||
view_func=self.update_trace_settings,
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
async def _replay_cached_logs(
|
||||
self, last_event_id: str
|
||||
@@ -106,3 +116,29 @@ class LogRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(f"获取日志历史失败: {e}")
|
||||
return Response().error(f"获取日志历史失败: {e}").__dict__
|
||||
|
||||
async def get_trace_settings(self):
|
||||
"""获取 Trace 设置"""
|
||||
try:
|
||||
trace_enable = self.config.get("trace_enable", True)
|
||||
return Response().ok(data={"trace_enable": trace_enable}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Trace 设置失败: {e}")
|
||||
return Response().error(f"获取 Trace 设置失败: {e}").__dict__
|
||||
|
||||
async def update_trace_settings(self):
|
||||
"""更新 Trace 设置"""
|
||||
try:
|
||||
data = await request.json
|
||||
if data is None:
|
||||
return Response().error("请求数据为空").__dict__
|
||||
|
||||
trace_enable = data.get("trace_enable")
|
||||
if trace_enable is not None:
|
||||
self.config["trace_enable"] = bool(trace_enable)
|
||||
self.config.save_config()
|
||||
|
||||
return Response().ok(message="Trace 设置已更新").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"更新 Trace 设置失败: {e}")
|
||||
return Response().error(f"更新 Trace 设置失败: {e}").__dict__
|
||||
|
||||
@@ -3,5 +3,8 @@
|
||||
"autoScroll": {
|
||||
"enabled": "Auto-scroll: On",
|
||||
"disabled": "Auto-scroll: Off"
|
||||
}
|
||||
},
|
||||
"hint": "Currently only recording partial model call paths from AstrBot main Agent. More coverage will be added.",
|
||||
"recording": "Recording",
|
||||
"paused": "Paused"
|
||||
}
|
||||
|
||||
@@ -3,5 +3,8 @@
|
||||
"autoScroll": {
|
||||
"enabled": "自动滚动:开",
|
||||
"disabled": "自动滚动:关"
|
||||
}
|
||||
},
|
||||
"hint": "当前仅记录部分 AstrBot 主 Agent 的模型调用路径,后续会不断完善。",
|
||||
"recording": "记录中",
|
||||
"paused": "已暂停"
|
||||
}
|
||||
|
||||
@@ -1,13 +1,72 @@
|
||||
<script setup>
|
||||
import TraceDisplayer from '@/components/shared/TraceDisplayer.vue';
|
||||
import { useModuleI18n } from '@/i18n/composables';
|
||||
import { ref, onMounted } from 'vue';
|
||||
import axios from 'axios';
|
||||
|
||||
const { tm } = useModuleI18n('features/trace');
|
||||
|
||||
const traceEnabled = ref(true);
|
||||
const loading = ref(false);
|
||||
const traceDisplayerKey = ref(0);
|
||||
|
||||
const fetchTraceSettings = async () => {
|
||||
try {
|
||||
const res = await axios.get('/api/trace/settings');
|
||||
if (res.data?.status === 'ok') {
|
||||
traceEnabled.value = res.data.data?.trace_enable ?? true;
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('Failed to fetch trace settings:', err);
|
||||
}
|
||||
};
|
||||
|
||||
const updateTraceSettings = async () => {
|
||||
loading.value = true;
|
||||
try {
|
||||
await axios.post('/api/trace/settings', {
|
||||
trace_enable: traceEnabled.value
|
||||
});
|
||||
// Refresh the TraceDisplayer component to reconnect SSE
|
||||
traceDisplayerKey.value += 1;
|
||||
} catch (err) {
|
||||
console.error('Failed to update trace settings:', err);
|
||||
} finally {
|
||||
loading.value = false;
|
||||
}
|
||||
};
|
||||
|
||||
onMounted(() => {
|
||||
fetchTraceSettings();
|
||||
});
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div style="height: 100%;">
|
||||
<TraceDisplayer />
|
||||
<div style="height: 100%; display: flex; flex-direction: column;">
|
||||
<div class="trace-header">
|
||||
<div class="trace-info">
|
||||
<v-icon size="small" color="info" class="mr-2">mdi-information-outline</v-icon>
|
||||
<span class="trace-hint">{{ tm('hint') }}</span>
|
||||
</div>
|
||||
<div class="trace-controls">
|
||||
<v-switch
|
||||
v-model="traceEnabled"
|
||||
:loading="loading"
|
||||
:disabled="loading"
|
||||
color="primary"
|
||||
hide-details
|
||||
density="compact"
|
||||
@update:model-value="updateTraceSettings"
|
||||
>
|
||||
<template #label>
|
||||
<span class="switch-label">{{ traceEnabled ? tm('recording') : tm('paused') }}</span>
|
||||
</template>
|
||||
</v-switch>
|
||||
</div>
|
||||
</div>
|
||||
<div style="flex: 1; min-height: 0;">
|
||||
<TraceDisplayer :key="traceDisplayerKey" />
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -19,3 +78,38 @@ export default {
|
||||
}
|
||||
};
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.trace-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 12px 16px;
|
||||
background: rgba(59, 130, 246, 0.05);
|
||||
border-bottom: 1px solid rgba(59, 130, 246, 0.1);
|
||||
border-radius: 8px 8px 0 0;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.trace-info {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.trace-hint {
|
||||
font-size: 13px;
|
||||
color: #6b7280;
|
||||
}
|
||||
|
||||
.trace-controls {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.switch-label {
|
||||
font-size: 13px;
|
||||
color: #4b5563;
|
||||
white-space: nowrap;
|
||||
}
|
||||
</style>
|
||||
|
||||
Reference in New Issue
Block a user