From 3a6749268001446baa6ecc014a2547fd79f4cbf5 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Sat, 14 Dec 2024 20:11:28 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E6=8F=92=E4=BB=B6=E6=8A=A5=E9=94=99?= =?UTF-8?q?=E7=9B=B4=E6=8E=A5=E7=BB=88=E6=AD=A2=E4=BA=8B=E4=BB=B6=EF=BC=9B?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E7=94=9F=E6=88=90=E5=99=A8=E5=8F=91=E9=80=81?= =?UTF-8?q?=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../process_stage/method/star_request.py | 52 +++++++++++++++---- astrbot/core/provider/provider.py | 2 + .../core/provider/sources/llmtuner_source.py | 12 +++-- 3 files changed, 53 insertions(+), 13 deletions(-) diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 2394ff7ac..70d9285d9 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -6,6 +6,8 @@ from astrbot.core.message.message_event_result import MessageEventResult, Comman from astrbot.core import logger from astrbot.core.star.star_handler import StarHandlerMetadata from astrbot.core.star.star import star_map +import traceback +import inspect class StarRequestSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: @@ -28,23 +30,53 @@ class StarRequestSubStage(Stage): star_cls_obj = star_map.get(handler.handler_module_str).star_cls # 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性) + ready_to_call = None if hasattr(handler.handler, '__self__'): # 猜测没有通过装饰器去注册 try: - ret = await handler.handler(event, **params) + ready_to_call = handler.handler(event, **params) except TypeError: # 向下兼容 - ret = await handler.handler(event, self.ctx.plugin_manager.context, **params) + ready_to_call = handler.handler(event, self.ctx.plugin_manager.context, **params) else: logger.debug("calling star handler: %s" % handler.handler_full_name) - ret = await handler.handler(star_cls_obj, event, **params) + ready_to_call = handler.handler(star_cls_obj, event, **params) logger.debug("star handler %s called" % handler.handler_full_name) - if ret: - assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" - event.stop_event() - event.set_result(ret) - # 执行后续步骤来发送消息 - yield + + if isinstance(ready_to_call, AsyncGenerator): + async for mer in ready_to_call: + # 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值) + if mer: + assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" + event.set_result(mer) + yield + else: + if event.get_result(): + yield + elif inspect.iscoroutine(ready_to_call): + # 如果只是一个 coroutine + ret = await ready_to_call + if ret: + # 如果有返回值 + assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" + event.stop_event() + event.set_result(ret) + # 执行后续步骤来发送消息 + if event.is_stopped() and event.get_result(): + # 插件主动停止事件传播,并且有结果 + event.continue_event() + yield + event.stop_event() + yield + elif not event.is_stopped and not event.get_result(): + continue + else: + yield event.clear_result() # 清除上一个 handler 的结果 except Exception as e: - logger.error(f"Star {handler.handler_full_name} handle error: {e}") \ No newline at end of file + logger.error(traceback.format_exc()) + logger.error(f"Star {handler.handler_full_name} handle error: {e}") + ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_str).name} 的处理函数 {handler.handler_name} 时出现异常:{e}" + event.set_result(MessageEventResult().message(ret)) + yield + event.stop_event() \ No newline at end of file diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 1cc267679..fffc398a9 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -79,6 +79,8 @@ class Provider(abc.ABC): @abc.abstractmethod async def get_human_readable_context(self, session_id: str, page: int, page_size: int): '''获取人类可读的上下文 + + page 从 1 开始 Example: diff --git a/astrbot/core/provider/sources/llmtuner_source.py b/astrbot/core/provider/sources/llmtuner_source.py index 751b2ce13..2f1f278c8 100644 --- a/astrbot/core/provider/sources/llmtuner_source.py +++ b/astrbot/core/provider/sources/llmtuner_source.py @@ -68,7 +68,7 @@ class LLMTunerModelLoader(Provider): responses = await self.model.achat(**conf) logger.debug(f"返回上下文:{responses}") - self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id])) + self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type) self.session_memory[session_id].append({"role": "user", "content": prompt}) self.session_memory[session_id].append({"role": "assistant", "content": responses[-1].response_text}) return responses[-1].response_text @@ -92,11 +92,17 @@ class LLMTunerModelLoader(Provider): if session_id not in self.session_memory: raise Exception("会话 ID 不存在") contexts = [] + temp_contexts = [] for record in self.session_memory[session_id]: if record['role'] == "user": - contexts.append(f"User: {record['content']}") + temp_contexts.append(f"User: {record['content']}") elif record['role'] == "assistant": - contexts.append(f"Assistant: {record['content']}") + temp_contexts.append(f"Assistant: {record['content']}") + contexts.insert(0, temp_contexts) + temp_contexts = [] + + # 展平 contexts 列表 + contexts = [item for sublist in contexts for item in sublist] # 计算分页 paged_contexts = contexts[(page-1)*page_size:page*page_size]