feat: reminder
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import traceback
|
||||
import inspect
|
||||
from typing import Union, AsyncGenerator
|
||||
from ...context import PipelineContext
|
||||
from ..stage import Stage
|
||||
@@ -63,23 +64,46 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
star_cls_obj = star_map.get(func_tool.module_name).star_cls
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
ready_to_call = None
|
||||
if hasattr(func_tool.func_obj, '__self__'):
|
||||
# 猜测没有通过装饰器去注册
|
||||
try:
|
||||
ret = await func_tool.func_obj(event, **func_tool_args)
|
||||
ready_to_call = func_tool.func_obj(event, **func_tool_args)
|
||||
except TypeError:
|
||||
# 向下兼容
|
||||
ret = await func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args)
|
||||
ready_to_call = func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args)
|
||||
else:
|
||||
ret = await func_tool.func_obj(star_cls_obj, event, **func_tool_args)
|
||||
|
||||
if ret:
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.stop_event()
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
yield
|
||||
event.clear_result() # 清除上一个 func tool 的结果
|
||||
ready_to_call = func_tool.func_obj(star_cls_obj, event, **func_tool_args)
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
async for mer in ready_to_call:
|
||||
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
if mer:
|
||||
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.set_result(mer)
|
||||
yield
|
||||
else:
|
||||
if event.get_result():
|
||||
yield
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个 coroutine
|
||||
ret = await ready_to_call
|
||||
if ret:
|
||||
# 如果有返回值
|
||||
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
|
||||
event.set_result(ret)
|
||||
# 执行后续步骤来发送消息
|
||||
if event.is_stopped() and event.get_result():
|
||||
# 主动停止事件传播,并且有结果
|
||||
event.continue_event()
|
||||
yield
|
||||
event.clear_result()
|
||||
event.stop_event()
|
||||
yield
|
||||
elif not event.is_stopped and not event.get_result():
|
||||
continue
|
||||
else:
|
||||
yield
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
@@ -196,7 +196,7 @@ class Context:
|
||||
except BaseException as e:
|
||||
raise ValueError("不合法的 session 字符串: " + str(e))
|
||||
|
||||
for platform in self.registered_platforms:
|
||||
for platform in self.platform_manager.platform_insts:
|
||||
if platform.meta().name == session.platform_name:
|
||||
await platform.send_by_session(session, message_chain)
|
||||
return True
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
import os
|
||||
import json
|
||||
import datetime
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import filter
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import llm_tool, logger
|
||||
|
||||
@star.register(name="astrbot-reminder", desc="使用 LLM 待办提醒", author="Soulter", version="0.0.1")
|
||||
class Main(star.Star):
|
||||
'''使用 LLM 待办提醒。只需对 LLM 说想要提醒的事情和时间即可。比如:`之后每天这个时候都提醒我做多邻国`'''
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
|
||||
# set and load config
|
||||
if not os.path.exists("data/astrbot-reminder.json"):
|
||||
with open("data/astrbot-reminder.json", "w") as f:
|
||||
f.write("{}")
|
||||
with open("data/astrbot-reminder.json", "r") as f:
|
||||
self.reminder_data = json.load(f)
|
||||
|
||||
self.scheduler.start()
|
||||
|
||||
async def _init_scheduler(self):
|
||||
'''Initialize the scheduler.'''
|
||||
for group in self.reminder_data:
|
||||
for reminder in self.reminder_data[group]:
|
||||
self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"]], id=group, trigger=reminder["cron"])
|
||||
|
||||
async def _save_data(self):
|
||||
'''Save the reminder data.'''
|
||||
with open("data/astrbot-reminder.json", "w") as f:
|
||||
json.dump(self.reminder_data, f, ensure_ascii=False)
|
||||
|
||||
def _parse_cron_expr(self, cron_expr: str):
|
||||
fields = cron_expr.split(" ")
|
||||
return {
|
||||
"minute": fields[0],
|
||||
"hour": fields[1],
|
||||
"day": fields[2],
|
||||
"month": fields[3],
|
||||
"day_of_week": fields[4],
|
||||
}
|
||||
|
||||
@llm_tool("reminder")
|
||||
async def reminder_tool(self, event: AstrMessageEvent, text: str, datetime_str: str = None, cron_expression: str = None, human_readable_cron: str = None):
|
||||
'''Call this function when user ask for setting a reminder.
|
||||
|
||||
Args:
|
||||
text(string): The content of the reminder.
|
||||
datetime_str(string): Required when user's reminder is a single reminder. The datetime string of the reminder, Must format with %Y-%m-%d %H:%M
|
||||
cron_expression(string): Required when user's reminder is a repeated reminder. The cron expression of the reminder.
|
||||
human_readable_cron(string): Optional. The human readable cron expression of the reminder.
|
||||
'''
|
||||
if event.unified_msg_origin not in self.reminder_data:
|
||||
self.reminder_data[event.unified_msg_origin] = []
|
||||
|
||||
if not cron_expression and not datetime_str:
|
||||
raise ValueError("The cron_expression and datetime_str cannot be both None.")
|
||||
reminder_time = ""
|
||||
if cron_expression:
|
||||
d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron }
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d], id=event.unified_msg_origin)
|
||||
if human_readable_cron:
|
||||
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
|
||||
else:
|
||||
d = { "text": text, "datetime": datetime_str }
|
||||
self.reminder_data[event.unified_msg_origin].append(d)
|
||||
datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M")
|
||||
self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], id=event.unified_msg_origin, run_date=datetime_scheduled)
|
||||
reminder_time = datetime_str
|
||||
await self._save_data()
|
||||
yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。")
|
||||
|
||||
@filter.command_group("reminder")
|
||||
def reminder(self):
|
||||
'''The command group of the reminder.'''
|
||||
pass
|
||||
|
||||
@reminder.command("ls")
|
||||
async def reminder_ls(self, event: AstrMessageEvent):
|
||||
'''List all reminders.'''
|
||||
reminders = self.reminder_data.get(event.unified_msg_origin, [])
|
||||
if not reminders:
|
||||
yield event.plain_result("没有待办事项。")
|
||||
else:
|
||||
reminder_str = "待办事项:\n"
|
||||
for i, reminder in enumerate(reminders):
|
||||
time_ = reminder.get("datetime", "")
|
||||
if not time_:
|
||||
time_ = reminder.get("cron_h", "") + f"(Cron: {reminder.get('cron', "")})"
|
||||
reminder_str += f"{i + 1}. {reminder['text']} - {time_}\n"
|
||||
reminder_str += "\n使用 /reminder rm <index> 删除待办事项。"
|
||||
yield event.plain_result(reminder_str)
|
||||
|
||||
@reminder.command("rm")
|
||||
async def reminder_rm(self, event: AstrMessageEvent, index: int):
|
||||
'''Remove a reminder by index.'''
|
||||
reminders = self.reminder_data.get(event.unified_msg_origin, [])
|
||||
if not reminders:
|
||||
yield event.plain_result("没有待办事项。")
|
||||
elif index < 1 or index > len(reminders):
|
||||
yield event.plain_result("索引越界。")
|
||||
else:
|
||||
reminder = reminders.pop(index - 1)
|
||||
self.scheduler.remove_job(event.unified_msg_origin)
|
||||
await self._save_data()
|
||||
yield event.plain_result("成功删除待办事项:\n" + reminder["text"])
|
||||
|
||||
async def _reminder_callback(self, unified_msg_origin: str, d: dict):
|
||||
'''The callback function of the reminder.'''
|
||||
logger.info(f"Reminder Activated: {d['text']}, created by {unified_msg_origin}")
|
||||
await self.context.send_message(unified_msg_origin, MessageEventResult().message("待办提醒: \n\n" + d['text'] + "\n时间: " + d.get("datetime", "") + d.get("cron_h", "")))
|
||||
Reference in New Issue
Block a user