From 343b153263dbb266e8ce2d056c1428ed089c9de1 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 5 Jan 2026 16:53:37 +0800 Subject: [PATCH] feat: add token_usage tracking to conversations and update related processing logic --- astrbot/core/agent/context/manager.py | 8 ++- astrbot/core/agent/context/token_counter.py | 14 ++++- .../agent/runners/tool_loop_agent_runner.py | 3 +- astrbot/core/conversation_mgr.py | 4 ++ astrbot/core/db/__init__.py | 1 + .../core/db/migration/migra_token_usage.py | 61 +++++++++++++++++++ astrbot/core/db/po.py | 7 +++ astrbot/core/db/sqlite.py | 6 +- .../method/agent_sub_stages/internal.py | 9 +++ astrbot/core/utils/migra_helper.py | 8 +++ 10 files changed, 115 insertions(+), 6 deletions(-) create mode 100644 astrbot/core/db/migration/migra_token_usage.py diff --git a/astrbot/core/agent/context/manager.py b/astrbot/core/agent/context/manager.py index 902528cd3..b8e131d98 100644 --- a/astrbot/core/agent/context/manager.py +++ b/astrbot/core/agent/context/manager.py @@ -41,7 +41,9 @@ class ContextManager: truncate_turns=config.truncate_turns ) - async def process(self, messages: list[Message]) -> list[Message]: + async def process( + self, messages: list[Message], trusted_token_usage: int = 0 + ) -> list[Message]: """Process the messages. Args: @@ -63,7 +65,9 @@ class ContextManager: # 2. 基于 token 的压缩 if self.config.max_context_tokens > 0: - total_tokens = self.token_counter.count_tokens(result) + total_tokens = self.token_counter.count_tokens( + result, trusted_token_usage + ) if self.compressor.should_compress( result, total_tokens, self.config.max_context_tokens diff --git a/astrbot/core/agent/context/token_counter.py b/astrbot/core/agent/context/token_counter.py index a6a58aea3..1d4efbe8d 100644 --- a/astrbot/core/agent/context/token_counter.py +++ b/astrbot/core/agent/context/token_counter.py @@ -11,11 +11,16 @@ class TokenCounter(Protocol): Provides an interface for counting tokens in message lists. """ - def count_tokens(self, messages: list[Message]) -> int: + def count_tokens( + self, messages: list[Message], trusted_token_usage: int = 0 + ) -> int: """Count the total tokens in the message list. Args: messages: The message list. + trusted_token_usage: The total token usage that LLM API returned. + For some cases, this value is more accurate. + But some API does not return it, so the value defaults to 0. Returns: The total token count. @@ -28,7 +33,12 @@ class EstimateTokenCounter: Provides a simple estimation of token count based on character types. """ - def count_tokens(self, messages: list[Message]) -> int: + def count_tokens( + self, messages: list[Message], trusted_token_usage: int = 0 + ) -> int: + if trusted_token_usage > 0: + return trusted_token_usage + total = 0 for msg in messages: content = msg.content diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 984018e0f..606163685 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -152,8 +152,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): llm_resp_result = None # do truncate and compress + token_usage = self.req.conversation.token_usage if self.req.conversation else 0 self.run_context.messages = await self.context_manager.process( - self.run_context.messages + self.run_context.messages, trusted_token_usage=token_usage ) async for llm_response in self._iter_llm_responses(): diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 287fe03c4..a0a0c0e2f 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -69,6 +69,7 @@ class ConversationManager: persona_id=conv_v2.persona_id, created_at=created_at, updated_at=updated_at, + token_usage=conv_v2.token_usage, ) async def new_conversation( @@ -256,6 +257,7 @@ class ConversationManager: history: list[dict] | None = None, title: str | None = None, persona_id: str | None = None, + token_usage: int | None = None, ) -> None: """更新会话的对话. @@ -263,6 +265,7 @@ class ConversationManager: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段 + token_usage (int | None): token 使用量。None 表示不更新 """ if not conversation_id: @@ -274,6 +277,7 @@ class ConversationManager: title=title, persona_id=persona_id, content=history, + token_usage=token_usage, ) async def update_conversation_title( diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 192c7b263..3a79e41c2 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -152,6 +152,7 @@ class BaseDatabase(abc.ABC): title: str | None = None, persona_id: str | None = None, content: list[dict] | None = None, + token_usage: int | None = None, ) -> None: """Update a conversation's history.""" ... diff --git a/astrbot/core/db/migration/migra_token_usage.py b/astrbot/core/db/migration/migra_token_usage.py new file mode 100644 index 000000000..07938301d --- /dev/null +++ b/astrbot/core/db/migration/migra_token_usage.py @@ -0,0 +1,61 @@ +"""Migration script to add token_usage column to conversations table. + +This migration adds the token_usage field to track token consumption for each conversation. + +Changes: +- Adds token_usage column to conversations table (default: 0) +""" + +from sqlalchemy import text + +from astrbot.api import logger, sp +from astrbot.core.db import BaseDatabase + + +async def migrate_token_usage(db_helper: BaseDatabase): + """Add token_usage column to conversations table. + + This migration adds a new column to track token consumption in conversations. + """ + # 检查是否已经完成迁移 + migration_done = await db_helper.get_preference( + "global", "global", "migration_done_token_usage_1" + ) + if migration_done: + return + + logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...") + + # 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。 + + try: + async with db_helper.get_db() as session: + # 检查列是否已存在 + result = await session.execute(text("PRAGMA table_info(conversations)")) + columns = result.fetchall() + column_names = [col[1] for col in columns] + + if "token_usage" in column_names: + logger.info("token_usage 列已存在,跳过迁移") + await sp.put_async( + "global", "global", "migration_done_token_usage_1", True + ) + return + + # 添加 token_usage 列 + await session.execute( + text( + "ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0" + ) + ) + await session.commit() + + logger.info("token_usage 列添加成功") + + # 标记迁移完成 + await sp.put_async("global", "global", "migration_done_token_usage_1", True) + logger.info("token_usage 迁移完成") + + except Exception as e: + logger.error(f"迁移过程中发生错误: {e}", exc_info=True) + raise diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 64bcf4ce3..fdbf4aff3 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True): ) title: str | None = Field(default=None, max_length=255) persona_id: str | None = Field(default=None) + token_usage: int = Field(default=0, nullable=False) + """content is a list of OpenAI-formated messages in list[dict] format. + token_usage is the total token value of the messages. + when 0, will use estimated token counter. + """ __table_args__ = ( UniqueConstraint( @@ -313,6 +318,8 @@ class Conversation: persona_id: str | None = "" created_at: int = 0 updated_at: int = 0 + token_usage: int = 0 + """对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。""" class Personality(TypedDict): diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index fa3ca9a76..7422a5cc2 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -241,7 +241,9 @@ class SQLiteDatabase(BaseDatabase): session.add(new_conversation) return new_conversation - async def update_conversation(self, cid, title=None, persona_id=None, content=None): + async def update_conversation( + self, cid, title=None, persona_id=None, content=None, token_usage=None + ): async with self.get_db() as session: session: AsyncSession async with session.begin(): @@ -255,6 +257,8 @@ class SQLiteDatabase(BaseDatabase): values["persona_id"] = persona_id if content is not None: values["content"] = content + if token_usage is not None: + values["token_usage"] = token_usage if not values: return None query = query.values(**values) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py index 984b8f9bf..69bd04314 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py @@ -6,6 +6,7 @@ from collections.abc import AsyncGenerator from astrbot.core import logger from astrbot.core.agent.message import Message +from astrbot.core.agent.response import AgentStats from astrbot.core.agent.tool import ToolSet from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.conversation_mgr import Conversation @@ -282,6 +283,7 @@ class InternalAgentSubStage(Stage): req: ProviderRequest, llm_response: LLMResponse | None, all_messages: list[Message], + runner_stats: AgentStats | None, ): if ( not req @@ -308,10 +310,16 @@ class InternalAgentSubStage(Stage): continue message_to_save.append(message.model_dump()) + # get token usage from agent runner stats + token_usage = None + if runner_stats: + token_usage = runner_stats.token_usage.total + await self.conv_manager.update_conversation( event.unified_msg_origin, req.conversation.cid, history=message_to_save, + token_usage=token_usage, ) def _get_compress_provider(self) -> Provider | None: @@ -519,6 +527,7 @@ class InternalAgentSubStage(Stage): req, agent_runner.get_final_llm_resp(), agent_runner.run_context.messages, + agent_runner.stats, ) # 异步处理 WebChat 特殊情况 diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py index b8ff677e1..6a300302d 100644 --- a/astrbot/core/utils/migra_helper.py +++ b/astrbot/core/utils/migra_helper.py @@ -3,6 +3,7 @@ import traceback from astrbot.core import astrbot_config, logger from astrbot.core.astrbot_config_mgr import AstrBotConfig, AstrBotConfigManager from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46 +from astrbot.core.db.migration.migra_token_usage import migrate_token_usage from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session @@ -139,6 +140,13 @@ async def migra( logger.error(f"Migration for webchat session failed: {e!s}") logger.error(traceback.format_exc()) + # migration for token_usage column + try: + await migrate_token_usage(db) + except Exception as e: + logger.error(f"Migration for token_usage column failed: {e!s}") + logger.error(traceback.format_exc()) + # migra third party agent runner configs _c = False providers = astrbot_config["provider"]