From 5c4326c30279a2a34fb756bae209f15eebdbb9d1 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Tue, 25 Mar 2025 20:53:23 +0800 Subject: [PATCH 1/8] =?UTF-8?q?perf:=20=E9=83=A8=E5=88=86=E8=AF=A6?= =?UTF-8?q?=E7=BB=86=E6=B3=A8=E9=87=8A,=20=E7=AC=A6=E5=90=88PEP8=E6=A0=87?= =?UTF-8?q?=E5=87=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/conversation_mgr.py | 1 + astrbot/core/core_lifecycle.py | 71 ++++++++++++++++--- astrbot/core/event_bus.py | 32 +++++++-- astrbot/core/log.py | 114 ++++++++++++++++++++++++++----- 4 files changed, 186 insertions(+), 32 deletions(-) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 6cba41142..ed29dddd9 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -11,6 +11,7 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" def __init__(self, db_helper: BaseDatabase): + # session_conversations 字典记录会话ID-用户ID 映射关系 self.session_conversations: Dict[str, str] = sp.get("session_conversation", {}) self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index e52d94674..fe257d078 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -24,31 +24,51 @@ from astrbot.core.star.star_handler import star_map class AstrBotCoreLifecycle: - def __init__(self, log_broker: LogBroker, db: BaseDatabase): - self.log_broker = log_broker - self.astrbot_config = astrbot_config - self.db = db + """ + AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。 + 该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、 + EventBus 等。 + 该类还负责加载和执行插件, 以及处理事件总线的分发。 + """ + def __init__(self, log_broker: LogBroker, db: BaseDatabase): + self.log_broker = log_broker # 初始化日志代理 + self.astrbot_config = astrbot_config # 初始化配置 + self.db = db # 初始化数据库 + + # 根据环境变量设置代理 os.environ["https_proxy"] = self.astrbot_config["http_proxy"] os.environ["http_proxy"] = self.astrbot_config["http_proxy"] os.environ["no_proxy"] = "localhost" async def initialize(self): + """ + 初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。 + """ + + # 初始化日志代理 logger.info("AstrBot v" + VERSION) if os.environ.get("TESTING", ""): - logger.setLevel("DEBUG") + logger.setLevel("DEBUG") # 测试模式下设置日志级别为 DEBUG else: - logger.setLevel(self.astrbot_config["log_level"]) + logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别 + + # 初始化事件队列 self.event_queue = Queue() + # 初始化供应商管理器 self.provider_manager = ProviderManager(self.astrbot_config, self.db) + # 初始化平台管理器 self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue) + # 初始化知识库管理器 self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config) + # 初始化对话管理器 self.conversation_manager = ConversationManager(self.db) + # 初始化提供给插件的上下文 self.star_context = Context( self.event_queue, self.astrbot_config, @@ -58,35 +78,50 @@ class AstrBotCoreLifecycle: self.conversation_manager, self.knowledge_db_manager, ) + + # 初始化插件管理器 self.plugin_manager = PluginManager(self.star_context, self.astrbot_config) + # 扫描、注册插件、实例化插件类 await self.plugin_manager.reload() - """扫描、注册插件、实例化插件类""" + # 根据配置实例化各个 Provider await self.provider_manager.initialize() - """根据配置实例化各个 Provider""" + # 初始化消息事件流水线调度器 self.pipeline_scheduler = PipelineScheduler( PipelineContext(self.astrbot_config, self.plugin_manager) ) await self.pipeline_scheduler.initialize() - """初始化消息事件流水线调度器""" + # 初始化更新器 self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"]) + + # 初始化事件总线 self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler) + + # 记录启动时间 self.start_time = int(time.time()) + + # 初始化当前任务列表 self.curr_tasks: List[asyncio.Task] = [] + # 根据配置实例化各个平台适配器 await self.platform_manager.initialize() - """根据配置实例化各个平台适配器""" + # 初始化关闭控制面板的事件 self.dashboard_shutdown_event = asyncio.Event() def _load(self): + """加载事件总线和任务并初始化""" + + # 创建一个异步任务来执行事件总线的 dispatch() 方法 + # dispatch是一个无限循环的协程, 从事件队列中获取事件并处理 event_bus_task = asyncio.create_task( self.event_bus.dispatch(), name="event_bus" ) + # 把插件中注册的所有协程函数注册到事件总线中 extra_tasks = [] for task in self.star_context._register_tasks: extra_tasks.append(asyncio.create_task(task, name=task.__name__)) @@ -100,17 +135,24 @@ class AstrBotCoreLifecycle: self.start_time = int(time.time()) async def _task_wrapper(self, task: asyncio.Task): + """异步任务包装器, 用于处理异步任务执行中出现的各种异常 + + Args: + task (asyncio.Task): 要执行的异步任务 + """ try: await task except asyncio.CancelledError: - pass + pass # 任务被取消, 静默处理 except Exception as e: + # 获取完整的异常堆栈信息, 按行分割并记录到日志中 logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}") for line in traceback.format_exc().split("\n"): logger.error(f"| {line}") logger.error("-------") async def start(self): + """启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子""" self._load() logger.info("AstrBot 启动完成。") @@ -127,16 +169,21 @@ class AstrBotCoreLifecycle: except BaseException: logger.error(traceback.format_exc()) + # 同时运行curr_tasks中的所有任务 await asyncio.gather(*self.curr_tasks, return_exceptions=True) async def stop(self): + """停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器""" + # 请求停止所有正在运行的异步任务 for task in self.curr_tasks: task.cancel() + # 终止各个管理器以及控制面板 await self.provider_manager.terminate() await self.platform_manager.terminate() self.dashboard_shutdown_event.set() + # 再次遍历curr_tasks等待每个任务真正结束 for task in self.curr_tasks: try: await task @@ -146,6 +193,7 @@ class AstrBotCoreLifecycle: logger.error(f"任务 {task.get_name()} 发生错误: {e}") async def restart(self): + """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例""" await self.provider_manager.terminate() await self.platform_manager.terminate() self.dashboard_shutdown_event.set() @@ -154,6 +202,7 @@ class AstrBotCoreLifecycle: ).start() def load_platform(self) -> List[asyncio.Task]: + """加载平台实例并返回所有平台实例的异步任务列表""" tasks = [] platform_insts = self.platform_manager.get_insts() for platform_inst in platform_insts: diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 2688bd400..91e6e46b0 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -1,3 +1,8 @@ +""" +事件总线, 用于处理事件的分发和处理 +事件总线是一个异步队列, 用于接收各种消息事件, 并将其分发到相应的处理器进行处理 +""" + import asyncio from asyncio import Queue from astrbot.core.pipeline.scheduler import PipelineScheduler @@ -6,21 +11,38 @@ from .platform import AstrMessageEvent class EventBus: + """事件总线: 用于处理事件的分发和处理 + + 维护一个异步队列, 来接受各种消息事件 + """ + def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler): - self.event_queue = event_queue - self.pipeline_scheduler = pipeline_scheduler + self.event_queue = event_queue # 事件队列 + self.pipeline_scheduler = pipeline_scheduler # 管道调度器 async def dispatch(self): + """无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑""" while True: - event: AstrMessageEvent = await self.event_queue.get() - self._print_event(event) - asyncio.create_task(self.pipeline_scheduler.execute(event)) + event: AstrMessageEvent = ( + await self.event_queue.get() + ) # 从事件队列中获取新的事件 + self._print_event(event) # 打印日志 + asyncio.create_task( + self.pipeline_scheduler.execute(event) + ) # 创建新的异步任务来执行管道调度器的处理逻辑 def _print_event(self, event: AstrMessageEvent): + """用于记录事件信息 + + Args: + event (AstrMessageEvent): 事件对象 + """ + # 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要 if event.get_sender_name(): logger.info( f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}" ) + # 没有发送者名称: [平台名] 发送者ID: 消息概要 else: logger.info( f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}" diff --git a/astrbot/core/log.py b/astrbot/core/log.py index b819fcc2b..01481f46e 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -1,3 +1,26 @@ +""" +日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能 + +const: + CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量 + log_color_config: 日志颜色配置, 定义了不同日志级别的颜色 + +class: + LogBroker: 日志代理类, 用于缓存和分发日志消息 + LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker + LogManager: 日志管理器, 用于创建和配置日志记录器 + +function: + is_plugin_path: 检查文件路径是否来自插件目录 + get_short_level_name: 将日志级别名称转换为四个字母的缩写 + +工作流程: +1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器 +2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker +3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者 +4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流 +""" + import logging import colorlog import asyncio @@ -6,7 +29,9 @@ from collections import deque from asyncio import Queue from typing import List +# 日志缓存大小 CACHED_SIZE = 200 +# 日志颜色配置 log_color_config = { "DEBUG": "green", "INFO": "bold_cyan", @@ -19,8 +44,13 @@ log_color_config = { def is_plugin_path(pathname): - """ - 检查文件路径是否来自插件目录 + """检查文件路径是否来自插件目录 + + Parameters: + pathname (str): 文件路径 + + Returns: + bool: 如果路径来自插件目录,则返回 True,否则返回 False """ if not pathname: return False @@ -30,8 +60,13 @@ def is_plugin_path(pathname): def get_short_level_name(level_name): - """ - 将日志级别名称转换为四个字母的缩写 + """将日志级别名称转换为四个字母的缩写 + + Parameters: + level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL" + + Returns: + str: 四个字母的日志级别缩写 """ level_map = { "DEBUG": "DBUG", @@ -44,12 +79,21 @@ def get_short_level_name(level_name): class LogBroker: + """日志代理类, 用于缓存和分发日志消息 + + 发布-订阅模式 + """ + def __init__(self): - self.log_cache = deque(maxlen=CACHED_SIZE) - self.subscribers: List[Queue] = [] + self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志 + self.subscribers: List[Queue] = [] # 订阅者列表 def register(self) -> Queue: - """给每个订阅者返回一个带有日志缓存的队列""" + """注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列 + + Returns: + Queue: 订阅者的队列, 可用于接收日志消息 + """ q = Queue(maxsize=CACHED_SIZE + 10) for log in self.log_cache: q.put_nowait(log) @@ -57,11 +101,19 @@ class LogBroker: return q def unregister(self, q: Queue): - """取消订阅""" + """取消订阅 + + Parameters: + q (Queue): 需要取消订阅的队列 + """ self.subscribers.remove(q) def publish(self, log_entry: str): - """发布消息""" + """发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统 + + Parameters: + log_entry (str): 日志消息, 可以是字符串或字典 + """ self.log_cache.append(log_entry) for q in self.subscribers: try: @@ -71,24 +123,46 @@ class LogBroker: class LogQueueHandler(logging.Handler): + """日志处理器, 用于将日志消息发送到 LogBroker + + 继承自 logging.Handler + """ + def __init__(self, log_broker: LogBroker): super().__init__() self.log_broker = log_broker def emit(self, record): + """日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布 + 这个方法会在每次日志记录时被调用 + + Parameters: + record (logging.LogRecord): 日志记录对象, 包含日志信息 + """ log_entry = self.format(record) self.log_broker.publish(log_entry) class LogManager: + """日志管理器, 用于创建和配置日志记录器 + + 提供了获取默认日志记录器logger和设置队列处理器的方法 + """ + @classmethod def GetLogger(cls, log_name: str = "default"): + """获取指定名称的日志记录器logger""" logger = logging.getLogger(log_name) + # 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置 if logger.hasHandlers(): return logger - console_handler = logging.StreamHandler() - console_handler.setLevel(logging.DEBUG) + # 如果logger没有处理器 + console_handler = logging.StreamHandler() # 创建一个StreamHandler用于控制台输出 + console_handler.setLevel( + logging.DEBUG + ) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG + # 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息 console_formatter = colorlog.ColoredFormatter( fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s", datefmt="%H:%M:%S", @@ -96,6 +170,8 @@ class LogManager: ) class PluginFilter(logging.Filter): + """插件过滤器类, 用于标记日志来源是插件还是核心组件""" + def filter(self, record): record.plugin_tag = ( "[Plug]" if is_plugin_path(record.pathname) else "[Core]" @@ -103,6 +179,9 @@ class LogManager: return True class FileNameFilter(logging.Filter): + """文件名过滤器类, 用于修改日志记录的文件名格式 + 例如: 将文件路径 /path/to/file.py 转换为 file. 格式""" + # 获取这个文件和父文件夹的名字:. 并且去除 .py def filter(self, record): dirname = os.path.dirname(record.pathname) @@ -114,22 +193,25 @@ class LogManager: return True class LevelNameFilter(logging.Filter): + """短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写""" + # 添加短日志级别名称 def filter(self, record): record.short_levelname = get_short_level_name(record.levelname) return True - console_handler.setFormatter(console_formatter) - logger.addFilter(PluginFilter()) - logger.addFilter(FileNameFilter()) + console_handler.setFormatter(console_formatter) # 设置处理器的格式化器 + logger.addFilter(PluginFilter()) # 添加插件过滤器 + logger.addFilter(FileNameFilter()) # 添加文件名过滤器 logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器 - logger.setLevel(logging.DEBUG) - logger.addHandler(console_handler) + logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG + logger.addHandler(console_handler) # 添加处理器到logger return logger @classmethod def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker): + """设置队列处理器, 用于将日志消息发送到 LogBroker""" handler = LogQueueHandler(log_broker) handler.setLevel(logging.DEBUG) if logger.handlers: From 9e7fe773bd79e3d62424d99d199593b26719a77c Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Wed, 26 Mar 2025 11:14:46 +0800 Subject: [PATCH 2/8] =?UTF-8?q?perf:=20=E6=9B=B4=E6=96=B0=E9=83=A8?= =?UTF-8?q?=E5=88=86=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/__init__.py | 1 + astrbot/core/core_lifecycle.py | 13 ++++++++++++- astrbot/core/event_bus.py | 10 +++++++++- astrbot/core/initial_loader.py | 16 ++++++++++++++-- 4 files changed, 36 insertions(+), 4 deletions(-) diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 9749dee24..20ec6167f 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -8,6 +8,7 @@ from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.config.default import DB_PATH from astrbot.core.config import AstrBotConfig +# 初始化数据存储文件夹 os.makedirs("data", exist_ok=True) astrbot_config = AstrBotConfig() diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index fe257d078..7a0293116 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -1,3 +1,14 @@ +""" +Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。 +该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。 +该类还负责加载和执行插件, 以及处理事件总线的分发。 + +工作流程: +1. 初始化所有组件 +2. 启动事件总线和任务, 所有任务都在这里运行 +3. 执行启动完成事件钩子 +""" + import traceback import asyncio import time @@ -121,7 +132,7 @@ class AstrBotCoreLifecycle: self.event_bus.dispatch(), name="event_bus" ) - # 把插件中注册的所有协程函数注册到事件总线中 + # 把插件中注册的所有协程函数注册到事件总线中并执行 extra_tasks = [] for task in self.star_context._register_tasks: extra_tasks.append(asyncio.create_task(task, name=task.__name__)) diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 91e6e46b0..d4caa2910 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -1,6 +1,14 @@ """ 事件总线, 用于处理事件的分发和处理 -事件总线是一个异步队列, 用于接收各种消息事件, 并将其分发到相应的处理器进行处理 +事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理 +其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑 + +class: + EventBus: 事件总线, 用于处理事件的分发和处理 + +工作流程: +1. 维护一个异步队列, 来接受各种消息事件 +2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑 """ import asyncio diff --git a/astrbot/core/initial_loader.py b/astrbot/core/initial_loader.py index f91a71da3..bea1224f3 100644 --- a/astrbot/core/initial_loader.py +++ b/astrbot/core/initial_loader.py @@ -1,3 +1,11 @@ +""" +AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。 + +工作流程: +1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期 +2. 运行核心生命周期任务和仪表板服务器 +""" + import asyncio import traceback from astrbot.core import logger @@ -8,6 +16,8 @@ from astrbot.dashboard.server import AstrBotDashboard class InitialLoader: + """AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。""" + def __init__(self, db: BaseDatabase, log_broker: LogBroker): self.db = db self.logger = logger @@ -27,10 +37,12 @@ class InitialLoader: self.dashboard_server = AstrBotDashboard( core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event ) - task = asyncio.gather(core_task, self.dashboard_server.run()) + task = asyncio.gather( + core_task, self.dashboard_server.run() + ) # 启动核心任务和仪表板服务器 try: - await task + await task # 整个AstrBot在这里运行 except asyncio.CancelledError: logger.info("🌈 正在关闭 AstrBot...") await core_lifecycle.stop() From 9717a736b15d10a14dda9247a179a63b57223844 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Wed, 26 Mar 2025 13:50:54 +0800 Subject: [PATCH 3/8] =?UTF-8?q?perf:=20=E6=9B=B4=E6=96=B0=E9=83=A8?= =?UTF-8?q?=E5=88=86=E6=8F=8F=E8=BF=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/__init__.py | 4 +++- astrbot/core/conversation_mgr.py | 13 ++++++++++++- astrbot/core/pipeline/__init__.py | 1 + astrbot/core/pipeline/context.py | 6 ++++-- astrbot/core/pipeline/scheduler.py | 18 ++++++++++++------ astrbot/core/updator.py | 12 ++++++++++++ 6 files changed, 44 insertions(+), 10 deletions(-) diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 20ec6167f..13a630632 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -20,7 +20,9 @@ if os.environ.get("TESTING", ""): logger.setLevel("DEBUG") db_helper = SQLiteDatabase(DB_PATH) -sp = SharedPreferences() # 简单的偏好设置存储 +sp = ( + SharedPreferences() +) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", "")) web_chat_queue = asyncio.Queue(maxsize=32) web_chat_back_queue = asyncio.Queue(maxsize=32) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index ed29dddd9..0e11fd465 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -1,3 +1,10 @@ +""" +AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库 + +在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话, +在一个会话中可以建立多个对话, 并且支持对话的切换和删除 +""" + import uuid import json import asyncio @@ -11,21 +18,24 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" def __init__(self, db_helper: BaseDatabase): - # session_conversations 字典记录会话ID-用户ID 映射关系 + # session_conversations 字典记录会话ID-对话ID 映射关系 self.session_conversations: Dict[str, str] = sp.get("session_conversation", {}) self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 self._start_periodic_save() def _start_periodic_save(self): + """启动定时保存任务""" asyncio.create_task(self._periodic_save()) async def _periodic_save(self): + """定时保存会话对话映射关系到存储中""" while True: await asyncio.sleep(self.save_interval) self._save_to_storage() def _save_to_storage(self): + """保存会话对话映射关系到存储中""" sp.put("session_conversation", self.session_conversations) async def new_conversation(self, unified_msg_origin: str) -> str: @@ -97,6 +107,7 @@ class ConversationManager: async def get_human_readable_context( self, unified_msg_origin, conversation_id, page=1, page_size=10 ): + """获取人类可读的上下文""" conversation = await self.get_conversation(unified_msg_origin, conversation_id) history = json.loads(conversation.history) diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 76844f6fd..b97fc0f12 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -12,6 +12,7 @@ from .process_stage.stage import ProcessStage from .result_decorate.stage import ResultDecorateStage from .respond.stage import RespondStage +# 管道阶段顺序 STAGES_ORDER = [ "WakingCheckStage", # 检查是否需要唤醒 "WhitelistCheckStage", # 检查是否在群聊/私聊白名单 diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index 1abbca4e1..eb5ffb1cd 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -5,5 +5,7 @@ from astrbot.core.star import PluginManager @dataclass class PipelineContext: - astrbot_config: AstrBotConfig - plugin_manager: PluginManager + """上下文对象,包含管道执行所需的上下文信息""" + + astrbot_config: AstrBotConfig # AstrBot 配置对象 + plugin_manager: PluginManager # 插件管理器对象 diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 66874b80f..2ed3c0d2a 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -7,23 +7,29 @@ from astrbot.core import logger class PipelineScheduler: + """管道调度器,负责调度各个阶段的执行""" + def __init__(self, context: PipelineContext): - registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__.__name__)) - self.ctx = context + registered_stages.sort( + key=lambda x: STAGES_ORDER.index(x.__class__.__name__) + ) # 按照顺序排序 + self.ctx = context # 上下文对象 async def initialize(self): + """初始化管道调度器时, 初始化所有阶段""" for stage in registered_stages: # logger.debug(f"初始化阶段 {stage.__class__ .__name__}") await stage.initialize(self.ctx) async def _process_stages(self, event: AstrMessageEvent, from_stage=0): + """依次执行各个阶段""" for i in range(from_stage, len(registered_stages)): stage = registered_stages[i] # logger.debug(f"执行阶段 {stage.__class__ .__name__}") - coro = stage.process(event) - if isinstance(coro, AsyncGenerator): - async for _ in coro: + coroutine = stage.process(event) + if isinstance(coroutine, AsyncGenerator): + async for _ in coroutine: if event.is_stopped(): logger.debug( f"阶段 {stage.__class__.__name__} 已终止事件传播。" @@ -36,7 +42,7 @@ class PipelineScheduler: ) break else: - await coro + await coroutine if event.is_stopped(): logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index 0d9860a60..1e7279a8c 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -9,6 +9,11 @@ from astrbot.core.utils.io import download_file class AstrBotUpdator(RepoZipUpdator): + """AstrBot 更新器,继承自 RepoZipUpdator 类 + 该类用于处理 AstrBot 的更新操作 + 功能包括检查更新、下载更新文件、解压缩更新文件等 + """ + def __init__(self, repo_mirror: str = "") -> None: super().__init__(repo_mirror) self.MAIN_PATH = os.path.abspath( @@ -17,6 +22,9 @@ class AstrBotUpdator(RepoZipUpdator): self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" def terminate_child_processes(self): + """终止当前进程的所有子进程 + 使用 psutil 库获取当前进程的所有子进程,并尝试终止它们 + """ try: parent = psutil.Process(os.getpid()) children = parent.children(recursive=True) @@ -35,6 +43,9 @@ class AstrBotUpdator(RepoZipUpdator): pass def _reboot(self, delay: int = 3): + """重启当前程序 + 在指定的延迟后,终止所有子进程并重新启动程序 + """ py = sys.executable time.sleep(delay) self.terminate_child_processes() @@ -46,6 +57,7 @@ class AstrBotUpdator(RepoZipUpdator): raise e async def check_update(self, url: str, current_version: str) -> ReleaseInfo: + """检查更新""" return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION) async def get_releases(self) -> list: From 6b6577006d21d27332d21304a0d4d08a531f98da Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Wed, 26 Mar 2025 17:59:30 +0800 Subject: [PATCH 4/8] =?UTF-8?q?perf:=20=E6=A0=BC=E5=BC=8F=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/db/po.py | 8 ++++++++ astrbot/core/pipeline/process_stage/method/llm_request.py | 6 +++--- astrbot/core/pipeline/result_decorate/stage.py | 4 +++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 59041d6dd..49adb2781 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -6,6 +6,8 @@ from typing import List @dataclass class Platform: + """平台使用统计数据""" + name: str count: int timestamp: int @@ -13,6 +15,8 @@ class Platform: @dataclass class Provider: + """供应商使用统计数据""" + name: str count: int timestamp: int @@ -20,6 +24,8 @@ class Provider: @dataclass class Plugin: + """插件使用统计数据""" + name: str count: int timestamp: int @@ -27,6 +33,8 @@ class Plugin: @dataclass class Command: + """命令使用统计数据""" + name: str count: int timestamp: int diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index 8d606d9e5..de2460bdc 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -58,9 +58,9 @@ class LLMRequestSubStage(Stage): if event.get_extra("provider_request"): req = event.get_extra("provider_request") - assert isinstance(req, ProviderRequest), ( - "provider_request 必须是 ProviderRequest 类型。" - ) + assert isinstance( + req, ProviderRequest + ), "provider_request 必须是 ProviderRequest 类型。" if req.conversation: req.contexts = json.loads(req.conversation.history) diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index d7bb9583c..4894b2e03 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -156,7 +156,9 @@ class ResultDecorateStage(Stage): self.ctx.astrbot_config["provider_tts_settings"]["enable"] and result.is_llm_result() ): - tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst + tts_provider = ( + self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst + ) new_chain = [] for comp in result.chain: if isinstance(comp, Plain) and len(comp.text) > 1: From 1746684e5280942c9f21da231374f3fc9cd9b328 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Wed, 26 Mar 2025 23:52:03 +0800 Subject: [PATCH 5/8] =?UTF-8?q?perf:=20=E4=BF=AE=E6=94=B9=E9=83=A8?= =?UTF-8?q?=E5=88=86=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/conversation_mgr.py | 78 ++++++++++++++++++--- astrbot/core/log.py | 26 +++++-- astrbot/core/pipeline/stage.py | 3 +- astrbot/core/pipeline/waking_check/stage.py | 5 ++ 4 files changed, 93 insertions(+), 19 deletions(-) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index 0e11fd465..c506fa8f1 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -39,7 +39,13 @@ class ConversationManager: sp.put("session_conversation", self.session_conversations) async def new_conversation(self, unified_msg_origin: str) -> str: - """新建对话,并将当前会话的对话转移到新对话""" + """新建对话,并将当前会话的对话转移到新对话 + + Args: + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + Returns: + conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + """ conversation_id = str(uuid.uuid4()) self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id) self.session_conversations[unified_msg_origin] = conversation_id @@ -47,14 +53,24 @@ class ConversationManager: return conversation_id async def switch_conversation(self, unified_msg_origin: str, conversation_id: str): - """切换会话的对话""" + """切换会话的对话 + + Args: + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + """ self.session_conversations[unified_msg_origin] = conversation_id sp.put("session_conversation", self.session_conversations) async def delete_conversation( self, unified_msg_origin: str, conversation_id: str = None ): - """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话""" + """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 + + Args: + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + """ conversation_id = self.session_conversations.get(unified_msg_origin) if conversation_id: self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id) @@ -62,23 +78,48 @@ class ConversationManager: sp.put("session_conversation", self.session_conversations) async def get_curr_conversation_id(self, unified_msg_origin: str) -> str: - """获取会话当前的对话 ID""" + """获取会话当前的对话 ID + + Args: + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + Returns: + conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + """ return self.session_conversations.get(unified_msg_origin, None) async def get_conversation( self, unified_msg_origin: str, conversation_id: str ) -> Conversation: - """获取会话的对话""" + """获取会话的对话 + + Args: + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + Returns: + conversation (Conversation): 对话对象 + """ return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id) async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]: - """获取会话的所有对话""" + """获取会话的所有对话 + + Args: + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + Returns: + conversations (List[Conversation]): 对话对象列表 + """ return self.db.get_conversations(unified_msg_origin) async def update_conversation( self, unified_msg_origin: str, conversation_id: str, history: List[Dict] ): - """更新会话的对话""" + """更新会话的对话 + + Args: + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段 + """ if conversation_id: self.db.update_conversation( user_id=unified_msg_origin, @@ -87,7 +128,12 @@ class ConversationManager: ) async def update_conversation_title(self, unified_msg_origin: str, title: str): - """更新会话的对话标题""" + """更新会话的对话标题 + + Args: + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + title (str): 对话标题 + """ conversation_id = self.session_conversations.get(unified_msg_origin) if conversation_id: self.db.update_conversation_title( @@ -97,7 +143,12 @@ class ConversationManager: async def update_conversation_persona_id( self, unified_msg_origin: str, persona_id: str ): - """更新会话的对话 Persona ID""" + """更新会话的对话 Persona ID + + Args: + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + persona_id (str): 对话 Persona ID + """ conversation_id = self.session_conversations.get(unified_msg_origin) if conversation_id: self.db.update_conversation_persona_id( @@ -107,7 +158,14 @@ class ConversationManager: async def get_human_readable_context( self, unified_msg_origin, conversation_id, page=1, page_size=10 ): - """获取人类可读的上下文""" + """获取人类可读的上下文 + + Args: + unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id + conversation_id (str): 对话 ID, 是 uuid 格式的字符串 + page (int): 页码 + page_size (int): 每页大小 + """ conversation = await self.get_conversation(unified_msg_origin, conversation_id) history = json.loads(conversation.history) diff --git a/astrbot/core/log.py b/astrbot/core/log.py index 01481f46e..501f0012e 100644 --- a/astrbot/core/log.py +++ b/astrbot/core/log.py @@ -46,7 +46,7 @@ log_color_config = { def is_plugin_path(pathname): """检查文件路径是否来自插件目录 - Parameters: + Args: pathname (str): 文件路径 Returns: @@ -62,7 +62,7 @@ def is_plugin_path(pathname): def get_short_level_name(level_name): """将日志级别名称转换为四个字母的缩写 - Parameters: + Args: level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL" Returns: @@ -103,7 +103,7 @@ class LogBroker: def unregister(self, q: Queue): """取消订阅 - Parameters: + Args: q (Queue): 需要取消订阅的队列 """ self.subscribers.remove(q) @@ -111,7 +111,7 @@ class LogBroker: def publish(self, log_entry: str): """发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统 - Parameters: + Args: log_entry (str): 日志消息, 可以是字符串或字典 """ self.log_cache.append(log_entry) @@ -136,7 +136,7 @@ class LogQueueHandler(logging.Handler): """日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布 这个方法会在每次日志记录时被调用 - Parameters: + Args: record (logging.LogRecord): 日志记录对象, 包含日志信息 """ log_entry = self.format(record) @@ -151,7 +151,14 @@ class LogManager: @classmethod def GetLogger(cls, log_name: str = "default"): - """获取指定名称的日志记录器logger""" + """获取指定名称的日志记录器logger + + Args: + log_name (str): 日志记录器的名称, 默认为 "default" + + Returns: + logging.Logger: 返回配置好的日志记录器 + """ logger = logging.getLogger(log_name) # 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置 if logger.hasHandlers(): @@ -211,7 +218,12 @@ class LogManager: @classmethod def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker): - """设置队列处理器, 用于将日志消息发送到 LogBroker""" + """设置队列处理器, 用于将日志消息发送到 LogBroker + + Args: + logger (logging.Logger): 日志记录器 + log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息 + """ handler = LogQueueHandler(log_broker) handler.setLevel(logging.DEBUG) if logger.handlers: diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index 7291dfc3e..ea87b29ee 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -8,8 +8,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent from .context import PipelineContext from astrbot.core.message.message_event_result import MessageEventResult, CommandResult -registered_stages: List[Stage] = [] -"""维护了所有已注册的 Stage 实现类""" +registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类 def register_stage(cls): diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index 9b2b20155..dfe19dc85 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -21,6 +21,11 @@ class WakingCheckStage(Stage): """ async def initialize(self, ctx: PipelineContext) -> None: + """初始化唤醒检查阶段 + + Args: + ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器 + """ self.ctx = ctx self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get( "no_permission_reply", True From 53b9497c18ccea208625fd2585fcc34b04658035 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Thu, 27 Mar 2025 21:32:38 +0800 Subject: [PATCH 6/8] =?UTF-8?q?perf:=20=E5=A2=9E=E5=8A=A0=E9=83=A8?= =?UTF-8?q?=E5=88=86=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/pipeline/scheduler.py | 20 ++++++++++++++------ main.py | 4 ++-- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 2ed3c0d2a..c5339ac4b 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -25,37 +25,45 @@ class PipelineScheduler: async def _process_stages(self, event: AstrMessageEvent, from_stage=0): """依次执行各个阶段""" for i in range(from_stage, len(registered_stages)): - stage = registered_stages[i] + stage = registered_stages[i] # 获取当前要执行的阶段 # logger.debug(f"执行阶段 {stage.__class__ .__name__}") - coroutine = stage.process(event) + coroutine = stage.process( + event + ) # 调用阶段的process方法, 返回协程或者异步生成器 + if isinstance(coroutine, AsyncGenerator): + # 如果返回的是异步生成器, 实现洋葱模型的核心 async for _ in coroutine: + # 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段 if event.is_stopped(): logger.debug( f"阶段 {stage.__class__.__name__} 已终止事件传播。" ) break + + # 递归调用, 处理所有后续阶段 await self._process_stages(event, i + 1) + + # 此处是后续所有阶段处理完毕后返回的点, 执行后置处理 if event.is_stopped(): logger.debug( f"阶段 {stage.__class__.__name__} 已终止事件传播。" ) break else: + # 如果返回的是普通协程(不含yield的async函数), 则不进入下一层(基线条件) + # 简单地等待它执行完成, 然后继续执行下一个阶段 await coroutine if event.is_stopped(): logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") break - if event.is_stopped(): - logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") - break - async def execute(self, event: AstrMessageEvent): """执行 pipeline""" await self._process_stages(event) + # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 if not event._has_send_oper and event.get_platform_name() == "webchat": await event.send(None) diff --git a/main.py b/main.py index 9937bb10f..c0ca48274 100644 --- a/main.py +++ b/main.py @@ -79,5 +79,5 @@ if __name__ == "__main__": # print logo logger.info(logo_tmpl) - dashboard_lifecycle = InitialLoader(db, log_broker) - asyncio.run(dashboard_lifecycle.start()) + core_lifecycle = InitialLoader(db, log_broker) + asyncio.run(core_lifecycle.start()) From 5484b421ced69c0a374d4df5aeac3d8247c0b884 Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Mon, 31 Mar 2025 22:30:43 +0800 Subject: [PATCH 7/8] =?UTF-8?q?perf:=20=E5=A2=9E=E5=8A=A0=E9=83=A8?= =?UTF-8?q?=E5=88=86=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/message/message_event_result.py | 1 + astrbot/core/pipeline/scheduler.py | 13 ++++- astrbot/core/pipeline/stage.py | 55 +++++++++++++++----- 3 files changed, 55 insertions(+), 14 deletions(-) diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 4cc7fb842..83c03b7fc 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -152,4 +152,5 @@ class MessageEventResult(MessageChain): return self.result_content_type == ResultContentType.LLM_RESULT +# 为了兼容旧版代码,保留 CommandResult 的别名 CommandResult = MessageEventResult diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index c5339ac4b..d29c7ec80 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -23,7 +23,12 @@ class PipelineScheduler: await stage.initialize(self.ctx) async def _process_stages(self, event: AstrMessageEvent, from_stage=0): - """依次执行各个阶段""" + """依次执行各个阶段 + + Args: + event (AstrMessageEvent): 事件对象 + from_stage (int): 从第几个阶段开始执行, 默认从0开始 + """ for i in range(from_stage, len(registered_stages)): stage = registered_stages[i] # 获取当前要执行的阶段 # logger.debug(f"执行阶段 {stage.__class__ .__name__}") @@ -60,7 +65,11 @@ class PipelineScheduler: break async def execute(self, event: AstrMessageEvent): - """执行 pipeline""" + """执行 pipeline + + Args: + event (AstrMessageEvent): 事件对象 + """ await self._process_stages(event) # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index ea87b29ee..c7d4ff792 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -22,14 +22,24 @@ class Stage(abc.ABC): @abc.abstractmethod async def initialize(self, ctx: PipelineContext) -> None: - """初始化阶段""" + """初始化阶段 + + Args: + ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器 + """ raise NotImplementedError @abc.abstractmethod async def process( self, event: AstrMessageEvent ) -> Union[None, AsyncGenerator[None, None]]: - """处理事件""" + """处理事件 + + Args: + event (AstrMessageEvent): 事件对象,包含事件的相关信息 + Returns: + Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段) + """ raise NotImplementedError async def _call_handler( @@ -40,9 +50,23 @@ class Stage(abc.ABC): *args, **kwargs, ) -> AsyncGenerator[None, None]: - """调用 Handler。""" - # 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性) - ready_to_call = None + """执行事件处理函数并处理其返回结果 + + 该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数: + 1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层 + 2. 协程: 执行一次并处理返回值 + + Args: + ctx (PipelineContext): 消息管道上下文对象 + event (AstrMessageEvent): 待处理的事件对象 + handler (Awaitable): 事件处理函数 + *args: 传递给handler的位置参数 + **kwargs: 传递给handler的关键字参数 + + Returns: + AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流 + """ + ready_to_call = None # 一个协程或者异步生成器(async def) trace_ = None @@ -51,29 +75,36 @@ class Stage(abc.ABC): except TypeError as _: # 向下兼容 trace_ = traceback.format_exc() + # 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份 ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs) if isinstance(ready_to_call, AsyncGenerator): - _has_yielded = False + # 如果是一个异步生成器, 进入洋葱模型 + _has_yielded = False # 是否返回过值 try: async for ret in ready_to_call: - # 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值) + # 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码 + # 返回值只能是 MessageEventResult 或者 None(无返回值) _has_yielded = True if isinstance(ret, (MessageEventResult, CommandResult)): + # 如果返回值是 MessageEventResult, 设置结果并继续 event.set_result(ret) - yield + yield # 传递控制权给上一层的process函数 else: - yield ret + # 如果返回值是 None, 则不设置结果并继续 + # 继续执行后续阶段 + yield ret # 传递控制权给上一层的process函数 if not _has_yielded: + # 如果这个异步生成器没有执行到yield分支 yield except Exception as e: logger.error(f"Previous Error: {trace_}") raise e elif inspect.iscoroutine(ready_to_call): - # 如果只是一个 coroutine + # 如果只是一个协程, 直接执行 ret = await ready_to_call if isinstance(ret, (MessageEventResult, CommandResult)): event.set_result(ret) - yield + yield # 传递控制权给上一层的process函数 else: - yield ret + yield ret # 传递控制权给上一层的process函数 From 6cef9c23f0a84c09b0170020b2077921871b910c Mon Sep 17 00:00:00 2001 From: anka <1350989414@qq.com> Date: Mon, 31 Mar 2025 22:41:23 +0800 Subject: [PATCH 8/8] =?UTF-8?q?bug=20fix:=20#1074=20=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E6=9C=80=E5=A4=9A=E6=90=BA=E5=B8=A6=E5=AF=B9=E8=AF=9D=E6=95=B0?= =?UTF-8?q?=E9=87=8F=E6=97=B6=E5=87=BA=E7=8E=B0bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/dashboard/routes/config.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index d3079f0d2..629a424f1 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -12,8 +12,11 @@ from astrbot.core import logger def try_cast(value: str, type_: str): - if type_ == "int" and value.isdigit(): - return int(value) + if type_ == "int": + try: + return int(value) + except (ValueError, TypeError): + return None elif ( type_ == "float" and isinstance(value, str) @@ -22,6 +25,11 @@ def try_cast(value: str, type_: str): return float(value) elif type_ == "float" and isinstance(value, int): return float(value) + elif type_ == "float": + try: + return float(value) + except (ValueError, TypeError): + return None def validate_config( @@ -34,13 +42,21 @@ def validate_config( if key not in metadata: # 无 schema 的配置项,执行类型猜测 if isinstance(value, str): - if value.isdigit(): + try: data[key] = int(value) - elif value.replace(".", "", 1).isdigit(): + continue + except ValueError: + pass + + try: data[key] = float(value) - elif value == "true": + continue + except ValueError: + pass + + if value.lower() == "true": data[key] = True - elif value == "false": + elif value.lower() == "false": data[key] = False continue meta = metadata[key]