feat: 添加发送消息后的事件钩子
This commit is contained in:
@@ -7,7 +7,8 @@ from astrbot.core.star.register import (
|
||||
register_permission_type as permission_type,
|
||||
register_on_llm_request as on_llm_request,
|
||||
register_llm_tool as llm_tool,
|
||||
register_on_decorating_result as on_decorating_result
|
||||
register_on_decorating_result as on_decorating_result,
|
||||
register_after_message_sent as after_message_sent
|
||||
)
|
||||
|
||||
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
|
||||
@@ -29,5 +30,6 @@ __all__ = [
|
||||
'PermissionType',
|
||||
'on_llm_request',
|
||||
'llm_tool',
|
||||
'on_decorating_result'
|
||||
'on_decorating_result',
|
||||
'after_message_sent'
|
||||
]
|
||||
@@ -3,6 +3,7 @@ from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
@register_stage
|
||||
class RespondStage(Stage):
|
||||
@@ -17,4 +18,8 @@ class RespondStage(Stage):
|
||||
if len(result.chain) > 0:
|
||||
await event.send(result)
|
||||
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
|
||||
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
|
||||
for handler in handlers:
|
||||
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
|
||||
await handler.handler(event)
|
||||
@@ -35,16 +35,22 @@ class ProviderOpenAIOfficial(Provider):
|
||||
timeout=provider_config.get("timeout", NOT_GIVEN),
|
||||
)
|
||||
self.set_model(provider_config['model_config']['model'])
|
||||
|
||||
|
||||
async def get_human_readable_context(self, session_id, page, page_size):
|
||||
if session_id not in self.session_memory:
|
||||
raise Exception("会话 ID 不存在")
|
||||
contexts = []
|
||||
temp_contexts = []
|
||||
for record in self.session_memory[session_id]:
|
||||
if record['role'] == "user":
|
||||
contexts.append(f"User: {record['content']}")
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record['role'] == "assistant":
|
||||
contexts.append(f"Assistant: {record['content']}")
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
# 展平 contexts 列表
|
||||
contexts = [item for sublist in contexts for item in sublist]
|
||||
|
||||
# 计算分页
|
||||
paged_contexts = contexts[(page-1)*page_size:page*page_size]
|
||||
|
||||
@@ -8,7 +8,8 @@ from .star_handler import (
|
||||
register_permission_type,
|
||||
register_on_llm_request,
|
||||
register_llm_tool,
|
||||
register_on_decorating_result
|
||||
register_on_decorating_result,
|
||||
register_after_message_sent
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -21,5 +22,6 @@ __all__ = [
|
||||
'register_permission_type',
|
||||
'register_on_llm_request',
|
||||
'register_llm_tool',
|
||||
'register_on_decorating_result'
|
||||
'register_on_decorating_result',
|
||||
'register_after_message_sent'
|
||||
]
|
||||
@@ -198,4 +198,12 @@ def register_on_decorating_result():
|
||||
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
|
||||
def register_after_message_sent():
|
||||
'''在消息发送后的事件'''
|
||||
def decorator(awaitable):
|
||||
_ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent)
|
||||
return awaitable
|
||||
|
||||
return decorator
|
||||
@@ -40,6 +40,7 @@ class EventType(enum.Enum):
|
||||
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
||||
OnDecoratingResultEvent = enum.auto() # 发送消息前
|
||||
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
|
||||
OnAfterMessageSentEvent = enum.auto() # 发送消息后
|
||||
|
||||
@dataclass
|
||||
class StarHandlerMetadata():
|
||||
|
||||
@@ -20,14 +20,18 @@ class Main(star.Star):
|
||||
f.write("{}")
|
||||
with open("data/astrbot-reminder.json", "r") as f:
|
||||
self.reminder_data = json.load(f)
|
||||
|
||||
|
||||
self._init_scheduler()
|
||||
self.scheduler.start()
|
||||
|
||||
|
||||
async def _init_scheduler(self):
|
||||
'''Initialize the scheduler.'''
|
||||
for group in self.reminder_data:
|
||||
for reminder in self.reminder_data[group]:
|
||||
self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"]], id=group, trigger=reminder["cron"])
|
||||
if "datetime" in reminder:
|
||||
self.scheduler.add_job(self._reminder_callback, 'date', args=[reminder["text"]], id=group, run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"))
|
||||
elif "cron" in reminder:
|
||||
self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"]], id=group, trigger=reminder["cron"])
|
||||
|
||||
async def _save_data(self):
|
||||
'''Save the reminder data.'''
|
||||
|
||||
Reference in New Issue
Block a user