From c67501737404e77a25a06cf9a31722d692729744 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 19 Dec 2024 21:33:03 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9ELLM=E8=AF=B7=E6=B1=82?= =?UTF-8?q?=E4=BA=8B=E4=BB=B6=E9=92=A9=E5=AD=90=E5=92=8C=E8=A3=85=E9=A5=B0?= =?UTF-8?q?=E6=B6=88=E6=81=AF=E7=BB=93=E6=9E=9C=E9=92=A9=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/api/__init__.py | 2 +- astrbot/api/all.py | 2 +- astrbot/api/event/filter/__init__.py | 8 +- astrbot/api/provider/__init__.py | 1 + astrbot/core/message/message_event_result.py | 20 ++- .../process_stage/method/llm_request.py | 118 +++++++----------- .../process_stage/method/star_request.py | 49 +------- astrbot/core/pipeline/process_stage/stage.py | 12 +- astrbot/core/pipeline/respond/stage.py | 6 +- .../core/pipeline/result_decorate/stage.py | 6 + astrbot/core/pipeline/stage.py | 49 +++++++- astrbot/core/pipeline/waking_check/stage.py | 4 +- astrbot/core/platform/astr_message_event.py | 10 +- astrbot/core/provider/manager.py | 9 +- astrbot/core/provider/provider.py | 1 + astrbot/core/provider/provider_request.py | 18 +++ astrbot/core/provider/register.py | 45 +------ .../core/provider/sources/openai_source.py | 9 +- astrbot/core/provider/tool.py | 17 ++- astrbot/core/star/context.py | 13 +- astrbot/core/star/filter/permission.py | 4 +- astrbot/core/star/register/__init__.py | 10 +- astrbot/core/star/register/star_handler.py | 114 +++++++++++++---- astrbot/core/star/star_handler.py | 44 ++++++- astrbot/core/star/star_manager.py | 8 ++ 25 files changed, 353 insertions(+), 226 deletions(-) create mode 100644 astrbot/core/provider/provider_request.py diff --git a/astrbot/api/__init__.py b/astrbot/api/__init__.py index d484e7c9e..e75cde6ff 100644 --- a/astrbot/api/__init__.py +++ b/astrbot/api/__init__.py @@ -2,7 +2,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot import logger from astrbot.core.utils.personality import personalities from astrbot.core import html_renderer -from astrbot.core.provider.register import register_llm_tool as llm_tool +from astrbot.core.star.register import register_llm_tool as llm_tool __all__ = [ "AstrBotConfig", diff --git a/astrbot/api/all.py b/astrbot/api/all.py index 226e2b30f..6248fd6eb 100644 --- a/astrbot/api/all.py +++ b/astrbot/api/all.py @@ -3,7 +3,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot import logger from astrbot.core.utils.personality import personalities from astrbot.core import html_renderer -from astrbot.core.provider.register import register_llm_tool as llm_tool +from astrbot.core.star.register import register_llm_tool as llm_tool # event from astrbot.core.message.message_event_result import ( diff --git a/astrbot/api/event/filter/__init__.py b/astrbot/api/event/filter/__init__.py index 6607ad372..f7c58c2e6 100644 --- a/astrbot/api/event/filter/__init__.py +++ b/astrbot/api/event/filter/__init__.py @@ -4,7 +4,10 @@ from astrbot.core.star.register import ( register_event_message_type as event_message_type, register_regex as regex, register_platform_adapter_type as platform_adapter_type, - register_permission_type as permission_type + register_permission_type as permission_type, + register_on_llm_request as on_llm_request, + register_llm_tool as llm_tool, + register_on_decorating_result as on_decorating_result ) from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType @@ -24,4 +27,7 @@ __all__ = [ 'PlatformAdapterType', 'PermissionTypeFilter', 'PermissionType', + 'on_llm_request', + 'llm_tool', + 'on_decorating_result' ] \ No newline at end of file diff --git a/astrbot/api/provider/__init__.py b/astrbot/api/provider/__init__.py index 52c5c59d4..453ffb876 100644 --- a/astrbot/api/provider/__init__.py +++ b/astrbot/api/provider/__init__.py @@ -1 +1,2 @@ from astrbot.core.provider import Provider, Personality, ProviderMetaData +from astrbot.core.provider.provider_request import ProviderRequest \ No newline at end of file diff --git a/astrbot/core/message/message_event_result.py b/astrbot/core/message/message_event_result.py index 148287580..d214544f2 100644 --- a/astrbot/core/message/message_event_result.py +++ b/astrbot/core/message/message_event_result.py @@ -97,7 +97,14 @@ class EventResultType(enum.Enum): ''' CONTINUE = enum.auto() STOP = enum.auto() - + +class ResultContentType(enum.Enum): + '''用于描述事件结果的内容的类型。 + ''' + LLM_RESULT = enum.auto() + '''调用 LLM 产生的结果''' + GENERAL_RESULT = enum.auto() + '''普通的消息结果''' @dataclass class MessageEventResult(MessageChain): '''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。 @@ -112,6 +119,8 @@ class MessageEventResult(MessageChain): result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE) + result_content_type: Optional[ResultContentType] = field(default_factory=lambda: ResultContentType.GENERAL_RESULT) + def stop_event(self) -> 'MessageEventResult': '''终止事件传播。 ''' @@ -130,5 +139,14 @@ class MessageEventResult(MessageChain): ''' return self.result_type == EventResultType.STOP + def set_result_content_type(self, result_type: EventResultType) -> 'MessageEventResult': + '''设置事件处理的结果类型。 + + Args: + result_type (EventResultType): 事件处理的结果类型。 + ''' + self.result_type = result_type + return self + CommandResult = MessageEventResult \ No newline at end of file diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index cfa0acc58..a0a2e1a26 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -1,113 +1,87 @@ import traceback -import inspect +import datetime from typing import Union, AsyncGenerator from ...context import PipelineContext from ..stage import Stage from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageEventResult, CommandResult +from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType from astrbot.core.message.components import Image from astrbot.core import logger from astrbot.core.utils.metrics import Metric -from astrbot.core.star.star import star_map - +from astrbot.core.provider.provider_request import ProviderRequest +from astrbot.core.star.star_handler import star_handlers_registry, EventType class LLMRequestSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix'] self.identifier = ctx.astrbot_config['provider_settings']['identifier'] + self.enable_datetime = ctx.astrbot_config['provider_settings']["datetime_system_prompt"] self.ctx = ctx async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]: - # Chat 唤醒前缀 - if self.ctx.astrbot_config['provider_settings']['wake_prefix']: - if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']): - return - event.message_str = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):] + req: ProviderRequest = None + + if event.get_extra("provider_request"): + print("provider_request") + req = event.get_extra("provider_request") + assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。" + else: + req = ProviderRequest(prompt="", image_urls=[]) + if self.ctx.astrbot_config['provider_settings']['wake_prefix']: + if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']): + return + req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):] + req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager() + for comp in event.message_obj.message: + if isinstance(comp, Image): + image_url = comp.url if comp.url else comp.file + req.image_urls.append(image_url) + req.session_id = event.session_id + + provider = self.ctx.plugin_manager.context.get_using_provider() if self.prompt_prefix: - event.message_str = self.prompt_prefix + event.message_str + req.prompt = self.prompt_prefix + req.prompt if self.identifier: user_id = event.message_obj.sender.user_id user_nickname = event.message_obj.sender.nickname user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n" - event.message_str = user_info + event.message_str + req.prompt = user_info + req.prompt + if self.enable_datetime: + req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}" + if provider.curr_personality['prompt']: + req.system_prompt += f"\n{provider.curr_personality['prompt']}" - image_urls = [] - for comp in event.message_obj.message: - if isinstance(comp, Image): - image_url = comp.url if comp.url else comp.file - image_urls.append(image_url) - - tools = self.ctx.plugin_manager.context.get_llm_tool_manager() - - provider = self.ctx.plugin_manager.context.get_using_provider() + # 执行请求 LLM 前事件。 + # 装饰 system_prompt 等功能 + handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent) + for handler in handlers: + try: + await handler.handler(event, req) + except BaseException: + logger.error(traceback.format_exc()) try: - llm_response = await provider.text_chat( - prompt=event.message_str, - session_id=event.session_id, - image_urls=image_urls, - func_tool=tools - ) + llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type) if llm_response.role == 'assistant': # text completion - event.set_result(MessageEventResult().message(llm_response.completion_text)) + event.set_result(MessageEventResult().message(llm_response.completion_text) + .set_result_content_type(ResultContentType.LLM_RESULT)) elif llm_response.role == 'tool': # function calling for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args): - func_tool = tools.get_func(func_tool_name) + func_tool = req.func_tool.get_func(func_tool_name) logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}") try: # 尝试调用工具函数 - - star_cls_obj = star_map.get(func_tool.module_name).star_cls - # 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性) - ready_to_call = None - if hasattr(func_tool.func_obj, '__self__'): - # 猜测没有通过装饰器去注册 - try: - ready_to_call = func_tool.func_obj(event, **func_tool_args) - except TypeError: - # 向下兼容 - ready_to_call = func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args) - else: - ready_to_call = func_tool.func_obj(star_cls_obj, event, **func_tool_args) - if isinstance(ready_to_call, AsyncGenerator): - async for mer in ready_to_call: - # 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值) - if mer: - assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" - event.set_result(mer) - yield - else: - if event.get_result(): - yield - elif inspect.iscoroutine(ready_to_call): - # 如果只是一个 coroutine - ret = await ready_to_call - if ret: - # 如果有返回值 - assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" - event.set_result(ret) - # 执行后续步骤来发送消息 - if event.is_stopped() and event.get_result(): - # 主动停止事件传播,并且有结果 - event.continue_event() - yield - event.clear_result() - event.stop_event() - yield - elif not event.is_stopped and not event.get_result(): - continue - else: - yield + wrapper = self._call_handler(self.ctx, event, func_tool.star_handler_metadata.handler, **func_tool_args) + async for _ in wrapper: + yield event.clear_result() # 清除上一个 handler 的结果 - except BaseException: logger.error(traceback.format_exc()) - except BaseException as e: logger.error(traceback.format_exc()) event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e))) diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index b10ad8334..177c7470e 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -2,12 +2,12 @@ from ...context import PipelineContext from ..stage import Stage from typing import Dict, Any, List, AsyncGenerator, Union from astrbot.core.platform.astr_message_event import AstrMessageEvent -from astrbot.core.message.message_event_result import MessageEventResult, CommandResult +from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core import logger from astrbot.core.star.star_handler import StarHandlerMetadata from astrbot.core.star.star import star_map import traceback -import inspect + class StarRequestSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: @@ -27,50 +27,11 @@ class StarRequestSubStage(Stage): if handler.handler_module_str not in star_map: # 孤立无援的 star handler continue - star_cls_obj = star_map.get(handler.handler_module_str).star_cls logger.debug(f"执行 Star Handler {handler.handler_full_name}") - # 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性) - ready_to_call = None - if hasattr(handler.handler, '__self__'): - # 猜测没有通过装饰器去注册 - try: - ready_to_call = handler.handler(event, **params) - except TypeError: - # 向下兼容 - ready_to_call = handler.handler(event, self.ctx.plugin_manager.context, **params) - else: - ready_to_call = handler.handler(star_cls_obj, event, **params) - - if isinstance(ready_to_call, AsyncGenerator): - async for mer in ready_to_call: - # 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值) - if mer: - assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" - event.set_result(mer) - yield - else: - if event.get_result(): - yield - elif inspect.iscoroutine(ready_to_call): - # 如果只是一个 coroutine - ret = await ready_to_call - if ret: - # 如果有返回值 - assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" - event.set_result(ret) - # 执行后续步骤来发送消息 - if event.is_stopped() and event.get_result(): - # 插件主动停止事件传播,并且有结果 - event.continue_event() - yield - event.clear_result() - event.stop_event() - yield - elif not event.is_stopped and not event.get_result(): - continue - else: - yield + wrapper = self._call_handler(self.ctx, event, handler.handler, **params) + async for _ in wrapper: + yield event.clear_result() # 清除上一个 handler 的结果 except Exception as e: logger.error(traceback.format_exc()) diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index faf3bba45..275d45b99 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -5,6 +5,7 @@ from .method.llm_request import LLMRequestSubStage from .method.star_request import StarRequestSubStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.star.star_handler import StarHandlerMetadata +from astrbot.core.provider.provider_request import ProviderRequest @register_stage class ProcessStage(Stage): @@ -25,8 +26,15 @@ class ProcessStage(Stage): activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers") if activated_handlers: - async for _ in self.star_request_sub_stage.process(event): - yield + async for resp in self.star_request_sub_stage.process(event): + # 生成器返回值处理 + if isinstance(resp, ProviderRequest): + # Handler 的 LLM 请求 + event.set_extra("provider_request", resp) + async for _ in self.llm_request_sub_stage.process(event): + yield + else: + yield if self.ctx.astrbot_config['provider_settings'].get('enable', True): if not event._has_send_oper: diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index b2328fead..6c6af9042 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -1,11 +1,11 @@ from typing import Union, AsyncGenerator -from ..stage import register_stage +from ..stage import register_stage, Stage from ..context import PipelineContext from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core import logger @register_stage -class RespondStage: +class RespondStage(Stage): async def initialize(self, ctx: PipelineContext): self.ctx = ctx @@ -13,7 +13,7 @@ class RespondStage: result = event.get_result() if result is None: return - + if len(result.chain) > 0: await event.send(result) logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}") diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index 725bff90d..faafe6dbd 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -6,6 +6,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core import logger from astrbot.core.message.components import Plain, Image from astrbot.core import html_renderer +from astrbot.core.star.star_handler import star_handlers_registry, EventType @register_stage class ResultDecorateStage: @@ -19,6 +20,11 @@ class ResultDecorateStage: if result is None: return + handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent) + for handler in handlers: + # TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。 + await handler.handler(event) + if len(result.chain) > 0: # 回复前缀 if self.reply_prefix: diff --git a/astrbot/core/pipeline/stage.py b/astrbot/core/pipeline/stage.py index 0fd64e696..9bb9339d5 100644 --- a/astrbot/core/pipeline/stage.py +++ b/astrbot/core/pipeline/stage.py @@ -1,8 +1,10 @@ from __future__ import annotations import abc -from typing import List, AsyncGenerator, Union +import inspect +from typing import List, AsyncGenerator, Union, Awaitable from astrbot.core.platform.astr_message_event import AstrMessageEvent from .context import PipelineContext +from astrbot.core.message.message_event_result import MessageEventResult, CommandResult registered_stages: List[Stage] = [] '''维护了所有已注册的 Stage 实现类''' @@ -29,4 +31,47 @@ class Stage(abc.ABC): ''' raise NotImplementedError - \ No newline at end of file + async def _call_handler( + self, + ctx: PipelineContext, + event: AstrMessageEvent, + handler: Awaitable, + **params + ) -> AsyncGenerator[None, None]: + '''调用 Handler。''' + # 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性) + ready_to_call = None + try: + ready_to_call = handler(event, **params) + except TypeError as e: + print(e) + # 向下兼容 + ready_to_call = handler(event, ctx.plugin_manager.context, **params) + + if isinstance(ready_to_call, AsyncGenerator): + async for mer in ready_to_call: + # 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值) + if mer: + assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" + event.set_result(mer) + yield + else: + if event.get_result(): + yield + elif inspect.iscoroutine(ready_to_call): + # 如果只是一个 coroutine + ret = await ready_to_call + if ret: + # 如果有返回值 + assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" + event.set_result(ret) + # 执行后续步骤来发送消息 + if event.is_stopped() and event.get_result(): + # 插件主动停止事件传播,并且有结果 + event.continue_event() + yield + event.clear_result() + event.stop_event() + yield + else: + yield \ No newline at end of file diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index f27700318..b1aa6059d 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -4,7 +4,7 @@ from typing import Union, AsyncGenerator from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core.message.components import At -from astrbot.core.star.star_handler import star_handlers_registry +from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.filter.command_group import CommandGroupFilter @@ -70,7 +70,7 @@ class WakingCheckStage(Stage): # 检查插件的 handler filter activated_handlers = [] handlers_parsed_params = {} # 注册了指令的 handler - for handler in star_handlers_registry: + for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent): # filter 需要满足 AND 的逻辑关系 passed = True child_command_handler_md = None diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index aa619246c..2f89d11ad 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -236,7 +236,9 @@ class AstrMessageEvent(abc.ABC): 清除消息事件的结果。 ''' self._result = None - + + '''消息链相关''' + def make_result(self) -> MessageEventResult: ''' 创建一个空的消息事件结果。 @@ -275,4 +277,8 @@ class AstrMessageEvent(abc.ABC): ''' mer = MessageEventResult() mer.chain = chain - return mer \ No newline at end of file + return mer + + '''LLM 请求相关''' + + \ No newline at end of file diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 6d6610ba7..7e9470c10 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,3 +1,4 @@ +import traceback from astrbot.core.config.astrbot_config import AstrBotConfig from .provider import Provider from typing import List @@ -42,8 +43,12 @@ class ProviderManager(): continue cls_type = provider_cls_map[provider_config['type']] logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...") - inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True)) - self.provider_insts.append(inst) + try: + inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True)) + self.provider_insts.append(inst) + except Exception as e: + traceback.print_exc() + logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 失败:{e}") if len(self.provider_insts) > 0: self.curr_provider_inst = self.provider_insts[0] diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index fffc398a9..c4ab049e4 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -99,6 +99,7 @@ class Provider(abc.ABC): image_urls: List[str]=None, func_tool: FuncCall=None, contexts: List=None, + system_prompt: str=None, **kwargs) -> LLMResponse: '''获得 LLM 的文本对话结果。会使用当前的模型进行对话。 diff --git a/astrbot/core/provider/provider_request.py b/astrbot/core/provider/provider_request.py new file mode 100644 index 000000000..7b8ebc450 --- /dev/null +++ b/astrbot/core/provider/provider_request.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from typing import List +from .tool import FuncCall + +@dataclass +class ProviderRequest(): + prompt: str + '''提示词''' + session_id: str = "" + '''会话 ID''' + image_urls: List[str] = None + '''图片 URL 列表''' + func_tool: FuncCall = None + '''工具''' + contexts: List = None + '''上下文''' + system_prompt: str = "" + '''系统提示词''' \ No newline at end of file diff --git a/astrbot/core/provider/register.py b/astrbot/core/provider/register.py index 1fe0f1041..f1086cb4f 100644 --- a/astrbot/core/provider/register.py +++ b/astrbot/core/provider/register.py @@ -1,8 +1,7 @@ -import docstring_parser -from typing import List, Dict, Type, Awaitable +from typing import List, Dict, Type from .provider_metadata import ProviderMetaData from astrbot.core import logger -from .tool import FuncCall, SUPPORTED_TYPES +from .tool import FuncCall provider_registry: List[ProviderMetaData] = [] '''维护了通过装饰器注册的 Provider''' @@ -27,43 +26,3 @@ def register_provider_adapter(provider_type_name: str, desc: str): return cls return decorator - -def register_llm_tool(name: str = None): - '''为函数调用(function-calling / tools-use)添加工具。 - - 请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释) - - ``` - @llm_tool(name="get_weather") # 如果 name 不填,将使用函数名 - async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult: - \'\'\'获取天气信息。 - - Args: - location(string): 地点 - \'\'\' - # 处理逻辑 - ``` - - 可接受的参数类型有:string, number, object, array, boolean。 - ''' - name_ = name - - def decorator(func_obj: Awaitable): - llm_tool_name = name_ if name_ else func_obj.__name__ - module_name = func_obj.__module__ - docstring = docstring_parser.parse(func_obj.__doc__) - args = [] - for arg in docstring.params: - if arg.type_name not in SUPPORTED_TYPES: - raise ValueError(f"LLM 函数工具 {func_obj.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}") - args.append({ - "type": arg.type_name, - "name": arg.arg_name, - "description": arg.description - }) - llm_tools.add_func(llm_tool_name, args, docstring.short_description, func_obj, module_name) - - logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册") - return func_obj - - return decorator \ No newline at end of file diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 3ab28a9c6..09c0d0ce5 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -1,7 +1,6 @@ import traceback import base64 import json -import datetime from openai import AsyncOpenAI, NOT_GIVEN from openai.types.chat.chat_completion import ChatCompletion @@ -29,7 +28,6 @@ class ProviderOpenAIOfficial(Provider): self.chosen_api_key = None self.api_keys: List = provider_config.get("key", []) self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None - self.enable_datetime = provider_config.get("datetime_system_prompt", True) self.client = AsyncOpenAI( api_key=self.chosen_api_key, @@ -133,18 +131,13 @@ class ProviderOpenAIOfficial(Provider): image_urls: List[str]=None, func_tool: FuncCall=None, contexts=None, + system_prompt=None, **kwargs ) -> LLMResponse: new_record = await self.assemble_context(prompt, image_urls) - context_query = [] if not contexts: context_query = [*self.session_memory[session_id], new_record] - system_prompt = "" - if self.curr_personality["prompt"]: - system_prompt = self.curr_personality["prompt"] - if self.enable_datetime: - system_prompt += f"Current datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}" if system_prompt: context_query.insert(0, {"role": "system", "content": system_prompt}) else: diff --git a/astrbot/core/provider/tool.py b/astrbot/core/provider/tool.py index f843b9029..76b0d8d15 100644 --- a/astrbot/core/provider/tool.py +++ b/astrbot/core/provider/tool.py @@ -1,8 +1,8 @@ import json import textwrap -from typing import Awaitable, Dict, List +from typing import Dict, List from dataclasses import dataclass - +from astrbot.core.star.star_handler import StarHandlerMetadata class FuncCallJsonFormatError(Exception): def __init__(self, msg): @@ -29,8 +29,7 @@ class FuncTool: name: str parameters: Dict description: str - func_obj: Awaitable - module_name: str = None + star_handler_metadata: StarHandlerMetadata active: bool = True '''是否激活''' @@ -56,8 +55,7 @@ class FuncCall: name: str, func_args: list, desc: str, - func_obj: Awaitable, - module_name: str = None, + star_handler_metadata: StarHandlerMetadata, ) -> None: """ 为函数调用(function-calling / tools-use)添加工具。 @@ -80,8 +78,7 @@ class FuncCall: name=name, parameters=params, description=desc, - func_obj=func_obj, - module_name=module_name, + star_handler_metadata=star_handler_metadata, ) self.func_list.append(_func) @@ -179,8 +176,8 @@ class FuncCall: # 调用函数 tool_callable = None for func in self.func_list: - if func["name"] == func_name: - tool_callable = func["func_obj"] + if func.name == func_name: + tool_callable = func.star_handler_metadata.handler break if not tool_callable: raise FuncNotFoundError(f"Request function {func_name} not found.") diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index b828f597f..c2f366e7d 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -10,7 +10,7 @@ from astrbot.core.message.message_event_result import MessageChain from astrbot.core.provider.manager import ProviderManager from astrbot.core.platform.manager import PlatformManager from .star import star_registry, StarMetadata -from .star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata +from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType from .filter.command import CommandFilter from .filter.regex import RegexFilter from typing import Awaitable @@ -69,6 +69,15 @@ class Context: 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。 ''' + md = StarHandlerMetadata( + event_type=EventType.OnLLMRequestEvent, + handler_full_name=func_obj.__module__ + "_" + func_obj.__name__, + handler_name=func_obj.__name__, + handler_module_str=func_obj.__module__, + handler=func_obj, + event_filters=[], + desc=desc + ) self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__) def unregister_llm_tool(self, name: str) -> None: @@ -112,6 +121,7 @@ class Context: ''' md = StarHandlerMetadata( + event_type=EventType.AdapterMessageEvent, handler_full_name=awaitable.__module__ + "_" + awaitable.__name__, handler_name=awaitable.__name__, handler_module_str=awaitable.__module__, @@ -129,7 +139,6 @@ class Context: handler_md=md )) star_handlers_registry.append(md) - star_handlers_map[md.handler_full_name] = md def register_provider(self, provider: Provider): ''' diff --git a/astrbot/core/star/filter/permission.py b/astrbot/core/star/filter/permission.py index 012a22337..1bdb74f18 100644 --- a/astrbot/core/star/filter/permission.py +++ b/astrbot/core/star/filter/permission.py @@ -6,8 +6,8 @@ from astrbot.core.config import AstrBotConfig class PermissionType(enum.Flag): '''权限类型。当选择 MEMBER,ADMIN 也可以通过。 ''' - ADMIN = "admin" - MEMBER = "member" + ADMIN = enum.auto() + MEMBER = enum.auto() class PermissionTypeFilter(HandlerFilter): def __init__(self, permission_type: PermissionType, raise_error: bool = True): diff --git a/astrbot/core/star/register/__init__.py b/astrbot/core/star/register/__init__.py index 4619f2669..911abc6d5 100644 --- a/astrbot/core/star/register/__init__.py +++ b/astrbot/core/star/register/__init__.py @@ -5,7 +5,10 @@ from .star_handler import ( register_event_message_type, register_platform_adapter_type, register_regex, - register_permission_type + register_permission_type, + register_on_llm_request, + register_llm_tool, + register_on_decorating_result ) __all__ = [ @@ -15,5 +18,8 @@ __all__ = [ 'register_event_message_type', 'register_platform_adapter_type', 'register_regex', - 'register_permission_type' + 'register_permission_type', + 'register_on_llm_request', + 'register_llm_tool', + 'register_on_decorating_result' ] \ No newline at end of file diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index cf337146f..8f7916984 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -1,6 +1,7 @@ from __future__ import annotations +import docstring_parser -from ..star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata +from ..star_handler import star_handlers_registry, StarHandlerMetadata, EventType from ..filter.command import CommandFilter from ..filter.command_group import CommandGroupFilter from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType @@ -8,19 +9,23 @@ from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAd from ..filter.permission import PermissionTypeFilter, PermissionType from ..filter.regex import RegexFilter from typing import Awaitable +from astrbot.core.provider.tool import SUPPORTED_TYPES +from astrbot.core.provider.register import llm_tools +from astrbot.core import logger - -def get_handler_full_name(awatable: Awaitable) -> str: +def get_handler_full_name(awaitable: Awaitable) -> str: '''获取 Handler 的全名''' - return f"{awatable.__module__}_{awatable.__name__}" + return f"{awaitable.__module__}_{awaitable.__name__}" -def get_handler_or_create(handler: Awaitable, dont_add = False) -> StarHandlerMetadata: +def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add = False) -> StarHandlerMetadata: '''获取 Handler 或者创建一个新的 Handler''' handler_full_name = get_handler_full_name(handler) - if handler_full_name in star_handlers_map: - return star_handlers_map[handler_full_name] + md = star_handlers_registry.get_handler_by_full_name(handler_full_name) + if md: + return md else: md = StarHandlerMetadata( + event_type=event_type, handler_full_name=handler_full_name, handler_name=handler.__name__, handler_module_str=handler.__module__, @@ -29,7 +34,6 @@ def get_handler_or_create(handler: Awaitable, dont_add = False) -> StarHandlerMe ) if not dont_add: star_handlers_registry.append(md) - star_handlers_map[handler_full_name] = md return md def register_command(command_name: str = None, *args): @@ -47,7 +51,7 @@ def register_command(command_name: str = None, *args): add_to_event_filters = True def decorator(awaitable): - handler_md = get_handler_or_create(awaitable) + handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) new_command.init_handler_md(handler_md) if add_to_event_filters: # 裸指令 @@ -74,7 +78,7 @@ def register_command_group(command_group_name: str = None, *args): def decorator(obj): if add_to_event_filters: # 根指令组 - handler_md = get_handler_or_create(obj) + handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent) handler_md.event_filters.append(new_group) return RegisteringCommandable(new_group) @@ -91,28 +95,28 @@ class RegisteringCommandable(): def register_event_message_type(event_message_type: EventMessageType): '''注册一个 EventMessageType''' - def decorator(awatable): - handler_md = get_handler_or_create(awatable) + def decorator(awaitable): + handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) handler_md.event_filters.append(EventMessageTypeFilter(event_message_type)) - return awatable + return awaitable return decorator def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType): '''注册一个 PlatformAdapterType''' - def decorator(awatable): - handler_md = get_handler_or_create(awatable) + def decorator(awaitable): + handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) handler_md.event_filters.append(PlatformAdapterTypeFilter(platform_adapter_type)) - return awatable + return awaitable return decorator def register_regex(regex: str): '''注册一个 Regex''' - def decorator(awatable): - handler_md = get_handler_or_create(awatable) + def decorator(awaitable): + handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) handler_md.event_filters.append(RegexFilter(regex)) - return awatable + return awaitable return decorator @@ -123,9 +127,75 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool permission_type: PermissionType raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True ''' - def decorator(awatable): - handler_md = get_handler_or_create(awatable) + def decorator(awaitable): + handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) handler_md.event_filters.append(PermissionTypeFilter(permission_type, raise_error)) - return awatable + return awaitable + return decorator + +def register_on_llm_request(): + '''当有 LLM 请求时的事件 + + Examples: + ```py + @on_llm_request() + async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None: + request.system_prompt += "你是一个猫娘..." + ``` + + 请务必接收两个参数:event, request + ''' + def decorator(awaitable): + _ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent) + return awaitable + + return decorator + +def register_llm_tool(name: str = None): + '''为函数调用(function-calling / tools-use)添加工具。 + + 请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释) + + ``` + @llm_tool(name="get_weather") # 如果 name 不填,将使用函数名 + async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult: + \'\'\'获取天气信息。 + + Args: + location(string): 地点 + \'\'\' + # 处理逻辑 + ``` + + 可接受的参数类型有:string, number, object, array, boolean。 + ''' + name_ = name + + def decorator(awaitable: Awaitable): + llm_tool_name = name_ if name_ else awaitable.__name__ + docstring = docstring_parser.parse(awaitable.__doc__) + args = [] + for arg in docstring.params: + if arg.type_name not in SUPPORTED_TYPES: + raise ValueError(f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}") + args.append({ + "type": arg.type_name, + "name": arg.arg_name, + "description": arg.description + }) + md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent) + llm_tools.add_func(llm_tool_name, args, docstring.short_description, md) + + logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册") + return awaitable + + return decorator + +def register_on_decorating_result(): + '''在发送消息前的事件''' + def decorator(awaitable): + _ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent) + return awaitable + return decorator \ No newline at end of file diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 5ba6429de..ecc575497 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -1,17 +1,53 @@ from __future__ import annotations +import enum from dataclasses import dataclass from typing import Awaitable, List, Dict from .filter import HandlerFilter -star_handlers_registry: List[StarHandlerMetadata] = [] -star_handlers_map: Dict[str, StarHandlerMetadata] = {} -'''用于快速查找。key 是 handler_full_name''' +class StarHandlerRegistry(List): + '''用于存储所有的 Star Handler''' + + star_handlers_map: Dict[str, StarHandlerMetadata] = {} + '''用于快速查找。key 是 handler_full_name''' + + def append(self, handler: StarHandlerMetadata): + '''添加一个 Handler''' + super().append(handler) + self.star_handlers_map[handler.handler_full_name] = handler + + def get_handlers_by_event_type(self, event_type: EventType) -> List[StarHandlerMetadata]: + '''通过事件类型获取 Handler''' + return [handler for handler in self if handler.event_type == event_type] + + def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata: + '''通过 Handler 的全名获取 Handler''' + return self.star_handlers_map.get(full_name, None) + + def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]: + '''通过模块名获取 Handler''' + return [handler for handler in self if handler.handler_module_str == module_name] + + +star_handlers_registry = StarHandlerRegistry() + +class EventType(enum.Enum): + '''表示一个 AstrBot 内部事件的类型。如适配器消息事件、LLM 请求事件、发送消息前的事件等 + + 用于对 Handler 的职能分组。 + ''' + AdapterMessageEvent = enum.auto() # 收到适配器发来的消息 + OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件) + OnDecoratingResultEvent = enum.auto() # 发送消息前 + OnCallingFuncToolEvent = enum.auto() # 调用函数工具 @dataclass class StarHandlerMetadata(): '''描述一个 Star 所注册的某一个 Handler。''' + event_type: EventType + '''Handler 的事件类型''' + handler_full_name: str '''格式为 f"{handler.__module__}_{handler.__name__}"''' @@ -25,7 +61,7 @@ class StarHandlerMetadata(): '''Handler 的函数对象,应当是一个异步函数''' event_filters: List[HandlerFilter] - '''一个事件过滤器,用于描述这个 Handler 能够处理、应该处理的事件''' + '''一个适配器消息事件过滤器,用于描述这个 Handler 能够处理、应该处理的适配器消息事件''' desc: str = "" '''Handler 的描述信息''' diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 07804acaf..83ed36aa9 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -1,4 +1,5 @@ import inspect +import functools import os import traceback import yaml @@ -174,6 +175,13 @@ class PluginManager: star_metadata.module = module star_metadata.root_dir_name = root_dir_name star_metadata.reserved = reserved + + related_handlers = star_handlers_registry.get_handlers_by_module_name(star_metadata.module_path) + for handler in related_handlers: + logger.debug(f"bind handler {handler.handler_name} to {star_metadata.name}") + # handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self + handler.handler = functools.partial(handler.handler, star_metadata.star_cls) + else: # v3.4.0 以前的方式注册插件 logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")