feat: 新增LLM请求事件钩子和装饰消息结果钩子
This commit is contained in:
@@ -2,7 +2,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.provider.register import register_llm_tool as llm_tool
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
|
||||
__all__ = [
|
||||
"AstrBotConfig",
|
||||
|
||||
+1
-1
@@ -3,7 +3,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.provider.register import register_llm_tool as llm_tool
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
|
||||
# event
|
||||
from astrbot.core.message.message_event_result import (
|
||||
|
||||
@@ -4,7 +4,10 @@ from astrbot.core.star.register import (
|
||||
register_event_message_type as event_message_type,
|
||||
register_regex as regex,
|
||||
register_platform_adapter_type as platform_adapter_type,
|
||||
register_permission_type as permission_type
|
||||
register_permission_type as permission_type,
|
||||
register_on_llm_request as on_llm_request,
|
||||
register_llm_tool as llm_tool,
|
||||
register_on_decorating_result as on_decorating_result
|
||||
)
|
||||
|
||||
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
|
||||
@@ -24,4 +27,7 @@ __all__ = [
|
||||
'PlatformAdapterType',
|
||||
'PermissionTypeFilter',
|
||||
'PermissionType',
|
||||
'on_llm_request',
|
||||
'llm_tool',
|
||||
'on_decorating_result'
|
||||
]
|
||||
@@ -1 +1,2 @@
|
||||
from astrbot.core.provider import Provider, Personality, ProviderMetaData
|
||||
from astrbot.core.provider.provider_request import ProviderRequest
|
||||
@@ -97,7 +97,14 @@ class EventResultType(enum.Enum):
|
||||
'''
|
||||
CONTINUE = enum.auto()
|
||||
STOP = enum.auto()
|
||||
|
||||
|
||||
class ResultContentType(enum.Enum):
|
||||
'''用于描述事件结果的内容的类型。
|
||||
'''
|
||||
LLM_RESULT = enum.auto()
|
||||
'''调用 LLM 产生的结果'''
|
||||
GENERAL_RESULT = enum.auto()
|
||||
'''普通的消息结果'''
|
||||
@dataclass
|
||||
class MessageEventResult(MessageChain):
|
||||
'''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
|
||||
@@ -112,6 +119,8 @@ class MessageEventResult(MessageChain):
|
||||
|
||||
result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE)
|
||||
|
||||
result_content_type: Optional[ResultContentType] = field(default_factory=lambda: ResultContentType.GENERAL_RESULT)
|
||||
|
||||
def stop_event(self) -> 'MessageEventResult':
|
||||
'''终止事件传播。
|
||||
'''
|
||||
@@ -130,5 +139,14 @@ class MessageEventResult(MessageChain):
|
||||
'''
|
||||
return self.result_type == EventResultType.STOP
|
||||
|
||||
def set_result_content_type(self, result_type: EventResultType) -> 'MessageEventResult':
|
||||
'''设置事件处理的结果类型。
|
||||
|
||||
Args:
|
||||
result_type (EventResultType): 事件处理的结果类型。
|
||||
'''
|
||||
self.result_type = result_type
|
||||
return self
|
||||
|
||||
|
||||
CommandResult = MessageEventResult
|
||||
@@ -1,113 +1,87 @@
|
||||
import traceback
|
||||
import inspect
|
||||
import datetime
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.star.star import star_map
|
||||
|
||||
from astrbot.core.provider.provider_request import ProviderRequest
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
|
||||
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
|
||||
self.enable_datetime = ctx.astrbot_config['provider_settings']["datetime_system_prompt"]
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# Chat 唤醒前缀
|
||||
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
|
||||
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
|
||||
return
|
||||
event.message_str = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
|
||||
req: ProviderRequest = None
|
||||
|
||||
if event.get_extra("provider_request"):
|
||||
print("provider_request")
|
||||
req = event.get_extra("provider_request")
|
||||
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
|
||||
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
|
||||
return
|
||||
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
|
||||
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
req.image_urls.append(image_url)
|
||||
req.session_id = event.session_id
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
|
||||
if self.prompt_prefix:
|
||||
event.message_str = self.prompt_prefix + event.message_str
|
||||
req.prompt = self.prompt_prefix + req.prompt
|
||||
if self.identifier:
|
||||
user_id = event.message_obj.sender.user_id
|
||||
user_nickname = event.message_obj.sender.nickname
|
||||
user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||
event.message_str = user_info + event.message_str
|
||||
req.prompt = user_info + req.prompt
|
||||
if self.enable_datetime:
|
||||
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
|
||||
if provider.curr_personality['prompt']:
|
||||
req.system_prompt += f"\n{provider.curr_personality['prompt']}"
|
||||
|
||||
image_urls = []
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
image_url = comp.url if comp.url else comp.file
|
||||
image_urls.append(image_url)
|
||||
|
||||
tools = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider()
|
||||
# 执行请求 LLM 前事件。
|
||||
# 装饰 system_prompt 等功能
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
|
||||
for handler in handlers:
|
||||
try:
|
||||
await handler.handler(event, req)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
try:
|
||||
llm_response = await provider.text_chat(
|
||||
prompt=event.message_str,
|
||||
session_id=event.session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=tools
|
||||
)
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
|
||||
|
||||
if llm_response.role == 'assistant':
|
||||
# text completion
|
||||
event.set_result(MessageEventResult().message(llm_response.completion_text))
|
||||
event.set_result(MessageEventResult().message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.LLM_RESULT))
|
||||
elif llm_response.role == 'tool':
|
||||
# function calling
|
||||
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
|
||||
func_tool = tools.get_func(func_tool_name)
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
|
||||
try:
|
||||
# 尝试调用工具函数
|
||||
|
||||
star_cls_obj = star_map.get(func_tool.module_name).star_cls
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
ready_to_call = None
|
||||
if hasattr(func_tool.func_obj, '__self__'):
|
||||
# 猜测没有通过装饰器去注册
|
||||
try:
|
||||
ready_to_call = func_tool.func_obj(event, **func_tool_args)
|
||||
except TypeError:
|
||||
# 向下兼容
|
||||
ready_to_call = func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args)
|
||||
else:
|
||||
ready_to_call = func_tool.func_obj(star_cls_obj, event, **func_tool_args)
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
async for mer in ready_to_call:
|
||||
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
if mer:
|
||||
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.set_result(mer)
|
||||
yield
|
||||
else:
|
||||
if event.get_result():
|
||||
yield
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个 coroutine
|
||||
ret = await ready_to_call
|
||||
if ret:
|
||||
# 如果有返回值
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
if event.is_stopped() and event.get_result():
|
||||
# 主动停止事件传播,并且有结果
|
||||
event.continue_event()
|
||||
yield
|
||||
event.clear_result()
|
||||
event.stop_event()
|
||||
yield
|
||||
elif not event.is_stopped and not event.get_result():
|
||||
continue
|
||||
else:
|
||||
yield
|
||||
wrapper = self._call_handler(self.ctx, event, func_tool.star_handler_metadata.handler, **func_tool_args)
|
||||
async for _ in wrapper:
|
||||
yield
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
|
||||
|
||||
@@ -2,12 +2,12 @@ from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
from typing import Dict, Any, List, AsyncGenerator, Union
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core.star.star import star_map
|
||||
import traceback
|
||||
import inspect
|
||||
|
||||
class StarRequestSubStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
@@ -27,50 +27,11 @@ class StarRequestSubStage(Stage):
|
||||
if handler.handler_module_str not in star_map:
|
||||
# 孤立无援的 star handler
|
||||
continue
|
||||
star_cls_obj = star_map.get(handler.handler_module_str).star_cls
|
||||
|
||||
logger.debug(f"执行 Star Handler {handler.handler_full_name}")
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
ready_to_call = None
|
||||
if hasattr(handler.handler, '__self__'):
|
||||
# 猜测没有通过装饰器去注册
|
||||
try:
|
||||
ready_to_call = handler.handler(event, **params)
|
||||
except TypeError:
|
||||
# 向下兼容
|
||||
ready_to_call = handler.handler(event, self.ctx.plugin_manager.context, **params)
|
||||
else:
|
||||
ready_to_call = handler.handler(star_cls_obj, event, **params)
|
||||
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
async for mer in ready_to_call:
|
||||
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
if mer:
|
||||
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.set_result(mer)
|
||||
yield
|
||||
else:
|
||||
if event.get_result():
|
||||
yield
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个 coroutine
|
||||
ret = await ready_to_call
|
||||
if ret:
|
||||
# 如果有返回值
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
if event.is_stopped() and event.get_result():
|
||||
# 插件主动停止事件传播,并且有结果
|
||||
event.continue_event()
|
||||
yield
|
||||
event.clear_result()
|
||||
event.stop_event()
|
||||
yield
|
||||
elif not event.is_stopped and not event.get_result():
|
||||
continue
|
||||
else:
|
||||
yield
|
||||
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
|
||||
async for _ in wrapper:
|
||||
yield
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -5,6 +5,7 @@ from .method.llm_request import LLMRequestSubStage
|
||||
from .method.star_request import StarRequestSubStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core.provider.provider_request import ProviderRequest
|
||||
|
||||
@register_stage
|
||||
class ProcessStage(Stage):
|
||||
@@ -25,8 +26,15 @@ class ProcessStage(Stage):
|
||||
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
|
||||
|
||||
if activated_handlers:
|
||||
async for _ in self.star_request_sub_stage.process(event):
|
||||
yield
|
||||
async for resp in self.star_request_sub_stage.process(event):
|
||||
# 生成器返回值处理
|
||||
if isinstance(resp, ProviderRequest):
|
||||
# Handler 的 LLM 请求
|
||||
event.set_extra("provider_request", resp)
|
||||
async for _ in self.llm_request_sub_stage.process(event):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
if self.ctx.astrbot_config['provider_settings'].get('enable', True):
|
||||
if not event._has_send_oper:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage
|
||||
from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core import logger
|
||||
|
||||
@register_stage
|
||||
class RespondStage:
|
||||
class RespondStage(Stage):
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
|
||||
@@ -13,7 +13,7 @@ class RespondStage:
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
|
||||
|
||||
if len(result.chain) > 0:
|
||||
await event.send(result)
|
||||
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
|
||||
|
||||
@@ -6,6 +6,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
@register_stage
|
||||
class ResultDecorateStage:
|
||||
@@ -19,6 +20,11 @@ class ResultDecorateStage:
|
||||
if result is None:
|
||||
return
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
|
||||
for handler in handlers:
|
||||
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
|
||||
await handler.handler(event)
|
||||
|
||||
if len(result.chain) > 0:
|
||||
# 回复前缀
|
||||
if self.reply_prefix:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
from typing import List, AsyncGenerator, Union
|
||||
import inspect
|
||||
from typing import List, AsyncGenerator, Union, Awaitable
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .context import PipelineContext
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
|
||||
registered_stages: List[Stage] = []
|
||||
'''维护了所有已注册的 Stage 实现类'''
|
||||
@@ -29,4 +31,47 @@ class Stage(abc.ABC):
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
async def _call_handler(
|
||||
self,
|
||||
ctx: PipelineContext,
|
||||
event: AstrMessageEvent,
|
||||
handler: Awaitable,
|
||||
**params
|
||||
) -> AsyncGenerator[None, None]:
|
||||
'''调用 Handler。'''
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
ready_to_call = None
|
||||
try:
|
||||
ready_to_call = handler(event, **params)
|
||||
except TypeError as e:
|
||||
print(e)
|
||||
# 向下兼容
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
|
||||
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
async for mer in ready_to_call:
|
||||
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
if mer:
|
||||
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.set_result(mer)
|
||||
yield
|
||||
else:
|
||||
if event.get_result():
|
||||
yield
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个 coroutine
|
||||
ret = await ready_to_call
|
||||
if ret:
|
||||
# 如果有返回值
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
if event.is_stopped() and event.get_result():
|
||||
# 插件主动停止事件传播,并且有结果
|
||||
event.continue_event()
|
||||
yield
|
||||
event.clear_result()
|
||||
event.stop_event()
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
@@ -4,7 +4,7 @@ from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult
|
||||
from astrbot.core.message.components import At
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class WakingCheckStage(Stage):
|
||||
# 检查插件的 handler filter
|
||||
activated_handlers = []
|
||||
handlers_parsed_params = {} # 注册了指令的 handler
|
||||
for handler in star_handlers_registry:
|
||||
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
|
||||
# filter 需要满足 AND 的逻辑关系
|
||||
passed = True
|
||||
child_command_handler_md = None
|
||||
|
||||
@@ -236,7 +236,9 @@ class AstrMessageEvent(abc.ABC):
|
||||
清除消息事件的结果。
|
||||
'''
|
||||
self._result = None
|
||||
|
||||
|
||||
'''消息链相关'''
|
||||
|
||||
def make_result(self) -> MessageEventResult:
|
||||
'''
|
||||
创建一个空的消息事件结果。
|
||||
@@ -275,4 +277,8 @@ class AstrMessageEvent(abc.ABC):
|
||||
'''
|
||||
mer = MessageEventResult()
|
||||
mer.chain = chain
|
||||
return mer
|
||||
return mer
|
||||
|
||||
'''LLM 请求相关'''
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import traceback
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .provider import Provider
|
||||
from typing import List
|
||||
@@ -42,8 +43,12 @@ class ProviderManager():
|
||||
continue
|
||||
cls_type = provider_cls_map[provider_config['type']]
|
||||
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...")
|
||||
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
|
||||
self.provider_insts.append(inst)
|
||||
try:
|
||||
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
|
||||
self.provider_insts.append(inst)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 失败:{e}")
|
||||
|
||||
if len(self.provider_insts) > 0:
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
|
||||
@@ -99,6 +99,7 @@ class Provider(abc.ABC):
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts: List=None,
|
||||
system_prompt: str=None,
|
||||
**kwargs) -> LLMResponse:
|
||||
'''获得 LLM 的文本对话结果。会使用当前的模型进行对话。
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from .tool import FuncCall
|
||||
|
||||
@dataclass
|
||||
class ProviderRequest():
|
||||
prompt: str
|
||||
'''提示词'''
|
||||
session_id: str = ""
|
||||
'''会话 ID'''
|
||||
image_urls: List[str] = None
|
||||
'''图片 URL 列表'''
|
||||
func_tool: FuncCall = None
|
||||
'''工具'''
|
||||
contexts: List = None
|
||||
'''上下文'''
|
||||
system_prompt: str = ""
|
||||
'''系统提示词'''
|
||||
@@ -1,8 +1,7 @@
|
||||
import docstring_parser
|
||||
from typing import List, Dict, Type, Awaitable
|
||||
from typing import List, Dict, Type
|
||||
from .provider_metadata import ProviderMetaData
|
||||
from astrbot.core import logger
|
||||
from .tool import FuncCall, SUPPORTED_TYPES
|
||||
from .tool import FuncCall
|
||||
|
||||
provider_registry: List[ProviderMetaData] = []
|
||||
'''维护了通过装饰器注册的 Provider'''
|
||||
@@ -27,43 +26,3 @@ def register_provider_adapter(provider_type_name: str, desc: str):
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
def register_llm_tool(name: str = None):
|
||||
'''为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
||||
|
||||
```
|
||||
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
|
||||
async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult:
|
||||
\'\'\'获取天气信息。
|
||||
|
||||
Args:
|
||||
location(string): 地点
|
||||
\'\'\'
|
||||
# 处理逻辑
|
||||
```
|
||||
|
||||
可接受的参数类型有:string, number, object, array, boolean。
|
||||
'''
|
||||
name_ = name
|
||||
|
||||
def decorator(func_obj: Awaitable):
|
||||
llm_tool_name = name_ if name_ else func_obj.__name__
|
||||
module_name = func_obj.__module__
|
||||
docstring = docstring_parser.parse(func_obj.__doc__)
|
||||
args = []
|
||||
for arg in docstring.params:
|
||||
if arg.type_name not in SUPPORTED_TYPES:
|
||||
raise ValueError(f"LLM 函数工具 {func_obj.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
|
||||
args.append({
|
||||
"type": arg.type_name,
|
||||
"name": arg.arg_name,
|
||||
"description": arg.description
|
||||
})
|
||||
llm_tools.add_func(llm_tool_name, args, docstring.short_description, func_obj, module_name)
|
||||
|
||||
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
|
||||
return func_obj
|
||||
|
||||
return decorator
|
||||
@@ -1,7 +1,6 @@
|
||||
import traceback
|
||||
import base64
|
||||
import json
|
||||
import datetime
|
||||
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
@@ -29,7 +28,6 @@ class ProviderOpenAIOfficial(Provider):
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.enable_datetime = provider_config.get("datetime_system_prompt", True)
|
||||
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
@@ -133,18 +131,13 @@ class ProviderOpenAIOfficial(Provider):
|
||||
image_urls: List[str]=None,
|
||||
func_tool: FuncCall=None,
|
||||
contexts=None,
|
||||
system_prompt=None,
|
||||
**kwargs
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
|
||||
context_query = []
|
||||
if not contexts:
|
||||
context_query = [*self.session_memory[session_id], new_record]
|
||||
system_prompt = ""
|
||||
if self.curr_personality["prompt"]:
|
||||
system_prompt = self.curr_personality["prompt"]
|
||||
if self.enable_datetime:
|
||||
system_prompt += f"Current datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
else:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
import textwrap
|
||||
from typing import Awaitable, Dict, List
|
||||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
|
||||
class FuncCallJsonFormatError(Exception):
|
||||
def __init__(self, msg):
|
||||
@@ -29,8 +29,7 @@ class FuncTool:
|
||||
name: str
|
||||
parameters: Dict
|
||||
description: str
|
||||
func_obj: Awaitable
|
||||
module_name: str = None
|
||||
star_handler_metadata: StarHandlerMetadata
|
||||
|
||||
active: bool = True
|
||||
'''是否激活'''
|
||||
@@ -56,8 +55,7 @@ class FuncCall:
|
||||
name: str,
|
||||
func_args: list,
|
||||
desc: str,
|
||||
func_obj: Awaitable,
|
||||
module_name: str = None,
|
||||
star_handler_metadata: StarHandlerMetadata,
|
||||
) -> None:
|
||||
"""
|
||||
为函数调用(function-calling / tools-use)添加工具。
|
||||
@@ -80,8 +78,7 @@ class FuncCall:
|
||||
name=name,
|
||||
parameters=params,
|
||||
description=desc,
|
||||
func_obj=func_obj,
|
||||
module_name=module_name,
|
||||
star_handler_metadata=star_handler_metadata,
|
||||
)
|
||||
self.func_list.append(_func)
|
||||
|
||||
@@ -179,8 +176,8 @@ class FuncCall:
|
||||
# 调用函数
|
||||
tool_callable = None
|
||||
for func in self.func_list:
|
||||
if func["name"] == func_name:
|
||||
tool_callable = func["func_obj"]
|
||||
if func.name == func_name:
|
||||
tool_callable = func.star_handler_metadata.handler
|
||||
break
|
||||
if not tool_callable:
|
||||
raise FuncNotFoundError(f"Request function {func_name} not found.")
|
||||
|
||||
@@ -10,7 +10,7 @@ from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.manager import ProviderManager
|
||||
from astrbot.core.platform.manager import PlatformManager
|
||||
from .star import star_registry, StarMetadata
|
||||
from .star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
|
||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from .filter.command import CommandFilter
|
||||
from .filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
@@ -69,6 +69,15 @@ class Context:
|
||||
|
||||
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
|
||||
'''
|
||||
md = StarHandlerMetadata(
|
||||
event_type=EventType.OnLLMRequestEvent,
|
||||
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
|
||||
handler_name=func_obj.__name__,
|
||||
handler_module_str=func_obj.__module__,
|
||||
handler=func_obj,
|
||||
event_filters=[],
|
||||
desc=desc
|
||||
)
|
||||
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__)
|
||||
|
||||
def unregister_llm_tool(self, name: str) -> None:
|
||||
@@ -112,6 +121,7 @@ class Context:
|
||||
|
||||
'''
|
||||
md = StarHandlerMetadata(
|
||||
event_type=EventType.AdapterMessageEvent,
|
||||
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
|
||||
handler_name=awaitable.__name__,
|
||||
handler_module_str=awaitable.__module__,
|
||||
@@ -129,7 +139,6 @@ class Context:
|
||||
handler_md=md
|
||||
))
|
||||
star_handlers_registry.append(md)
|
||||
star_handlers_map[md.handler_full_name] = md
|
||||
|
||||
def register_provider(self, provider: Provider):
|
||||
'''
|
||||
|
||||
@@ -6,8 +6,8 @@ from astrbot.core.config import AstrBotConfig
|
||||
class PermissionType(enum.Flag):
|
||||
'''权限类型。当选择 MEMBER,ADMIN 也可以通过。
|
||||
'''
|
||||
ADMIN = "admin"
|
||||
MEMBER = "member"
|
||||
ADMIN = enum.auto()
|
||||
MEMBER = enum.auto()
|
||||
|
||||
class PermissionTypeFilter(HandlerFilter):
|
||||
def __init__(self, permission_type: PermissionType, raise_error: bool = True):
|
||||
|
||||
@@ -5,7 +5,10 @@ from .star_handler import (
|
||||
register_event_message_type,
|
||||
register_platform_adapter_type,
|
||||
register_regex,
|
||||
register_permission_type
|
||||
register_permission_type,
|
||||
register_on_llm_request,
|
||||
register_llm_tool,
|
||||
register_on_decorating_result
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -15,5 +18,8 @@ __all__ = [
|
||||
'register_event_message_type',
|
||||
'register_platform_adapter_type',
|
||||
'register_regex',
|
||||
'register_permission_type'
|
||||
'register_permission_type',
|
||||
'register_on_llm_request',
|
||||
'register_llm_tool',
|
||||
'register_on_decorating_result'
|
||||
]
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import docstring_parser
|
||||
|
||||
from ..star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
|
||||
from ..star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||
from ..filter.command import CommandFilter
|
||||
from ..filter.command_group import CommandGroupFilter
|
||||
from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType
|
||||
@@ -8,19 +9,23 @@ from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAd
|
||||
from ..filter.permission import PermissionTypeFilter, PermissionType
|
||||
from ..filter.regex import RegexFilter
|
||||
from typing import Awaitable
|
||||
from astrbot.core.provider.tool import SUPPORTED_TYPES
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
def get_handler_full_name(awatable: Awaitable) -> str:
|
||||
def get_handler_full_name(awaitable: Awaitable) -> str:
|
||||
'''获取 Handler 的全名'''
|
||||
return f"{awatable.__module__}_{awatable.__name__}"
|
||||
return f"{awaitable.__module__}_{awaitable.__name__}"
|
||||
|
||||
def get_handler_or_create(handler: Awaitable, dont_add = False) -> StarHandlerMetadata:
|
||||
def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add = False) -> StarHandlerMetadata:
|
||||
'''获取 Handler 或者创建一个新的 Handler'''
|
||||
handler_full_name = get_handler_full_name(handler)
|
||||
if handler_full_name in star_handlers_map:
|
||||
return star_handlers_map[handler_full_name]
|
||||
md = star_handlers_registry.get_handler_by_full_name(handler_full_name)
|
||||
if md:
|
||||
return md
|
||||
else:
|
||||
md = StarHandlerMetadata(
|
||||
event_type=event_type,
|
||||
handler_full_name=handler_full_name,
|
||||
handler_name=handler.__name__,
|
||||
handler_module_str=handler.__module__,
|
||||
@@ -29,7 +34,6 @@ def get_handler_or_create(handler: Awaitable, dont_add = False) -> StarHandlerMe
|
||||
)
|
||||
if not dont_add:
|
||||
star_handlers_registry.append(md)
|
||||
star_handlers_map[handler_full_name] = md
|
||||
return md
|
||||
|
||||
def register_command(command_name: str = None, *args):
|
||||
@@ -47,7 +51,7 @@ def register_command(command_name: str = None, *args):
|
||||
add_to_event_filters = True
|
||||
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable)
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
|
||||
new_command.init_handler_md(handler_md)
|
||||
if add_to_event_filters:
|
||||
# 裸指令
|
||||
@@ -74,7 +78,7 @@ def register_command_group(command_group_name: str = None, *args):
|
||||
def decorator(obj):
|
||||
if add_to_event_filters:
|
||||
# 根指令组
|
||||
handler_md = get_handler_or_create(obj)
|
||||
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent)
|
||||
handler_md.event_filters.append(new_group)
|
||||
|
||||
return RegisteringCommandable(new_group)
|
||||
@@ -91,28 +95,28 @@ class RegisteringCommandable():
|
||||
|
||||
def register_event_message_type(event_message_type: EventMessageType):
|
||||
'''注册一个 EventMessageType'''
|
||||
def decorator(awatable):
|
||||
handler_md = get_handler_or_create(awatable)
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
|
||||
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
|
||||
return awatable
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
|
||||
'''注册一个 PlatformAdapterType'''
|
||||
def decorator(awatable):
|
||||
handler_md = get_handler_or_create(awatable)
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
|
||||
handler_md.event_filters.append(PlatformAdapterTypeFilter(platform_adapter_type))
|
||||
return awatable
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_regex(regex: str):
|
||||
'''注册一个 Regex'''
|
||||
def decorator(awatable):
|
||||
handler_md = get_handler_or_create(awatable)
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
|
||||
handler_md.event_filters.append(RegexFilter(regex))
|
||||
return awatable
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -123,9 +127,75 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool
|
||||
permission_type: PermissionType
|
||||
raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True
|
||||
'''
|
||||
def decorator(awatable):
|
||||
handler_md = get_handler_or_create(awatable)
|
||||
def decorator(awaitable):
|
||||
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
|
||||
handler_md.event_filters.append(PermissionTypeFilter(permission_type, raise_error))
|
||||
return awatable
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_llm_request():
|
||||
'''当有 LLM 请求时的事件
|
||||
|
||||
Examples:
|
||||
```py
|
||||
@on_llm_request()
|
||||
async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None:
|
||||
request.system_prompt += "你是一个猫娘..."
|
||||
```
|
||||
|
||||
请务必接收两个参数:event, request
|
||||
'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_llm_tool(name: str = None):
|
||||
'''为函数调用(function-calling / tools-use)添加工具。
|
||||
|
||||
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
||||
|
||||
```
|
||||
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
|
||||
async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult:
|
||||
\'\'\'获取天气信息。
|
||||
|
||||
Args:
|
||||
location(string): 地点
|
||||
\'\'\'
|
||||
# 处理逻辑
|
||||
```
|
||||
|
||||
可接受的参数类型有:string, number, object, array, boolean。
|
||||
'''
|
||||
name_ = name
|
||||
|
||||
def decorator(awaitable: Awaitable):
|
||||
llm_tool_name = name_ if name_ else awaitable.__name__
|
||||
docstring = docstring_parser.parse(awaitable.__doc__)
|
||||
args = []
|
||||
for arg in docstring.params:
|
||||
if arg.type_name not in SUPPORTED_TYPES:
|
||||
raise ValueError(f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
|
||||
args.append({
|
||||
"type": arg.type_name,
|
||||
"name": arg.arg_name,
|
||||
"description": arg.description
|
||||
})
|
||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||
llm_tools.add_func(llm_tool_name, args, docstring.short_description, md)
|
||||
|
||||
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_on_decorating_result():
|
||||
'''在发送消息前的事件'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
@@ -1,17 +1,53 @@
|
||||
from __future__ import annotations
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, List, Dict
|
||||
from .filter import HandlerFilter
|
||||
|
||||
star_handlers_registry: List[StarHandlerMetadata] = []
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
'''用于快速查找。key 是 handler_full_name'''
|
||||
class StarHandlerRegistry(List):
|
||||
'''用于存储所有的 Star Handler'''
|
||||
|
||||
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
|
||||
'''用于快速查找。key 是 handler_full_name'''
|
||||
|
||||
def append(self, handler: StarHandlerMetadata):
|
||||
'''添加一个 Handler'''
|
||||
super().append(handler)
|
||||
self.star_handlers_map[handler.handler_full_name] = handler
|
||||
|
||||
def get_handlers_by_event_type(self, event_type: EventType) -> List[StarHandlerMetadata]:
|
||||
'''通过事件类型获取 Handler'''
|
||||
return [handler for handler in self if handler.event_type == event_type]
|
||||
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||
'''通过 Handler 的全名获取 Handler'''
|
||||
return self.star_handlers_map.get(full_name, None)
|
||||
|
||||
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
|
||||
'''通过模块名获取 Handler'''
|
||||
return [handler for handler in self if handler.handler_module_str == module_name]
|
||||
|
||||
|
||||
star_handlers_registry = StarHandlerRegistry()
|
||||
|
||||
class EventType(enum.Enum):
|
||||
'''表示一个 AstrBot 内部事件的类型。如适配器消息事件、LLM 请求事件、发送消息前的事件等
|
||||
|
||||
用于对 Handler 的职能分组。
|
||||
'''
|
||||
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
||||
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
||||
OnDecoratingResultEvent = enum.auto() # 发送消息前
|
||||
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
|
||||
|
||||
@dataclass
|
||||
class StarHandlerMetadata():
|
||||
'''描述一个 Star 所注册的某一个 Handler。'''
|
||||
|
||||
event_type: EventType
|
||||
'''Handler 的事件类型'''
|
||||
|
||||
handler_full_name: str
|
||||
'''格式为 f"{handler.__module__}_{handler.__name__}"'''
|
||||
|
||||
@@ -25,7 +61,7 @@ class StarHandlerMetadata():
|
||||
'''Handler 的函数对象,应当是一个异步函数'''
|
||||
|
||||
event_filters: List[HandlerFilter]
|
||||
'''一个事件过滤器,用于描述这个 Handler 能够处理、应该处理的事件'''
|
||||
'''一个适配器消息事件过滤器,用于描述这个 Handler 能够处理、应该处理的适配器消息事件'''
|
||||
|
||||
desc: str = ""
|
||||
'''Handler 的描述信息'''
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import functools
|
||||
import os
|
||||
import traceback
|
||||
import yaml
|
||||
@@ -174,6 +175,13 @@ class PluginManager:
|
||||
star_metadata.module = module
|
||||
star_metadata.root_dir_name = root_dir_name
|
||||
star_metadata.reserved = reserved
|
||||
|
||||
related_handlers = star_handlers_registry.get_handlers_by_module_name(star_metadata.module_path)
|
||||
for handler in related_handlers:
|
||||
logger.debug(f"bind handler {handler.handler_name} to {star_metadata.name}")
|
||||
# handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self
|
||||
handler.handler = functools.partial(handler.handler, star_metadata.star_cls)
|
||||
|
||||
else:
|
||||
# v3.4.0 以前的方式注册插件
|
||||
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
|
||||
|
||||
Reference in New Issue
Block a user