From 5c4326c30279a2a34fb756bae209f15eebdbb9d1 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Tue, 25 Mar 2025 20:53:23 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E9=83=A8=E5=88=86=E8=AF=A6=E7=BB=86?= =?UTF-8?q?=E6=B3=A8=E9=87=8A,=20=E7=AC=A6=E5=90=88PEP8=E6=A0=87=E5=87=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/conversation_mgr.py | 1 + astrbot/core/core_lifecycle.py | 71 ++++++++++++++++--- astrbot/core/event_bus.py | 32 +++++++-- astrbot/core/log.py | 114 ++++++++++++++++++++++++++----- 4 files changed, 186 insertions(+), 32 deletions(-) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 6cba41142..ed29dddd9 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -11,6 +11,7 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" def __init__(self, db_helper: BaseDatabase): + # session_conversations 字典记录会话ID-用户ID 映射关系 self.session_conversations: Dict[str, str] = sp.get("session_conversation", {}) self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index e52d94674..fe257d078 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -24,31 +24,51 @@ from astrbot.core.star.star_handler import star_map class AstrBotCoreLifecycle: - def __init__(self, log_broker: LogBroker, db: BaseDatabase): - self.log_broker = log_broker - self.astrbot_config = astrbot_config - self.db = db + """ + AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。 + 该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、 + EventBus 等。 + 该类还负责加载和执行插件, 以及处理事件总线的分发。 + """ + def __init__(self, log_broker: LogBroker, db: BaseDatabase): + self.log_broker = log_broker # 初始化日志代理 + self.astrbot_config = astrbot_config # 初始化配置 + self.db = db # 初始化数据库 + + # 根据环境变量设置代理 os.environ["https_proxy"] = self.astrbot_config["http_proxy"] os.environ["http_proxy"] = self.astrbot_config["http_proxy"] os.environ["no_proxy"] = "localhost" async def initialize(self): + """ + 初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。 + """ + + # 初始化日志代理 logger.info("AstrBot v" + VERSION) if os.environ.get("TESTING", ""): - logger.setLevel("DEBUG") + logger.setLevel("DEBUG") # 测试模式下设置日志级别为 DEBUG else: - logger.setLevel(self.astrbot_config["log_level"]) + logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别 + + # 初始化事件队列 self.event_queue = Queue() + # 初始化供应商管理器 self.provider_manager = ProviderManager(self.astrbot_config, self.db) + # 初始化平台管理器 self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue) + # 初始化知识库管理器 self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config) + # 初始化对话管理器 self.conversation_manager = ConversationManager(self.db) + # 初始化提供给插件的上下文 self.star_context = Context( self.event_queue, self.astrbot_config, @@ -58,35 +78,50 @@ class AstrBotCoreLifecycle: self.conversation_manager, self.knowledge_db_manager, ) + + # 初始化插件管理器 self.plugin_manager = PluginManager(self.star_context, self.astrbot_config) + # 扫描、注册插件、实例化插件类 await self.plugin_manager.reload() - """扫描、注册插件、实例化插件类""" + # 根据配置实例化各个 Provider await self.provider_manager.initialize() - """根据配置实例化各个 Provider""" + # 初始化消息事件流水线调度器 self.pipeline_scheduler = PipelineScheduler( PipelineContext(self.astrbot_config, self.plugin_manager) ) await self.pipeline_scheduler.initialize() - """初始化消息事件流水线调度器""" + # 初始化更新器 self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"]) + + # 初始化事件总线 self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler) + + # 记录启动时间 self.start_time = int(time.time()) + + # 初始化当前任务列表 self.curr_tasks: List[asyncio.Task] = [] + # 根据配置实例化各个平台适配器 await self.platform_manager.initialize() - """根据配置实例化各个平台适配器""" + # 初始化关闭控制面板的事件 self.dashboard_shutdown_event = asyncio.Event() def _load(self): + """加载事件总线和任务并初始化""" + + # 创建一个异步任务来执行事件总线的 dispatch() 方法 + # dispatch是一个无限循环的协程, 从事件队列中获取事件并处理 event_bus_task = asyncio.create_task( self.event_bus.dispatch(), name="event_bus" ) + # 把插件中注册的所有协程函数注册到事件总线中 extra_tasks = [] for task in self.star_context._register_tasks: extra_tasks.append(asyncio.create_task(task, name=task.__name__)) @@ -100,17 +135,24 @@ class AstrBotCoreLifecycle: self.start_time = int(time.time()) async def _task_wrapper(self, task: asyncio.Task): + """异步任务包装器, 用于处理异步任务执行中出现的各种异常 + + Args: + task (asyncio.Task): 要执行的异步任务 + """ try: await task except asyncio.CancelledError: - pass + pass # 任务被取消, 静默处理 except Exception as e: + # 获取完整的异常堆栈信息, 按行分割并记录到日志中 logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}") for line in traceback.format_exc().split("\n"): logger.error(f"| {line}") logger.error("-------") async def start(self): + """启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子""" self._load() logger.info("AstrBot 启动完成。") @@ -127,16 +169,21 @@ class AstrBotCoreLifecycle: except BaseException: logger.error(traceback.format_exc()) + # 同时运行curr_tasks中的所有任务 await asyncio.gather(*self.curr_tasks, return_exceptions=True) async def stop(self): + """停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器""" + # 请求停止所有正在运行的异步任务 for task in self.curr_tasks: task.cancel() + # 终止各个管理器以及控制面板 await self.provider_manager.terminate() await self.platform_manager.terminate() self.dashboard_shutdown_event.set() + # 再次遍历curr_tasks等待每个任务真正结束 for task in self.curr_tasks: try: await task @@ -146,6 +193,7 @@ class AstrBotCoreLifecycle: logger.error(f"任务 {task.get_name()} 发生错误: {e}") async def restart(self): + """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例""" await self.provider_manager.terminate() await self.platform_manager.terminate() self.dashboard_shutdown_event.set() @@ -154,6 +202,7 @@ class AstrBotCoreLifecycle: ).start() def load_platform(self) -> List[asyncio.Task]: + """加载平台实例并返回所有平台实例的异步任务列表""" tasks = [] platform_insts = self.platform_manager.get_insts() for platform_inst in platform_insts: diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 2688bd400..91e6e46b0 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -1,3 +1,8 @@ +""" +事件总线, 用于处理事件的分发和处理 +事件总线是一个异步队列, 用于接收各种消息事件, 并将其分发到相应的处理器进行处理 +""" + import asyncio from asyncio import Queue from astrbot.core.pipeline.scheduler import PipelineScheduler @@ -6,21 +11,38 @@ from .platform import AstrMessageEvent class EventBus: + """事件总线: 用于处理事件的分发和处理 + + 维护一个异步队列, 来接受各种消息事件 + """ + def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler): - self.event_queue = event_queue - self.pipeline_scheduler = pipeline_scheduler + self.event_queue = event_queue # 事件队列 + self.pipeline_scheduler = pipeline_scheduler # 管道调度器 async def dispatch(self): + """无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑""" while True: - event: AstrMessageEvent = await self.event_queue.get() - self._print_event(event) - asyncio.create_task(self.pipeline_scheduler.execute(event)) + event: AstrMessageEvent = ( + await self.event_queue.get() + ) # 从事件队列中获取新的事件 + self._print_event(event) # 打印日志 + asyncio.create_task( + self.pipeline_scheduler.execute(event) + ) # 创建新的异步任务来执行管道调度器的处理逻辑 def _print_event(self, event: AstrMessageEvent): + """用于记录事件信息 + + Args: + event (AstrMessageEvent): 事件对象 + """ + # 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要 if event.get_sender_name(): logger.info( f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}" ) + # 没有发送者名称: [平台名] 发送者ID: 消息概要 else: logger.info( f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}" diff --git a/astrbot/core/log.py b/astrbot/core/log.py index b819fcc2b..01481f46e 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -1,3 +1,26 @@ +""" +日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能 + +const: + CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量 + log_color_config: 日志颜色配置, 定义了不同日志级别的颜色 + +class: + LogBroker: 日志代理类, 用于缓存和分发日志消息 + LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker + LogManager: 日志管理器, 用于创建和配置日志记录器 + +function: + is_plugin_path: 检查文件路径是否来自插件目录 + get_short_level_name: 将日志级别名称转换为四个字母的缩写 + +工作流程: +1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器 +2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker +3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者 +4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流 +""" + import logging import colorlog import asyncio @@ -6,7 +29,9 @@ from collections import deque from asyncio import Queue from typing import List +# 日志缓存大小 CACHED_SIZE = 200 +# 日志颜色配置 log_color_config = { "DEBUG": "green", "INFO": "bold_cyan", @@ -19,8 +44,13 @@ log_color_config = { def is_plugin_path(pathname): - """ - 检查文件路径是否来自插件目录 + """检查文件路径是否来自插件目录 + + Parameters: + pathname (str): 文件路径 + + Returns: + bool: 如果路径来自插件目录,则返回 True,否则返回 False """ if not pathname: return False @@ -30,8 +60,13 @@ def is_plugin_path(pathname): def get_short_level_name(level_name): - """ - 将日志级别名称转换为四个字母的缩写 + """将日志级别名称转换为四个字母的缩写 + + Parameters: + level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL" + + Returns: + str: 四个字母的日志级别缩写 """ level_map = { "DEBUG": "DBUG", @@ -44,12 +79,21 @@ def get_short_level_name(level_name): class LogBroker: + """日志代理类, 用于缓存和分发日志消息 + + 发布-订阅模式 + """ + def __init__(self): - self.log_cache = deque(maxlen=CACHED_SIZE) - self.subscribers: List[Queue] = [] + self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志 + self.subscribers: List[Queue] = [] # 订阅者列表 def register(self) -> Queue: - """给每个订阅者返回一个带有日志缓存的队列""" + """注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列 + + Returns: + Queue: 订阅者的队列, 可用于接收日志消息 + """ q = Queue(maxsize=CACHED_SIZE + 10) for log in self.log_cache: q.put_nowait(log) @@ -57,11 +101,19 @@ class LogBroker: return q def unregister(self, q: Queue): - """取消订阅""" + """取消订阅 + + Parameters: + q (Queue): 需要取消订阅的队列 + """ self.subscribers.remove(q) def publish(self, log_entry: str): - """发布消息""" + """发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统 + + Parameters: + log_entry (str): 日志消息, 可以是字符串或字典 + """ self.log_cache.append(log_entry) for q in self.subscribers: try: @@ -71,24 +123,46 @@ class LogBroker: class LogQueueHandler(logging.Handler): + """日志处理器, 用于将日志消息发送到 LogBroker + + 继承自 logging.Handler + """ + def __init__(self, log_broker: LogBroker): super().__init__() self.log_broker = log_broker def emit(self, record): + """日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布 + 这个方法会在每次日志记录时被调用 + + Parameters: + record (logging.LogRecord): 日志记录对象, 包含日志信息 + """ log_entry = self.format(record) self.log_broker.publish(log_entry) class LogManager: + """日志管理器, 用于创建和配置日志记录器 + + 提供了获取默认日志记录器logger和设置队列处理器的方法 + """ + @classmethod def GetLogger(cls, log_name: str = "default"): + """获取指定名称的日志记录器logger""" logger = logging.getLogger(log_name) + # 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置 if logger.hasHandlers(): return logger - console_handler = logging.StreamHandler() - console_handler.setLevel(logging.DEBUG) + # 如果logger没有处理器 + console_handler = logging.StreamHandler() # 创建一个StreamHandler用于控制台输出 + console_handler.setLevel( + logging.DEBUG + ) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG + # 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息 console_formatter = colorlog.ColoredFormatter( fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s", datefmt="%H:%M:%S", @@ -96,6 +170,8 @@ class LogManager: ) class PluginFilter(logging.Filter): + """插件过滤器类, 用于标记日志来源是插件还是核心组件""" + def filter(self, record): record.plugin_tag = ( "[Plug]" if is_plugin_path(record.pathname) else "[Core]" @@ -103,6 +179,9 @@ class LogManager: return True class FileNameFilter(logging.Filter): + """文件名过滤器类, 用于修改日志记录的文件名格式 + 例如: 将文件路径 /path/to/file.py 转换为 file. 格式""" + # 获取这个文件和父文件夹的名字:. 并且去除 .py def filter(self, record): dirname = os.path.dirname(record.pathname) @@ -114,22 +193,25 @@ class LogManager: return True class LevelNameFilter(logging.Filter): + """短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写""" + # 添加短日志级别名称 def filter(self, record): record.short_levelname = get_short_level_name(record.levelname) return True - console_handler.setFormatter(console_formatter) - logger.addFilter(PluginFilter()) - logger.addFilter(FileNameFilter()) + console_handler.setFormatter(console_formatter) # 设置处理器的格式化器 + logger.addFilter(PluginFilter()) # 添加插件过滤器 + logger.addFilter(FileNameFilter()) # 添加文件名过滤器 logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器 - logger.setLevel(logging.DEBUG) - logger.addHandler(console_handler) + logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG + logger.addHandler(console_handler) # 添加处理器到logger return logger @classmethod def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker): + """设置队列处理器, 用于将日志消息发送到 LogBroker""" handler = LogQueueHandler(log_broker) handler.setLevel(logging.DEBUG) if logger.handlers: