feat: reminder

This commit is contained in:
Soulter
2024-12-16 20:02:50 +08:00
parent 9c50889371
commit acb3af8ab8
3 changed files with 152 additions and 12 deletions
@@ -1,4 +1,5 @@
import traceback
import inspect
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
@@ -63,23 +64,46 @@ class LLMRequestSubStage(Stage):
star_cls_obj = star_map.get(func_tool.module_name).star_cls
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
if hasattr(func_tool.func_obj, '__self__'):
# 猜测没有通过装饰器去注册
try:
ret = await func_tool.func_obj(event, **func_tool_args)
ready_to_call = func_tool.func_obj(event, **func_tool_args)
except TypeError:
# 向下兼容
ret = await func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args)
ready_to_call = func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args)
else:
ret = await func_tool.func_obj(star_cls_obj, event, **func_tool_args)
if ret:
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.stop_event()
event.set_result(ret)
# 执行后续步骤来发送消息
yield
event.clear_result() # 清除上一个 func tool 的结果
ready_to_call = func_tool.func_obj(star_cls_obj, event, **func_tool_args)
if isinstance(ready_to_call, AsyncGenerator):
async for mer in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
if mer:
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(mer)
yield
else:
if event.get_result():
yield
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if ret:
# 如果有返回值
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(ret)
# 执行后续步骤来发送消息
if event.is_stopped() and event.get_result():
# 主动停止事件传播,并且有结果
event.continue_event()
yield
event.clear_result()
event.stop_event()
yield
elif not event.is_stopped and not event.get_result():
continue
else:
yield
event.clear_result() # 清除上一个 handler 的结果
except BaseException:
logger.error(traceback.format_exc())
+1 -1
View File
@@ -196,7 +196,7 @@ class Context:
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.registered_platforms:
for platform in self.platform_manager.platform_insts:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
+116
View File
@@ -0,0 +1,116 @@
import os
import json
import datetime
import astrbot.api.star as star
from astrbot.api.event import filter
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import llm_tool, logger
@star.register(name="astrbot-reminder", desc="使用 LLM 待办提醒", author="Soulter", version="0.0.1")
class Main(star.Star):
'''使用 LLM 待办提醒。只需对 LLM 说想要提醒的事情和时间即可。比如:`之后每天这个时候都提醒我做多邻国`'''
def __init__(self, context: star.Context) -> None:
self.context = context
self.scheduler = AsyncIOScheduler()
# set and load config
if not os.path.exists("data/astrbot-reminder.json"):
with open("data/astrbot-reminder.json", "w") as f:
f.write("{}")
with open("data/astrbot-reminder.json", "r") as f:
self.reminder_data = json.load(f)
self.scheduler.start()
async def _init_scheduler(self):
'''Initialize the scheduler.'''
for group in self.reminder_data:
for reminder in self.reminder_data[group]:
self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"]], id=group, trigger=reminder["cron"])
async def _save_data(self):
'''Save the reminder data.'''
with open("data/astrbot-reminder.json", "w") as f:
json.dump(self.reminder_data, f, ensure_ascii=False)
def _parse_cron_expr(self, cron_expr: str):
fields = cron_expr.split(" ")
return {
"minute": fields[0],
"hour": fields[1],
"day": fields[2],
"month": fields[3],
"day_of_week": fields[4],
}
@llm_tool("reminder")
async def reminder_tool(self, event: AstrMessageEvent, text: str, datetime_str: str = None, cron_expression: str = None, human_readable_cron: str = None):
'''Call this function when user ask for setting a reminder.
Args:
text(string): The content of the reminder.
datetime_str(string): Required when user's reminder is a single reminder. The datetime string of the reminder, Must format with %Y-%m-%d %H:%M
cron_expression(string): Required when user's reminder is a repeated reminder. The cron expression of the reminder.
human_readable_cron(string): Optional. The human readable cron expression of the reminder.
'''
if event.unified_msg_origin not in self.reminder_data:
self.reminder_data[event.unified_msg_origin] = []
if not cron_expression and not datetime_str:
raise ValueError("The cron_expression and datetime_str cannot be both None.")
reminder_time = ""
if cron_expression:
d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron }
self.reminder_data[event.unified_msg_origin].append(d)
self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d], id=event.unified_msg_origin)
if human_readable_cron:
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
else:
d = { "text": text, "datetime": datetime_str }
self.reminder_data[event.unified_msg_origin].append(d)
datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M")
self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], id=event.unified_msg_origin, run_date=datetime_scheduled)
reminder_time = datetime_str
await self._save_data()
yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。")
@filter.command_group("reminder")
def reminder(self):
'''The command group of the reminder.'''
pass
@reminder.command("ls")
async def reminder_ls(self, event: AstrMessageEvent):
'''List all reminders.'''
reminders = self.reminder_data.get(event.unified_msg_origin, [])
if not reminders:
yield event.plain_result("没有待办事项。")
else:
reminder_str = "待办事项:\n"
for i, reminder in enumerate(reminders):
time_ = reminder.get("datetime", "")
if not time_:
time_ = reminder.get("cron_h", "") + f"(Cron: {reminder.get('cron', "")})"
reminder_str += f"{i + 1}. {reminder['text']} - {time_}\n"
reminder_str += "\n使用 /reminder rm <index> 删除待办事项。"
yield event.plain_result(reminder_str)
@reminder.command("rm")
async def reminder_rm(self, event: AstrMessageEvent, index: int):
'''Remove a reminder by index.'''
reminders = self.reminder_data.get(event.unified_msg_origin, [])
if not reminders:
yield event.plain_result("没有待办事项。")
elif index < 1 or index > len(reminders):
yield event.plain_result("索引越界。")
else:
reminder = reminders.pop(index - 1)
self.scheduler.remove_job(event.unified_msg_origin)
await self._save_data()
yield event.plain_result("成功删除待办事项:\n" + reminder["text"])
async def _reminder_callback(self, unified_msg_origin: str, d: dict):
'''The callback function of the reminder.'''
logger.info(f"Reminder Activated: {d['text']}, created by {unified_msg_origin}")
await self.context.send_message(unified_msg_origin, MessageEventResult().message("待办提醒: \n\n" + d['text'] + "\n时间: " + d.get("datetime", "") + d.get("cron_h", "")))