From 52c9045a2831f66c3412f2536ced3f5699db1f39 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 25 May 2024 17:47:41 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=E4=BA=86?= =?UTF-8?q?=E7=BB=9F=E8=AE=A1=E4=BF=A1=E6=81=AF=E6=95=B0=E6=8D=AE=E7=BB=93?= =?UTF-8?q?=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core.py | 20 +++++----------- model/command/command.py | 1 + model/platform/_platfrom.py | 14 +++++++---- model/platform/qq_gocq.py | 4 ++++ model/platform/qq_official.py | 2 ++ model/provider/openai_official.py | 3 +++ model/provider/provider.py | 29 ++++++++++++++++++++--- type/config.py | 1 + type/register.py | 7 ++++++ type/types.py | 2 -- util/general_utils.py | 39 +++++++++++++++++++++++++------ util/updator.py | 5 ++-- 12 files changed, 95 insertions(+), 32 deletions(-) create mode 100644 type/config.py diff --git a/astrbot/core.py b/astrbot/core.py index 2b8470452..d5524dbfb 100644 --- a/astrbot/core.py +++ b/astrbot/core.py @@ -22,6 +22,7 @@ from util.cmd_config import init_astrbot_config_items from type.types import GlobalObject from type.register import * from type.message import AstrBotMessage +from type.config import * from addons.dashboard.helper import DashBoardHelper from addons.dashboard.server import DashBoardData from persist.session import dbConn @@ -38,9 +39,6 @@ frequency_time = 60 # 计数默认值 frequency_count = 10 -# 版本 -version = '3.1.13' - # 语言模型 OPENAI_OFFICIAL = 'openai_official' NONE_LLM = 'none_llm' @@ -61,8 +59,6 @@ init_astrbot_config_items() # 全局对象 _global_object: GlobalObject = None -# 语言模型选择 - def privider_chooser(cfg): l = [] @@ -70,13 +66,10 @@ def privider_chooser(cfg): l.append('openai_official') return l - -''' -初始化机器人 -''' - - def init(): + ''' + 初始化机器人 + ''' global llm_instance, llm_command_instance global baidu_judge, chosen_provider global frequency_count, frequency_time @@ -92,9 +85,9 @@ def init(): # 初始化 global_object _global_object = GlobalObject() - _global_object.version = version + _global_object.version = VERSION _global_object.base_config = cfg - logger.info("AstrBot v"+version) + logger.info("AstrBot v" + VERSION) if 'reply_prefix' in cfg: # 适配旧版配置 @@ -319,7 +312,6 @@ async def record_message(platform: str, session_id: str): db_inst.increment_stat_session(platform, session_id, 1) db_inst.increment_stat_message(curr_ts, 1) db_inst.increment_stat_platform(curr_ts, platform, 1) - _global_object.cnt_total += 1 async def oper_msg(message: AstrBotMessage, diff --git a/model/command/command.py b/model/command/command.py index bc9132e0b..05234a873 100644 --- a/model/command/command.py +++ b/model/command/command.py @@ -73,6 +73,7 @@ class Command: else: raise TypeError("插件返回值格式错误。") if hit: + plugin.trig() logger.debug("hit plugin: " + plugin.metadata.plugin_name) return True, res except TypeError as e: diff --git a/model/platform/_platfrom.py b/model/platform/_platfrom.py index 06e6f7d22..3efbb9852 100644 --- a/model/platform/_platfrom.py +++ b/model/platform/_platfrom.py @@ -14,34 +14,40 @@ class Platform(): 初始化平台的各种接口 ''' self.message_handler = message_handler + self.cnt_receive = 0 + self.cnt_reply = 0 pass @abc.abstractmethod - async def handle_msg(): + async def handle_msg(self): ''' 处理到来的消息 ''' + self.cnt_receive += 1 pass @abc.abstractmethod - async def reply_msg(): + async def reply_msg(self): ''' 回复消息(被动发送) ''' + self.cnt_reply += 1 pass @abc.abstractmethod - async def send_msg(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]): + async def send_msg(self, target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]): ''' 发送消息(主动发送) ''' + self.cnt_reply += 1 pass @abc.abstractmethod - async def send(target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]): + async def send(self, target: Union[GuildMessage, GroupMessage, FriendMessage, str], message: Union[str, list]): ''' 发送消息(主动发送)同 send_msg() ''' + self.cnt_reply += 1 pass def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str, list]) -> str: diff --git a/model/platform/qq_gocq.py b/model/platform/qq_gocq.py index ac651613a..c7dc7dea8 100644 --- a/model/platform/qq_gocq.py +++ b/model/platform/qq_gocq.py @@ -104,6 +104,7 @@ class QQGOCQ(Platform): self.client.run() async def handle_msg(self, message: AstrBotMessage): + await super().handle_msg() logger.info( f"{message.sender.nickname}/{message.sender.user_id} -> {self.parse_message_outline(message)}") @@ -176,6 +177,7 @@ class QQGOCQ(Platform): async def reply_msg(self, message: Union[AstrBotMessage, GuildMessage, GroupMessage, FriendMessage], result_message: list): + await super().reply_msg() """ 插件开发者请使用send方法, 可以不用直接调用这个方法。 """ @@ -254,6 +256,7 @@ class QQGOCQ(Platform): 提供给插件的发送QQ消息接口。 参数说明:第一个参数可以是消息对象,也可以是QQ群号。第二个参数是消息内容(消息内容可以是消息链列表,也可以是纯文字信息)。 ''' + await super().reply_msg() try: await self.reply_msg(message, result_message) except BaseException as e: @@ -265,6 +268,7 @@ class QQGOCQ(Platform): ''' 同 send_msg() ''' + await super().reply_msg() await self.reply_msg(to, res) def create_text_image(title: str, text: str, max_width=30, font_size=20): diff --git a/model/platform/qq_official.py b/model/platform/qq_official.py index daa5a0f5f..d7e754412 100644 --- a/model/platform/qq_official.py +++ b/model/platform/qq_official.py @@ -102,6 +102,7 @@ class QQOfficial(Platform): ) async def handle_msg(self, message: AstrBotMessage): + await super().handle_msg() assert isinstance(message.raw_message, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage)) is_group = message.type != MessageType.FRIEND_MESSAGE @@ -154,6 +155,7 @@ class QQOfficial(Platform): ''' 回复频道消息 ''' + await super().reply_msg() if isinstance(message, AstrBotMessage): source = message.raw_message else: diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index fff43ca15..39bd7d3ac 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -289,6 +289,7 @@ class ProviderOpenAIOfficial(Provider): extra_conf: Dict = None, **kwargs ) -> str: + super().accu_model_stat() if not session_id: session_id = "unknown" if "unknown" in self.session_memory: @@ -421,6 +422,7 @@ class ProviderOpenAIOfficial(Provider): ''' retry = 0 conf = self.image_generator_model_configs + super().accu_model_stat(model=conf['model']) if not conf: logger.error("OpenAI 图片生成模型配置不存在。") raise Exception("OpenAI 图片生成模型配置不存在。") @@ -481,6 +483,7 @@ class ProviderOpenAIOfficial(Provider): def set_model(self, model: str): self.model_configs['model'] = model + super().set_curr_model(model) def get_configs(self): return self.model_configs diff --git a/model/provider/provider.py b/model/provider/provider.py index 8bf437ae8..038698e25 100644 --- a/model/provider/provider.py +++ b/model/provider/provider.py @@ -1,4 +1,27 @@ +from collections import defaultdict + class Provider: + def __init__(self) -> None: + self.model_stat = defaultdict(int) # 用于记录 LLM Model 使用数据 + self.curr_model_name = "unknown" + + def reset_model_stat(self): + self.model_stat.clear() + + def set_curr_model(self, model_name: str): + self.curr_model_name = model_name + + def get_curr_model(self): + ''' + 返回当前正在使用的 LLM + ''' + return self.curr_model_name + + def accu_model_stat(self, model: str = None): + if not model: + model = self.get_curr_model() + self.model_stat[model] += 1 + async def text_chat(self, prompt: str, session_id: str, @@ -18,7 +41,7 @@ class Provider: extra_conf: 额外配置 default_personality: 默认人格 ''' - raise NotImplementedError + raise NotImplementedError() async def image_generate(self, prompt, session_id, **kwargs) -> str: ''' @@ -26,10 +49,10 @@ class Provider: prompt: 提示词 session_id: 会话id ''' - raise NotImplementedError + raise NotImplementedError() async def forget(self, session_id=None) -> bool: ''' 重置会话 ''' - raise NotImplementedError + raise NotImplementedError() diff --git a/type/config.py b/type/config.py new file mode 100644 index 000000000..98553c5e9 --- /dev/null +++ b/type/config.py @@ -0,0 +1 @@ +VERSION = '3.1.13' \ No newline at end of file diff --git a/type/register.py b/type/register.py index 8d99fc666..d26fbfd29 100644 --- a/type/register.py +++ b/type/register.py @@ -15,6 +15,13 @@ class RegisteredPlugin: module_path: str module: ModuleType root_dir_name: str + trig_cnt: int = 0 + + def reset_trig_cnt(self): + self.trig_cnt = 0 + + def trig(self): + self.trig_cnt += 1 def __str__(self) -> str: return f"RegisteredPlugin({self.metadata}, {self.module_path}, {self.root_dir_name})" diff --git a/type/types.py b/type/types.py index 2f9fd4928..7964ce26b 100644 --- a/type/types.py +++ b/type/types.py @@ -15,7 +15,6 @@ class GlobalObject: web_search: bool # 是否开启了网页搜索 reply_prefix: str # 回复前缀 unique_session: bool # 是否开启了独立会话 - cnt_total: int # 总消息数 default_personality: dict dashboard_data = None @@ -26,7 +25,6 @@ class GlobalObject: self.web_search = False # 是否开启了网页搜索 self.reply_prefix = None self.unique_session = False - self.cnt_total = 0 self.platforms = [] self.llms = [] self.default_personality = None diff --git a/util/general_utils.py b/util/general_utils.py index d9f2d851c..e97f00e66 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -15,6 +15,7 @@ from PIL import Image, ImageDraw, ImageFont from type.types import GlobalObject from SparkleLogging.utils.core import LogManager from logging import Logger +from collections import defaultdict logger: Logger = LogManager.GetLogger(log_name='astrbot-core') @@ -466,15 +467,39 @@ def get_sys_info(global_object: GlobalObject): def upload(_global_object: GlobalObject): + ''' + 上传相关非敏感统计数据 + ''' while True: - addr_ip = '' + platform_stats = {} + llm_stats = {} + plugin_stats = {} + for platform in _global_object.platforms: + platform_stats[platform.platform_name] = { + "cnt_receive": platform.platform_instance.cnt_receive, + "cnt_reply": platform.platform_instance.cnt_reply + } + + for llm in _global_object.llms: + for k, v in llm.llm_instance.model_stat: + llm_stats[llm.llm_name + "_" + k] = v + llm.llm_instance.reset_model_stat() + + for plugin in _global_object.cached_plugins: + plugin_stats[plugin.metadata.plugin_name] = { + "metadata": plugin.metadata, + "trig_cnt": plugin.trig_cnt + } + plugin.reset_trig_cnt() + try: res = { - "version": _global_object.version, - "count": _global_object.cnt_total, - "ip": addr_ip, - "sys": sys.platform, - "admin": "null", + "stat_version": "moon", + "version": _global_object.version, # 版本号 + "platform_stats": platform_stats, # 过去 30 分钟各消息平台交互消息数 + "llm_stats": llm_stats, + "plugin_stats": plugin_stats, + "sys": sys.platform, # 系统版本 } resp = requests.post( 'https://api.soulter.top/upload', data=json.dumps(res), timeout=5) @@ -484,7 +509,7 @@ def upload(_global_object: GlobalObject): _global_object.cnt_total = 0 except BaseException as e: pass - time.sleep(10*60) + time.sleep(30*60) def retry(n: int = 3): ''' diff --git a/util/updator.py b/util/updator.py index e97564f89..f624bb75a 100644 --- a/util/updator.py +++ b/util/updator.py @@ -6,6 +6,7 @@ except BaseException as e: has_git = False import sys, os import requests +from type.config import VERSION def _reboot(): py = sys.executable @@ -78,11 +79,11 @@ def check_update() -> str: print(f"当前版本: {curr_commit}") print(f"最新版本: {new_commit}") if curr_commit.startswith(new_commit): - return "当前已经是最新版本。" + return f"当前已经是最新版本: v{VERSION}" else: update_info = f"""有新版本可用。 === 当前版本 === -{curr_commit} +v{VERSION} === 新版本 === {update_data[0]['version']} From 123ee24f7e6c2933e7dd6a70d5ea1963eac0cd72 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 25 May 2024 18:01:16 +0800 Subject: [PATCH 2/2] fix: stat perf --- model/provider/openai_official.py | 1 + util/general_utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/model/provider/openai_official.py b/model/provider/openai_official.py index 39bd7d3ac..d270cc201 100644 --- a/model/provider/openai_official.py +++ b/model/provider/openai_official.py @@ -73,6 +73,7 @@ class ProviderOpenAIOfficial(Provider): base_url=self.base_url ) self.model_configs: Dict = cfg['chatGPTConfigs'] + super().set_curr_model(self.model_configs['model']) self.image_generator_model_configs: Dict = self.cc.get('openai_image_generate', None) self.session_memory: Dict[str, List] = {} # 会话记忆 self.session_memory_lock = threading.Lock() diff --git a/util/general_utils.py b/util/general_utils.py index e97f00e66..931b4314d 100644 --- a/util/general_utils.py +++ b/util/general_utils.py @@ -470,6 +470,7 @@ def upload(_global_object: GlobalObject): ''' 上传相关非敏感统计数据 ''' + time.sleep(10) while True: platform_stats = {} llm_stats = {} @@ -481,8 +482,9 @@ def upload(_global_object: GlobalObject): } for llm in _global_object.llms: - for k, v in llm.llm_instance.model_stat: - llm_stats[llm.llm_name + "_" + k] = v + stat = llm.llm_instance.model_stat + for k in stat: + llm_stats[llm.llm_name + "#" + k] = stat[k] llm.llm_instance.reset_model_stat() for plugin in _global_object.cached_plugins: