From a01e8650425e093bc920aaf3d393acc3ad53dd5d Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 2 Dec 2024 22:20:24 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9C=AC=E5=9C=B0=E6=8C=87=E6=A0=87?= =?UTF-8?q?=E6=94=B6=E9=9B=86=E5=88=B0=E6=95=B0=E6=8D=AE=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/__init__.py | 5 ++++- astrbot/core/provider/provider.py | 3 +-- astrbot/core/utils/metrics.py | 10 ++++++++++ main.py | 6 ++---- packages/astrbot_plugin_openai/openai_adapter.py | 1 - 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index e87e58524..a1c61bb7c 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -1,5 +1,8 @@ from .log import LogManager, LogBroker from astrbot.core.utils.t2i.renderer import HtmlRenderer +from astrbot.core.db.sqlite import SQLiteDatabase +from astrbot.core.config.default import DB_PATH html_renderer = HtmlRenderer() -logger = LogManager.GetLogger(log_name='astrbot') \ No newline at end of file +logger = LogManager.GetLogger(log_name='astrbot') +db_helper = SQLiteDatabase(DB_PATH) \ No newline at end of file diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index f8d2289bb..6e3c3c9bf 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -15,10 +15,9 @@ class Provider(abc.ABC): # 维护了 session_id 的上下文,不包含 system 指令 self.session_memory = defaultdict(list) self.curr_personality = Personality(prompt=default_personality, name="") - + self.db_helper = db_helper if persistant_history: # 读取历史记录 - self.db_helper = db_helper try: for history in db_helper.get_llm_history(): self.session_memory[history.session_id] = json.loads(history.content) diff --git a/astrbot/core/utils/metrics.py b/astrbot/core/utils/metrics.py index 577a1e872..560c15311 100644 --- a/astrbot/core/utils/metrics.py +++ b/astrbot/core/utils/metrics.py @@ -2,6 +2,7 @@ import aiohttp import sys import logging from astrbot.core.config import VERSION +from astrbot.core import db_helper, logger logger = logging.getLogger("astrbot") @@ -19,6 +20,15 @@ class Metric(): payload = { "metrics_data": kwargs } + try: + if 'adapter_name' in kwargs: + db_helper.insert_platform_metrics({kwargs['adapter_name']: 1}) + if 'llm_name' in kwargs: + db_helper.insert_llm_metrics({kwargs['llm_name']: 1}) + except Exception as e: + logger.error(f"保存指标到数据库失败: {e}") + pass + try: async with aiohttp.ClientSession() as session: async with session.post(base_url, json=payload, timeout=3) as response: diff --git a/main.py b/main.py index 182233bdb..8b1b2cf30 100644 --- a/main.py +++ b/main.py @@ -7,10 +7,9 @@ import aiohttp import zipfile from typing import List from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.config import DB_PATH from astrbot.dashboard import AstrBotDashBoardLifecycle - +from astrbot.core import db_helper from astrbot.core import logger, LogManager, LogBroker # add parent path to sys.path @@ -92,8 +91,7 @@ if __name__ == "__main__": # check dashboard files asyncio.run(check_dashboard_files()) - # start db - db = SQLiteDatabase(DB_PATH) + db = db_helper # print logo logger.info(logo_tmpl) diff --git a/packages/astrbot_plugin_openai/openai_adapter.py b/packages/astrbot_plugin_openai/openai_adapter.py index 5987083c3..fbe19d503 100644 --- a/packages/astrbot_plugin_openai/openai_adapter.py +++ b/packages/astrbot_plugin_openai/openai_adapter.py @@ -19,7 +19,6 @@ from dataclasses import asdict class ProviderOpenAIOfficial(Provider): def __init__(self, llm_config: LLMConfig, db_helper: BaseDatabase, persistant_history = True) -> None: super().__init__(db_helper, llm_config.default_personality, persistant_history) - self.api_keys = [] self.chosen_api_key = None self.base_url = None