feat: 新增LLM请求事件钩子和装饰消息结果钩子

This commit is contained in:
Soulter
2024-12-19 21:33:03 +08:00
parent 86cb852507
commit c675017374
25 changed files with 353 additions and 226 deletions
+1 -1
View File
@@ -2,7 +2,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core.provider.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_llm_tool as llm_tool
__all__ = [
"AstrBotConfig",
+1 -1
View File
@@ -3,7 +3,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core.provider.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_llm_tool as llm_tool
# event
from astrbot.core.message.message_event_result import (
+7 -1
View File
@@ -4,7 +4,10 @@ from astrbot.core.star.register import (
register_event_message_type as event_message_type,
register_regex as regex,
register_platform_adapter_type as platform_adapter_type,
register_permission_type as permission_type
register_permission_type as permission_type,
register_on_llm_request as on_llm_request,
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result
)
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
@@ -24,4 +27,7 @@ __all__ = [
'PlatformAdapterType',
'PermissionTypeFilter',
'PermissionType',
'on_llm_request',
'llm_tool',
'on_decorating_result'
]
+1
View File
@@ -1 +1,2 @@
from astrbot.core.provider import Provider, Personality, ProviderMetaData
from astrbot.core.provider.provider_request import ProviderRequest
+19 -1
View File
@@ -97,7 +97,14 @@ class EventResultType(enum.Enum):
'''
CONTINUE = enum.auto()
STOP = enum.auto()
class ResultContentType(enum.Enum):
'''用于描述事件结果的内容的类型。
'''
LLM_RESULT = enum.auto()
'''调用 LLM 产生的结果'''
GENERAL_RESULT = enum.auto()
'''普通的消息结果'''
@dataclass
class MessageEventResult(MessageChain):
'''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
@@ -112,6 +119,8 @@ class MessageEventResult(MessageChain):
result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE)
result_content_type: Optional[ResultContentType] = field(default_factory=lambda: ResultContentType.GENERAL_RESULT)
def stop_event(self) -> 'MessageEventResult':
'''终止事件传播。
'''
@@ -130,5 +139,14 @@ class MessageEventResult(MessageChain):
'''
return self.result_type == EventResultType.STOP
def set_result_content_type(self, result_type: EventResultType) -> 'MessageEventResult':
'''设置事件处理的结果类型。
Args:
result_type (EventResultType): 事件处理的结果类型。
'''
self.result_type = result_type
return self
CommandResult = MessageEventResult
@@ -1,113 +1,87 @@
import traceback
import inspect
import datetime
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.star.star import star_map
from astrbot.core.provider.provider_request import ProviderRequest
from astrbot.core.star.star_handler import star_handlers_registry, EventType
class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
self.enable_datetime = ctx.astrbot_config['provider_settings']["datetime_system_prompt"]
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
# Chat 唤醒前缀
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
return
event.message_str = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
req: ProviderRequest = None
if event.get_extra("provider_request"):
print("provider_request")
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
else:
req = ProviderRequest(prompt="", image_urls=[])
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
return
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
req.session_id = event.session_id
provider = self.ctx.plugin_manager.context.get_using_provider()
if self.prompt_prefix:
event.message_str = self.prompt_prefix + event.message_str
req.prompt = self.prompt_prefix + req.prompt
if self.identifier:
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n"
event.message_str = user_info + event.message_str
req.prompt = user_info + req.prompt
if self.enable_datetime:
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
if provider.curr_personality['prompt']:
req.system_prompt += f"\n{provider.curr_personality['prompt']}"
image_urls = []
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
image_urls.append(image_url)
tools = self.ctx.plugin_manager.context.get_llm_tool_manager()
provider = self.ctx.plugin_manager.context.get_using_provider()
# 执行请求 LLM 前事件。
# 装饰 system_prompt 等功能
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
for handler in handlers:
try:
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
try:
llm_response = await provider.text_chat(
prompt=event.message_str,
session_id=event.session_id,
image_urls=image_urls,
func_tool=tools
)
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text))
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
elif llm_response.role == 'tool':
# function calling
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
func_tool = tools.get_func(func_tool_name)
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
try:
# 尝试调用工具函数
star_cls_obj = star_map.get(func_tool.module_name).star_cls
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
if hasattr(func_tool.func_obj, '__self__'):
# 猜测没有通过装饰器去注册
try:
ready_to_call = func_tool.func_obj(event, **func_tool_args)
except TypeError:
# 向下兼容
ready_to_call = func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args)
else:
ready_to_call = func_tool.func_obj(star_cls_obj, event, **func_tool_args)
if isinstance(ready_to_call, AsyncGenerator):
async for mer in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
if mer:
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(mer)
yield
else:
if event.get_result():
yield
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if ret:
# 如果有返回值
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(ret)
# 执行后续步骤来发送消息
if event.is_stopped() and event.get_result():
# 主动停止事件传播,并且有结果
event.continue_event()
yield
event.clear_result()
event.stop_event()
yield
elif not event.is_stopped and not event.get_result():
continue
else:
yield
wrapper = self._call_handler(self.ctx, event, func_tool.star_handler_metadata.handler, **func_tool_args)
async for _ in wrapper:
yield
event.clear_result() # 清除上一个 handler 的结果
except BaseException:
logger.error(traceback.format_exc())
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
@@ -2,12 +2,12 @@ from ...context import PipelineContext
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.star.star import star_map
import traceback
import inspect
class StarRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
@@ -27,50 +27,11 @@ class StarRequestSubStage(Stage):
if handler.handler_module_str not in star_map:
# 孤立无援的 star handler
continue
star_cls_obj = star_map.get(handler.handler_module_str).star_cls
logger.debug(f"执行 Star Handler {handler.handler_full_name}")
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
if hasattr(handler.handler, '__self__'):
# 猜测没有通过装饰器去注册
try:
ready_to_call = handler.handler(event, **params)
except TypeError:
# 向下兼容
ready_to_call = handler.handler(event, self.ctx.plugin_manager.context, **params)
else:
ready_to_call = handler.handler(star_cls_obj, event, **params)
if isinstance(ready_to_call, AsyncGenerator):
async for mer in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
if mer:
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(mer)
yield
else:
if event.get_result():
yield
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if ret:
# 如果有返回值
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(ret)
# 执行后续步骤来发送消息
if event.is_stopped() and event.get_result():
# 插件主动停止事件传播,并且有结果
event.continue_event()
yield
event.clear_result()
event.stop_event()
yield
elif not event.is_stopped and not event.get_result():
continue
else:
yield
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
async for _ in wrapper:
yield
event.clear_result() # 清除上一个 handler 的结果
except Exception as e:
logger.error(traceback.format_exc())
+10 -2
View File
@@ -5,6 +5,7 @@ from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.provider_request import ProviderRequest
@register_stage
class ProcessStage(Stage):
@@ -25,8 +26,15 @@ class ProcessStage(Stage):
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
if activated_handlers:
async for _ in self.star_request_sub_stage.process(event):
yield
async for resp in self.star_request_sub_stage.process(event):
# 生成器返回值处理
if isinstance(resp, ProviderRequest):
# Handler 的 LLM 请求
event.set_extra("provider_request", resp)
async for _ in self.llm_request_sub_stage.process(event):
yield
else:
yield
if self.ctx.astrbot_config['provider_settings'].get('enable', True):
if not event._has_send_oper:
+3 -3
View File
@@ -1,11 +1,11 @@
from typing import Union, AsyncGenerator
from ..stage import register_stage
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
@register_stage
class RespondStage:
class RespondStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
@@ -13,7 +13,7 @@ class RespondStage:
result = event.get_result()
if result is None:
return
if len(result.chain) > 0:
await event.send(result)
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
@@ -6,6 +6,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@register_stage
class ResultDecorateStage:
@@ -19,6 +20,11 @@ class ResultDecorateStage:
if result is None:
return
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
for handler in handlers:
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
await handler.handler(event)
if len(result.chain) > 0:
# 回复前缀
if self.reply_prefix:
+47 -2
View File
@@ -1,8 +1,10 @@
from __future__ import annotations
import abc
from typing import List, AsyncGenerator, Union
import inspect
from typing import List, AsyncGenerator, Union, Awaitable
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
registered_stages: List[Stage] = []
'''维护了所有已注册的 Stage 实现类'''
@@ -29,4 +31,47 @@ class Stage(abc.ABC):
'''
raise NotImplementedError
async def _call_handler(
self,
ctx: PipelineContext,
event: AstrMessageEvent,
handler: Awaitable,
**params
) -> AsyncGenerator[None, None]:
'''调用 Handler。'''
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
try:
ready_to_call = handler(event, **params)
except TypeError as e:
print(e)
# 向下兼容
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
if isinstance(ready_to_call, AsyncGenerator):
async for mer in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
if mer:
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(mer)
yield
else:
if event.get_result():
yield
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if ret:
# 如果有返回值
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(ret)
# 执行后续步骤来发送消息
if event.is_stopped() and event.get_result():
# 插件主动停止事件传播,并且有结果
event.continue_event()
yield
event.clear_result()
event.stop_event()
yield
else:
yield
+2 -2
View File
@@ -4,7 +4,7 @@ from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.message.components import At
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.filter.command_group import CommandGroupFilter
@@ -70,7 +70,7 @@ class WakingCheckStage(Stage):
# 检查插件的 handler filter
activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler
for handler in star_handlers_registry:
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
# filter 需要满足 AND 的逻辑关系
passed = True
child_command_handler_md = None
+8 -2
View File
@@ -236,7 +236,9 @@ class AstrMessageEvent(abc.ABC):
清除消息事件的结果。
'''
self._result = None
'''消息链相关'''
def make_result(self) -> MessageEventResult:
'''
创建一个空的消息事件结果。
@@ -275,4 +277,8 @@ class AstrMessageEvent(abc.ABC):
'''
mer = MessageEventResult()
mer.chain = chain
return mer
return mer
'''LLM 请求相关'''
+7 -2
View File
@@ -1,3 +1,4 @@
import traceback
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider
from typing import List
@@ -42,8 +43,12 @@ class ProviderManager():
continue
cls_type = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...")
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
self.provider_insts.append(inst)
try:
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
self.provider_insts.append(inst)
except Exception as e:
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 失败:{e}")
if len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0]
+1
View File
@@ -99,6 +99,7 @@ class Provider(abc.ABC):
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts: List=None,
system_prompt: str=None,
**kwargs) -> LLMResponse:
'''获得 LLM 的文本对话结果。会使用当前的模型进行对话。
+18
View File
@@ -0,0 +1,18 @@
from dataclasses import dataclass
from typing import List
from .tool import FuncCall
@dataclass
class ProviderRequest():
prompt: str
'''提示词'''
session_id: str = ""
'''会话 ID'''
image_urls: List[str] = None
'''图片 URL 列表'''
func_tool: FuncCall = None
'''工具'''
contexts: List = None
'''上下文'''
system_prompt: str = ""
'''系统提示词'''
+2 -43
View File
@@ -1,8 +1,7 @@
import docstring_parser
from typing import List, Dict, Type, Awaitable
from typing import List, Dict, Type
from .provider_metadata import ProviderMetaData
from astrbot.core import logger
from .tool import FuncCall, SUPPORTED_TYPES
from .tool import FuncCall
provider_registry: List[ProviderMetaData] = []
'''维护了通过装饰器注册的 Provider'''
@@ -27,43 +26,3 @@ def register_provider_adapter(provider_type_name: str, desc: str):
return cls
return decorator
def register_llm_tool(name: str = None):
'''为函数调用(function-calling / tools-use)添加工具。
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
```
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult:
\'\'\'获取天气信息。
Args:
location(string): 地点
\'\'\'
# 处理逻辑
```
可接受的参数类型有:string, number, object, array, boolean。
'''
name_ = name
def decorator(func_obj: Awaitable):
llm_tool_name = name_ if name_ else func_obj.__name__
module_name = func_obj.__module__
docstring = docstring_parser.parse(func_obj.__doc__)
args = []
for arg in docstring.params:
if arg.type_name not in SUPPORTED_TYPES:
raise ValueError(f"LLM 函数工具 {func_obj.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
args.append({
"type": arg.type_name,
"name": arg.arg_name,
"description": arg.description
})
llm_tools.add_func(llm_tool_name, args, docstring.short_description, func_obj, module_name)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return func_obj
return decorator
@@ -1,7 +1,6 @@
import traceback
import base64
import json
import datetime
from openai import AsyncOpenAI, NOT_GIVEN
from openai.types.chat.chat_completion import ChatCompletion
@@ -29,7 +28,6 @@ class ProviderOpenAIOfficial(Provider):
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.enable_datetime = provider_config.get("datetime_system_prompt", True)
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
@@ -133,18 +131,13 @@ class ProviderOpenAIOfficial(Provider):
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
system_prompt = ""
if self.curr_personality["prompt"]:
system_prompt = self.curr_personality["prompt"]
if self.enable_datetime:
system_prompt += f"Current datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
else:
+7 -10
View File
@@ -1,8 +1,8 @@
import json
import textwrap
from typing import Awaitable, Dict, List
from typing import Dict, List
from dataclasses import dataclass
from astrbot.core.star.star_handler import StarHandlerMetadata
class FuncCallJsonFormatError(Exception):
def __init__(self, msg):
@@ -29,8 +29,7 @@ class FuncTool:
name: str
parameters: Dict
description: str
func_obj: Awaitable
module_name: str = None
star_handler_metadata: StarHandlerMetadata
active: bool = True
'''是否激活'''
@@ -56,8 +55,7 @@ class FuncCall:
name: str,
func_args: list,
desc: str,
func_obj: Awaitable,
module_name: str = None,
star_handler_metadata: StarHandlerMetadata,
) -> None:
"""
为函数调用(function-calling / tools-use)添加工具。
@@ -80,8 +78,7 @@ class FuncCall:
name=name,
parameters=params,
description=desc,
func_obj=func_obj,
module_name=module_name,
star_handler_metadata=star_handler_metadata,
)
self.func_list.append(_func)
@@ -179,8 +176,8 @@ class FuncCall:
# 调用函数
tool_callable = None
for func in self.func_list:
if func["name"] == func_name:
tool_callable = func["func_obj"]
if func.name == func_name:
tool_callable = func.star_handler_metadata.handler
break
if not tool_callable:
raise FuncNotFoundError(f"Request function {func_name} not found.")
+11 -2
View File
@@ -10,7 +10,7 @@ from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform.manager import PlatformManager
from .star import star_registry, StarMetadata
from .star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
@@ -69,6 +69,15 @@ class Context:
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
'''
md = StarHandlerMetadata(
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
handler_module_str=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc
)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__)
def unregister_llm_tool(self, name: str) -> None:
@@ -112,6 +121,7 @@ class Context:
'''
md = StarHandlerMetadata(
event_type=EventType.AdapterMessageEvent,
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
handler_name=awaitable.__name__,
handler_module_str=awaitable.__module__,
@@ -129,7 +139,6 @@ class Context:
handler_md=md
))
star_handlers_registry.append(md)
star_handlers_map[md.handler_full_name] = md
def register_provider(self, provider: Provider):
'''
+2 -2
View File
@@ -6,8 +6,8 @@ from astrbot.core.config import AstrBotConfig
class PermissionType(enum.Flag):
'''权限类型。当选择 MEMBER,ADMIN 也可以通过。
'''
ADMIN = "admin"
MEMBER = "member"
ADMIN = enum.auto()
MEMBER = enum.auto()
class PermissionTypeFilter(HandlerFilter):
def __init__(self, permission_type: PermissionType, raise_error: bool = True):
+8 -2
View File
@@ -5,7 +5,10 @@ from .star_handler import (
register_event_message_type,
register_platform_adapter_type,
register_regex,
register_permission_type
register_permission_type,
register_on_llm_request,
register_llm_tool,
register_on_decorating_result
)
__all__ = [
@@ -15,5 +18,8 @@ __all__ = [
'register_event_message_type',
'register_platform_adapter_type',
'register_regex',
'register_permission_type'
'register_permission_type',
'register_on_llm_request',
'register_llm_tool',
'register_on_decorating_result'
]
+92 -22
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import docstring_parser
from ..star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
from ..star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from ..filter.command import CommandFilter
from ..filter.command_group import CommandGroupFilter
from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType
@@ -8,19 +9,23 @@ from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAd
from ..filter.permission import PermissionTypeFilter, PermissionType
from ..filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.provider.tool import SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from astrbot.core import logger
def get_handler_full_name(awatable: Awaitable) -> str:
def get_handler_full_name(awaitable: Awaitable) -> str:
'''获取 Handler 的全名'''
return f"{awatable.__module__}_{awatable.__name__}"
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(handler: Awaitable, dont_add = False) -> StarHandlerMetadata:
def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add = False) -> StarHandlerMetadata:
'''获取 Handler 或者创建一个新的 Handler'''
handler_full_name = get_handler_full_name(handler)
if handler_full_name in star_handlers_map:
return star_handlers_map[handler_full_name]
md = star_handlers_registry.get_handler_by_full_name(handler_full_name)
if md:
return md
else:
md = StarHandlerMetadata(
event_type=event_type,
handler_full_name=handler_full_name,
handler_name=handler.__name__,
handler_module_str=handler.__module__,
@@ -29,7 +34,6 @@ def get_handler_or_create(handler: Awaitable, dont_add = False) -> StarHandlerMe
)
if not dont_add:
star_handlers_registry.append(md)
star_handlers_map[handler_full_name] = md
return md
def register_command(command_name: str = None, *args):
@@ -47,7 +51,7 @@ def register_command(command_name: str = None, *args):
add_to_event_filters = True
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable)
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
new_command.init_handler_md(handler_md)
if add_to_event_filters:
# 裸指令
@@ -74,7 +78,7 @@ def register_command_group(command_group_name: str = None, *args):
def decorator(obj):
if add_to_event_filters:
# 根指令组
handler_md = get_handler_or_create(obj)
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent)
handler_md.event_filters.append(new_group)
return RegisteringCommandable(new_group)
@@ -91,28 +95,28 @@ class RegisteringCommandable():
def register_event_message_type(event_message_type: EventMessageType):
'''注册一个 EventMessageType'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
return awatable
return awaitable
return decorator
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
'''注册一个 PlatformAdapterType'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(PlatformAdapterTypeFilter(platform_adapter_type))
return awatable
return awaitable
return decorator
def register_regex(regex: str):
'''注册一个 Regex'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(RegexFilter(regex))
return awatable
return awaitable
return decorator
@@ -123,9 +127,75 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool
permission_type: PermissionType
raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True
'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(PermissionTypeFilter(permission_type, raise_error))
return awatable
return awaitable
return decorator
def register_on_llm_request():
'''当有 LLM 请求时的事件
Examples:
```py
@on_llm_request()
async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None:
request.system_prompt += "你是一个猫娘..."
```
请务必接收两个参数:event, request
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent)
return awaitable
return decorator
def register_llm_tool(name: str = None):
'''为函数调用(function-calling / tools-use)添加工具。
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
```
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult:
\'\'\'获取天气信息。
Args:
location(string): 地点
\'\'\'
# 处理逻辑
```
可接受的参数类型有:string, number, object, array, boolean。
'''
name_ = name
def decorator(awaitable: Awaitable):
llm_tool_name = name_ if name_ else awaitable.__name__
docstring = docstring_parser.parse(awaitable.__doc__)
args = []
for arg in docstring.params:
if arg.type_name not in SUPPORTED_TYPES:
raise ValueError(f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
args.append({
"type": arg.type_name,
"name": arg.arg_name,
"description": arg.description
})
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(llm_tool_name, args, docstring.short_description, md)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return awaitable
return decorator
def register_on_decorating_result():
'''在发送消息前的事件'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent)
return awaitable
return decorator
+40 -4
View File
@@ -1,17 +1,53 @@
from __future__ import annotations
import enum
from dataclasses import dataclass
from typing import Awaitable, List, Dict
from .filter import HandlerFilter
star_handlers_registry: List[StarHandlerMetadata] = []
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
'''用于快速查找。key 是 handler_full_name'''
class StarHandlerRegistry(List):
'''用于存储所有的 Star Handler'''
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
'''用于快速查找。key 是 handler_full_name'''
def append(self, handler: StarHandlerMetadata):
'''添加一个 Handler'''
super().append(handler)
self.star_handlers_map[handler.handler_full_name] = handler
def get_handlers_by_event_type(self, event_type: EventType) -> List[StarHandlerMetadata]:
'''通过事件类型获取 Handler'''
return [handler for handler in self if handler.event_type == event_type]
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
'''通过 Handler 的全名获取 Handler'''
return self.star_handlers_map.get(full_name, None)
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
'''通过模块名获取 Handler'''
return [handler for handler in self if handler.handler_module_str == module_name]
star_handlers_registry = StarHandlerRegistry()
class EventType(enum.Enum):
'''表示一个 AstrBot 内部事件的类型。如适配器消息事件、LLM 请求事件、发送消息前的事件等
用于对 Handler 的职能分组。
'''
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
OnDecoratingResultEvent = enum.auto() # 发送消息前
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
@dataclass
class StarHandlerMetadata():
'''描述一个 Star 所注册的某一个 Handler。'''
event_type: EventType
'''Handler 的事件类型'''
handler_full_name: str
'''格式为 f"{handler.__module__}_{handler.__name__}"'''
@@ -25,7 +61,7 @@ class StarHandlerMetadata():
'''Handler 的函数对象,应当是一个异步函数'''
event_filters: List[HandlerFilter]
'''一个事件过滤器,用于描述这个 Handler 能够处理、应该处理的事件'''
'''一个适配器消息事件过滤器,用于描述这个 Handler 能够处理、应该处理的适配器消息事件'''
desc: str = ""
'''Handler 的描述信息'''
+8
View File
@@ -1,4 +1,5 @@
import inspect
import functools
import os
import traceback
import yaml
@@ -174,6 +175,13 @@ class PluginManager:
star_metadata.module = module
star_metadata.root_dir_name = root_dir_name
star_metadata.reserved = reserved
related_handlers = star_handlers_registry.get_handlers_by_module_name(star_metadata.module_path)
for handler in related_handlers:
logger.debug(f"bind handler {handler.handler_name} to {star_metadata.name}")
# handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self
handler.handler = functools.partial(handler.handler, star_metadata.star_cls)
else:
# v3.4.0 以前的方式注册插件
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")