From 8a0f865af17f2d92c795493f88c9afc8a457763e Mon Sep 17 00:00:00 2001 From: Soulter <37870767+Soulter@users.noreply.github.com> Date: Thu, 18 Dec 2025 17:11:09 +0800 Subject: [PATCH] feat: enhance tool call handling and agent stats tracking and UI integration for tool calls render (#4101) * feat: enhance tool call handling and UI integration for tool calls render - Added support for tool call messages in the agent runner and webchat event handling. - Implemented JSON message component for structured tool call data. - Updated chat route to save tool call information in message history. - Enhanced frontend to display tool call details in a collapsible format, including status and results. - Introduced elapsed time tracking for ongoing tool calls in the chat interface. * fix: improve message handling in agent run utility and tool loop runner - Refactored message sending logic in `astr_agent_run_util.py` to use `msg_chain` directly for better clarity. - Added a check in `tool_loop_agent_runner.py` to ensure `tool_call_result_blocks` is not empty before yielding the last tool call result, preventing potential errors. * refactor: enhance message structure and UI for chat components - Updated message handling in `MessageList.vue` to support structured message parts, including plain text, images, audio, and files. - Improved the `Chat.vue` component styles for better visual consistency. - Refactored message parsing logic in `useMessages.ts` to accommodate new message formats and ensure proper rendering of embedded content. - Removed deprecated tool call handling from the message structure, streamlining the message display process. * chore: ruff format * feat: implement agent statistics tracking and display in chat - Added `AgentStats` and `TokenUsage` data classes to track agent performance metrics. - Enhanced `ToolLoopAgentRunner` to collect and update agent statistics during execution. - Integrated agent statistics sending to webchat for real-time updates. - Updated chat route to save and display agent statistics in message history. - Improved frontend components to visualize agent statistics, including token usage and duration metrics. * fix: improve message handling in Telegram event and agent run utility - Updated message sending logic in `astr_agent_run_util.py` to send the correct message chain for tool calls. - Enhanced `tg_event.py` to edit messages during streaming breaks, improving message management and user experience. - Added error handling for message editing failures to ensure robustness. * chore: ruff format --- astrbot/core/agent/response.py | 23 +- .../agent/runners/tool_loop_agent_runner.py | 68 +- astrbot/core/astr_agent_run_util.py | 25 +- astrbot/core/message/components.py | 9 +- .../platform/sources/telegram/tg_event.py | 9 + .../platform/sources/webchat/webchat_event.py | 40 +- astrbot/core/provider/entities.py | 41 ++ .../core/provider/sources/anthropic_source.py | 43 +- .../core/provider/sources/gemini_source.py | 21 +- .../core/provider/sources/openai_source.py | 19 +- astrbot/core/star/context.py | 4 + astrbot/dashboard/routes/chat.py | 65 +- dashboard/src/components/chat/Chat.vue | 4 + dashboard/src/components/chat/MessageList.vue | 681 +++++++++++++++--- dashboard/src/composables/useMessages.ts | 422 ++++++----- .../src/i18n/locales/en-US/features/chat.json | 8 + .../src/i18n/locales/zh-CN/features/chat.json | 8 + 17 files changed, 1155 insertions(+), 335 deletions(-) diff --git a/astrbot/core/agent/response.py b/astrbot/core/agent/response.py index 3f3430c87..9e61fa8c7 100644 --- a/astrbot/core/agent/response.py +++ b/astrbot/core/agent/response.py @@ -1,7 +1,8 @@ import typing as T -from dataclasses import dataclass +from dataclasses import dataclass, field from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.provider.entities import TokenUsage class AgentResponseData(T.TypedDict): @@ -12,3 +13,23 @@ class AgentResponseData(T.TypedDict): class AgentResponse: type: str data: AgentResponseData + + +@dataclass +class AgentStats: + token_usage: TokenUsage = field(default_factory=TokenUsage) + start_time: float = 0.0 + end_time: float = 0.0 + time_to_first_token: float = 0.0 + + @property + def duration(self) -> float: + return self.end_time - self.start_time + + def to_dict(self) -> dict: + return { + "token_usage": self.token_usage.__dict__, + "start_time": self.start_time, + "end_time": self.end_time, + "time_to_first_token": self.time_to_first_token, + } diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 450e4dbcb..069de144f 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -1,4 +1,5 @@ import sys +import time import traceback import typing as T @@ -12,6 +13,7 @@ from mcp.types import ( ) from astrbot import logger +from astrbot.core.message.components import Json from astrbot.core.message.message_event_result import ( MessageChain, ) @@ -24,7 +26,7 @@ from astrbot.core.provider.provider import Provider from ..hooks import BaseAgentRunHooks from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment -from ..response import AgentResponseData +from ..response import AgentResponseData, AgentStats from ..run_context import ContextWrapper, TContext from ..tool_executor import BaseFunctionToolExecutor from .base import AgentResponse, AgentState, BaseAgentRunner @@ -69,6 +71,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): ) self.run_context.messages = messages + self.stats = AgentStats() + self.stats.start_time = time.time() + async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: """Yields chunks *and* a final LLMResponse.""" if self.streaming: @@ -98,6 +103,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): async for llm_response in self._iter_llm_responses(): if llm_response.is_chunk: + # update ttft + if self.stats.time_to_first_token == 0: + self.stats.time_to_first_token = time.time() - self.stats.start_time + if llm_response.result_chain: yield AgentResponse( type="streaming_delta", @@ -121,6 +130,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): ) continue llm_resp_result = llm_response + + if not llm_response.is_chunk and llm_response.usage: + # only count the token usage of the final response for computation purpose + self.stats.token_usage += llm_response.usage break # got final response if not llm_resp_result: @@ -132,6 +145,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): if llm_resp.role == "err": # 如果 LLM 响应错误,转换到错误状态 self.final_llm_resp = llm_resp + self.stats.end_time = time.time() self._transition_state(AgentState.ERROR) yield AgentResponse( type="err", @@ -146,6 +160,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): # 如果没有工具调用,转换到完成状态 self.final_llm_resp = llm_resp self._transition_state(AgentState.DONE) + self.stats.end_time = time.time() # record the final assistant message self.run_context.messages.append( Message( @@ -175,23 +190,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): # 如果有工具调用,还需处理工具调用 if llm_resp.tools_call_name: tool_call_result_blocks = [] - for tool_call_name in llm_resp.tools_call_name: - yield AgentResponse( - type="tool_call", - data=AgentResponseData( - chain=MessageChain(type="tool_call").message( - f"🔨 调用工具: {tool_call_name}" - ), - ), - ) async for result in self._handle_function_tools(self.req, llm_resp): if isinstance(result, list): tool_call_result_blocks = result elif isinstance(result, MessageChain): if result.type is None: - result.type = "tool_call_result" + # should not happen + continue + if result.type == "tool_direct_result": + ar_type = "tool_call_result" + else: + ar_type = result.type yield AgentResponse( - type="tool_call_result", + type=ar_type, data=AgentResponseData(chain=result), ) # 将结果添加到上下文中 @@ -234,6 +245,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): llm_response.tools_call_args, llm_response.tools_call_ids, ): + yield MessageChain( + type="tool_call", + chain=[ + Json( + data={ + "id": func_tool_id, + "name": func_tool_name, + "args": func_tool_args, + "ts": time.time(), + } + ) + ], + ) try: if not req.func_tool: return @@ -307,7 +331,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): content=res.content[0].text, ), ) - yield MessageChain().message(res.content[0].text) elif isinstance(res.content[0], ImageContent): tool_call_result_blocks.append( ToolCallMessageSegment( @@ -329,7 +352,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): content=resource.text, ), ) - yield MessageChain().message(resource.text) elif ( isinstance(resource, BlobResourceContents) and resource.mimeType @@ -353,7 +375,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): content="返回的数据类型不受支持", ), ) - yield MessageChain().message("返回的数据类型不受支持。") + + # yield the last tool call result + if tool_call_result_blocks: + last_tcr_content = str(tool_call_result_blocks[-1].content) + yield MessageChain( + type="tool_call_result", + chain=[ + Json( + data={ + "id": func_tool_id, + "ts": time.time(), + "result": last_tcr_content, + } + ) + ], + ) elif resp is None: # Tool 直接请求发送消息给用户 @@ -363,6 +400,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。" ) self._transition_state(AgentState.DONE) + self.stats.end_time = time.time() else: # 不应该出现其他类型 logger.warning( diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index d94d96a82..5421a14c0 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator from astrbot.core import logger from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.message.components import Json from astrbot.core.message.message_event_result import ( MessageChain, MessageEventResult, @@ -33,16 +34,27 @@ async def run_agent( msg_chain = resp.data["chain"] if msg_chain.type == "tool_direct_result": # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容 - await astr_event.send(resp.data["chain"]) + await astr_event.send(msg_chain) continue + if astr_event.get_platform_id() == "webchat": + await astr_event.send(msg_chain) # 对于其他情况,暂时先不处理 continue elif resp.type == "tool_call": if agent_runner.streaming: # 用来标记流式响应需要分节 yield MessageChain(chain=[], type="break") - if show_tool_use: + + if astr_event.get_platform_name() == "webchat": await astr_event.send(resp.data["chain"]) + elif show_tool_use: + json_comp = resp.data["chain"].chain[0] + if isinstance(json_comp, Json): + m = f"🔨 调用工具: {json_comp.data.get('name')}" + else: + m = "🔨 调用工具..." + chain = MessageChain(type="tool_call").message(m) + await astr_event.send(chain) continue if stream_to_general and resp.type == "streaming_delta": @@ -69,6 +81,15 @@ async def run_agent( continue yield resp.data["chain"] # MessageChain if agent_runner.done(): + # send agent stats to webchat + if astr_event.get_platform_name() == "webchat": + await astr_event.send( + MessageChain( + type="agent_stats", + chain=[Json(data=agent_runner.stats.to_dict())], + ) + ) + break except Exception as e: diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 0e7b3bab6..050e36521 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -629,12 +629,11 @@ class Nodes(BaseMessageComponent): class Json(BaseMessageComponent): type = ComponentType.Json - data: str | dict - resid: int | None = 0 + data: dict - def __init__(self, data, **_): - if isinstance(data, dict): - data = json.dumps(data) + def __init__(self, data: str | dict, **_): + if isinstance(data, str): + data = json.loads(data) super().__init__(data=data, **_) diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 37f60e65a..5faba6803 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -200,6 +200,15 @@ class TelegramPlatformEvent(AstrMessageEvent): if isinstance(chain, MessageChain): if chain.type == "break": # 分割符 + if message_id: + try: + await self.client.edit_message_text( + text=delta, + chat_id=payload["chat_id"], + message_id=message_id, + ) + except Exception as e: + logger.warning(f"编辑消息失败(streaming-break): {e!s}") message_id = None # 重置消息 ID delta = "" # 重置 delta continue diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 9f1a6d059..2e529bb1d 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -1,11 +1,12 @@ import base64 +import json import os import shutil import uuid from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import File, Image, Plain, Record +from astrbot.api.message_components import File, Image, Json, Plain, Record from astrbot.core.utils.astrbot_path import get_astrbot_data_path from .webchat_queue_mgr import webchat_queue_mgr @@ -41,12 +42,20 @@ class WebChatMessageEvent(AstrMessageEvent): await web_chat_back_queue.put( { "type": "plain", - "cid": cid, "data": data, "streaming": streaming, "chain_type": message.type, }, ) + elif isinstance(comp, Json): + await web_chat_back_queue.put( + { + "type": "plain", + "data": json.dumps(comp.data, ensure_ascii=False), + "streaming": streaming, + "chain_type": message.type, + }, + ) elif isinstance(comp, Image): # save image to local filename = f"{str(uuid.uuid4())}.jpg" @@ -58,7 +67,6 @@ class WebChatMessageEvent(AstrMessageEvent): await web_chat_back_queue.put( { "type": "image", - "cid": cid, "data": data, "streaming": streaming, }, @@ -74,7 +82,6 @@ class WebChatMessageEvent(AstrMessageEvent): await web_chat_back_queue.put( { "type": "record", - "cid": cid, "data": data, "streaming": streaming, }, @@ -91,7 +98,6 @@ class WebChatMessageEvent(AstrMessageEvent): await web_chat_back_queue.put( { "type": "file", - "cid": cid, "data": data, "streaming": streaming, }, @@ -111,18 +117,17 @@ class WebChatMessageEvent(AstrMessageEvent): cid = self.session_id.split("!")[-1] web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) async for chain in generator: - if chain.type == "break" and final_data: - # 分割符 - await web_chat_back_queue.put( - { - "type": "break", # break means a segment end - "data": final_data, - "streaming": True, - "cid": cid, - }, - ) - final_data = "" - continue + # if chain.type == "break" and final_data: + # # 分割符 + # await web_chat_back_queue.put( + # { + # "type": "break", # break means a segment end + # "data": final_data, + # "streaming": True, + # }, + # ) + # final_data = "" + # continue r = await WebChatMessageEvent._send( chain, @@ -142,7 +147,6 @@ class WebChatMessageEvent(AstrMessageEvent): "data": final_data, "reasoning": reasoning_content, "streaming": True, - "cid": cid, }, ) await super().send_streaming(generator, use_fallback) diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index dc188f141..d13e9b56a 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import enum import json @@ -199,6 +201,38 @@ class ProviderRequest: return "" +@dataclass +class TokenUsage: + input_other: int = 0 + """The number of input tokens, excluding cached tokens.""" + input_cached: int = 0 + """The number of input cached tokens.""" + output: int = 0 + """The number of output tokens.""" + + @property + def total(self) -> int: + return self.input_other + self.input_cached + self.output + + @property + def input(self) -> int: + return self.input_other + self.input_cached + + def __add__(self, other: TokenUsage) -> TokenUsage: + return TokenUsage( + input_other=self.input_other + other.input_other, + input_cached=self.input_cached + other.input_cached, + output=self.output + other.output, + ) + + def __sub__(self, other: TokenUsage) -> TokenUsage: + return TokenUsage( + input_other=self.input_other - other.input_other, + input_cached=self.input_cached - other.input_cached, + output=self.output - other.output, + ) + + @dataclass class LLMResponse: role: str @@ -227,6 +261,11 @@ class LLMResponse: is_chunk: bool = False """Indicates if the response is a chunked response.""" + id: str | None = None + """The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response.""" + usage: TokenUsage | None = None + """The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response.""" + def __init__( self, role: str, @@ -241,6 +280,8 @@ class LLMResponse: | AnthropicMessage | None = None, is_chunk: bool = False, + id: str | None = None, + usage: TokenUsage | None = None, ): """初始化 LLMResponse diff --git a/astrbot/core/provider/sources/anthropic_source.py b/astrbot/core/provider/sources/anthropic_source.py index bd0f06fba..7e33f40d9 100644 --- a/astrbot/core/provider/sources/anthropic_source.py +++ b/astrbot/core/provider/sources/anthropic_source.py @@ -6,10 +6,12 @@ from mimetypes import guess_type import anthropic from anthropic import AsyncAnthropic from anthropic.types import Message +from anthropic.types.message_delta_usage import MessageDeltaUsage +from anthropic.types.usage import Usage from astrbot import logger from astrbot.api.provider import Provider -from astrbot.core.provider.entities import LLMResponse +from astrbot.core.provider.entities import LLMResponse, TokenUsage from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.utils.io import download_image_by_url @@ -107,6 +109,22 @@ class ProviderAnthropic(Provider): return system_prompt, new_messages + def _extract_usage(self, usage: Usage) -> TokenUsage: + # https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance + return TokenUsage( + input_other=usage.input_tokens or 0, + input_cached=usage.cache_read_input_tokens or 0, + output=usage.output_tokens, + ) + + def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None: + if usage.input_tokens is not None: + token_usage.input_other = usage.input_tokens + if usage.cache_read_input_tokens is not None: + token_usage.input_cached = usage.cache_read_input_tokens + if usage.output_tokens is not None: + token_usage.output = usage.output_tokens + async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if tools: if tool_list := tools.get_func_desc_anthropic_style(): @@ -131,6 +149,10 @@ class ProviderAnthropic(Provider): llm_response.tools_call_args.append(content_block.input) llm_response.tools_call_name.append(content_block.name) llm_response.tools_call_ids.append(content_block.id) + + llm_response.id = completion.id + llm_response.usage = self._extract_usage(completion.usage) + # TODO(Soulter): 处理 end_turn 情况 if not llm_response.completion_text and not llm_response.tools_call_args: raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。") @@ -152,9 +174,16 @@ class ProviderAnthropic(Provider): final_text = "" final_tool_calls = [] + id = None + usage = TokenUsage() + async with self.client.messages.stream(**payloads) as stream: assert isinstance(stream, anthropic.AsyncMessageStream) async for event in stream: + if event.type == "message_start": + # the usage contains input token usage + id = event.message.id + usage = self._extract_usage(event.message.usage) if event.type == "content_block_start": if event.content_block.type == "text": # 文本块开始 @@ -162,6 +191,8 @@ class ProviderAnthropic(Provider): role="assistant", completion_text="", is_chunk=True, + usage=usage, + id=id, ) elif event.content_block.type == "tool_use": # 工具使用块开始,初始化缓冲区 @@ -179,6 +210,8 @@ class ProviderAnthropic(Provider): role="assistant", completion_text=event.delta.text, is_chunk=True, + usage=usage, + id=id, ) elif event.delta.type == "input_json_delta": # 工具调用参数增量 @@ -215,6 +248,8 @@ class ProviderAnthropic(Provider): tools_call_name=[tool_info["name"]], tools_call_ids=[tool_info["id"]], is_chunk=True, + usage=usage, + id=id, ) except json.JSONDecodeError: # JSON 解析失败,跳过这个工具调用 @@ -223,11 +258,17 @@ class ProviderAnthropic(Provider): # 清理缓冲区 del tool_use_buffer[event.index] + elif event.type == "message_delta": + if event.usage: + self._update_usage(usage, event.usage) + # 返回最终的完整结果 final_response = LLMResponse( role="assistant", completion_text=final_text, is_chunk=False, + usage=usage, + id=id, ) if final_tool_calls: diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index e2efc6aab..8e0b89081 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -14,7 +14,7 @@ import astrbot.core.message.components as Comp from astrbot import logger from astrbot.api.provider import Provider from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.provider.entities import LLMResponse +from astrbot.core.provider.entities import LLMResponse, TokenUsage from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.utils.io import download_image_by_url @@ -347,6 +347,16 @@ class ProviderGoogleGenAI(Provider): ] return "".join(thought_buf).strip() + def _extract_usage( + self, usage_metadata: types.GenerateContentResponseUsageMetadata + ) -> TokenUsage: + """Extract usage from candidate""" + return TokenUsage( + input_other=usage_metadata.prompt_token_count or 0, + input_cached=usage_metadata.cached_content_token_count or 0, + output=usage_metadata.candidates_token_count or 0, + ) + def _process_content_parts( self, candidate: types.Candidate, @@ -501,6 +511,9 @@ class ProviderGoogleGenAI(Provider): result.candidates[0], llm_response, ) + llm_response.id = result.response_id + if result.usage_metadata: + llm_response.usage = self._extract_usage(result.usage_metadata) return llm_response async def _query_stream( @@ -569,6 +582,9 @@ class ProviderGoogleGenAI(Provider): chunk.candidates[0], llm_response, ) + llm_response.id = chunk.response_id + if chunk.usage_metadata: + llm_response.usage = self._extract_usage(chunk.usage_metadata) yield llm_response return @@ -596,6 +612,9 @@ class ProviderGoogleGenAI(Provider): chunk.candidates[0], final_response, ) + final_response.id = chunk.response_id + if chunk.usage_metadata: + final_response.usage = self._extract_usage(chunk.usage_metadata) break # Yield final complete response with accumulated text diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 788b649a9..4aeacf672 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -12,6 +12,7 @@ from openai._exceptions import NotFoundError from openai.lib.streaming.chat._completions import ChatCompletionStreamState from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.completion_usage import CompletionUsage import astrbot.core.message.components as Comp from astrbot import logger @@ -19,7 +20,7 @@ from astrbot.api.provider import Provider from astrbot.core.agent.message import Message from astrbot.core.agent.tool import ToolSet from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.provider.entities import LLMResponse, ToolCallsResult +from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult from astrbot.core.utils.io import download_image_by_url from ..register import register_provider_adapter @@ -208,6 +209,7 @@ class ProviderOpenAIOfficial(Provider): # handle the content delta reasoning = self._extract_reasoning_content(chunk) _y = False + llm_response.id = chunk.id if reasoning: llm_response.reasoning_content = reasoning _y = True @@ -217,6 +219,8 @@ class ProviderOpenAIOfficial(Provider): chain=[Comp.Plain(completion_text)], ) _y = True + if chunk.usage: + llm_response.usage = self._extract_usage(chunk.usage) if _y: yield llm_response @@ -245,6 +249,15 @@ class ProviderOpenAIOfficial(Provider): reasoning_text = str(reasoning_attr) return reasoning_text + def _extract_usage(self, usage: CompletionUsage) -> TokenUsage: + ptd = usage.prompt_tokens_details + cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0 + return TokenUsage( + input_other=usage.prompt_tokens - cached, + input_cached=ptd.cached_tokens if ptd and ptd.cached_tokens else 0, + output=usage.completion_tokens, + ) + async def _parse_openai_completion( self, completion: ChatCompletion, tools: ToolSet | None ) -> LLMResponse: @@ -321,6 +334,10 @@ class ProviderOpenAIOfficial(Provider): raise Exception(f"API 返回的 completion 无法解析:{completion}。") llm_response.raw_completion = completion + llm_response.id = completion.id + + if completion.usage: + llm_response.usage = self._extract_usage(completion.usage) return llm_response diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 9a52ec8bc..2561762f1 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -296,6 +296,10 @@ class Context: provider_type=ProviderType.CHAT_COMPLETION, umo=umo, ) + if prov is None: + raise ProviderNotFoundError( + "provider not found, please choose provider first" + ) if not isinstance(prov, Provider): raise ValueError("返回的 Provider 不是 Provider 类型") return prov diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index f2439c058..c2b991ef7 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -227,16 +227,19 @@ class ChatRoute(Route): text: str, media_parts: list, reasoning: str, + agent_stats: dict, ): """保存 bot 消息到历史记录,返回保存的记录""" bot_message_parts = [] + bot_message_parts.extend(media_parts) if text: bot_message_parts.append({"type": "plain", "text": text}) - bot_message_parts.extend(media_parts) new_his = {"type": "bot", "message": bot_message_parts} if reasoning: new_his["reasoning"] = reasoning + if agent_stats: + new_his["agent_stats"] = agent_stats record = await self.platform_history_mgr.insert( platform_id="webchat", @@ -294,7 +297,8 @@ class ChatRoute(Route): accumulated_parts = [] accumulated_text = "" accumulated_reasoning = "" - + tool_calls = {} + agent_stats = {} try: async with track_conversation(self.running_convs, webchat_conv_id): while True: @@ -314,6 +318,16 @@ class ChatRoute(Route): result_text = result["data"] msg_type = result.get("type") streaming = result.get("streaming", False) + chain_type = result.get("chain_type") + + if chain_type == "agent_stats": + stats_info = { + "type": "agent_stats", + "data": json.loads(result_text), + } + yield f"data: {json.dumps(stats_info, ensure_ascii=False)}\n\n" + agent_stats = stats_info["data"] + continue # 发送 SSE 数据 try: @@ -335,8 +349,30 @@ class ChatRoute(Route): # 累积消息部分 if msg_type == "plain": - chain_type = result.get("chain_type", "normal") - if chain_type == "reasoning": + chain_type = result.get("chain_type") + if chain_type == "tool_call": + tool_call = json.loads(result_text) + tool_calls[tool_call.get("id")] = tool_call + if accumulated_text: + # 如果累积了文本,则先保存文本 + accumulated_parts.append( + {"type": "plain", "text": accumulated_text} + ) + accumulated_text = "" + elif chain_type == "tool_call_result": + tcr = json.loads(result_text) + tc_id = tcr.get("id") + if tc_id in tool_calls: + tool_calls[tc_id]["result"] = tcr.get("result") + tool_calls[tc_id]["finished_ts"] = tcr.get("ts") + accumulated_parts.append( + { + "type": "tool_call", + "tool_calls": [tool_calls[tc_id]], + } + ) + tool_calls.pop(tc_id, None) + elif chain_type == "reasoning": accumulated_reasoning += result_text elif streaming: accumulated_text += result_text @@ -369,15 +405,20 @@ class ChatRoute(Route): if msg_type == "end": break elif ( - (streaming and msg_type == "complete") - or not streaming - or msg_type == "break" + (streaming and msg_type == "complete") or not streaming + # or msg_type == "break" ): + if ( + chain_type == "tool_call" + or chain_type == "tool_call_result" + ): + continue saved_record = await self._save_bot_message( webchat_conv_id, accumulated_text, accumulated_parts, accumulated_reasoning, + agent_stats, ) # 发送保存的消息信息给前端 if saved_record and not client_disconnected: @@ -392,11 +433,11 @@ class ChatRoute(Route): yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n" except Exception: pass - # 重置累积变量 (对于 break 后的下一段消息) - if msg_type == "break": - accumulated_parts = [] - accumulated_text = "" - accumulated_reasoning = "" + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + tool_calls = {} + agent_stats = {} except BaseException as e: logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True) diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 509971ca8..5524e787d 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -575,5 +575,9 @@ onBeforeUnmount(() => { .chat-page-container { padding: 0 !important; } + + .conversation-header { + padding: 2px; + } } diff --git a/dashboard/src/components/chat/MessageList.vue b/dashboard/src/components/chat/MessageList.vue index 8361b5176..cd14c6574 100644 --- a/dashboard/src/components/chat/MessageList.vue +++ b/dashboard/src/components/chat/MessageList.vue @@ -5,56 +5,66 @@