From acb3af8ab8f7bbe3816cc2c8aec741dd00b04952 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 16 Dec 2024 20:02:50 +0800 Subject: [PATCH] feat: reminder --- .../process_stage/method/llm_request.py | 46 +++++-- astrbot/core/star/context.py | 2 +- packages/reminder/main.py | 116 ++++++++++++++++++ 3 files changed, 152 insertions(+), 12 deletions(-) create mode 100644 packages/reminder/main.py diff --git a/astrbot/core/pipeline/process_stage/method/llm_request.py b/astrbot/core/pipeline/process_stage/method/llm_request.py index fadd7b925..cfa0acc58 100644 --- a/astrbot/core/pipeline/process_stage/method/llm_request.py +++ b/astrbot/core/pipeline/process_stage/method/llm_request.py @@ -1,4 +1,5 @@ import traceback +import inspect from typing import Union, AsyncGenerator from ...context import PipelineContext from ..stage import Stage @@ -63,23 +64,46 @@ class LLMRequestSubStage(Stage): star_cls_obj = star_map.get(func_tool.module_name).star_cls # 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性) + ready_to_call = None if hasattr(func_tool.func_obj, '__self__'): # 猜测没有通过装饰器去注册 try: - ret = await func_tool.func_obj(event, **func_tool_args) + ready_to_call = func_tool.func_obj(event, **func_tool_args) except TypeError: # 向下兼容 - ret = await func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args) + ready_to_call = func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args) else: - ret = await func_tool.func_obj(star_cls_obj, event, **func_tool_args) - - if ret: - assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" - event.stop_event() - event.set_result(ret) - # 执行后续步骤来发送消息 - yield - event.clear_result() # 清除上一个 func tool 的结果 + ready_to_call = func_tool.func_obj(star_cls_obj, event, **func_tool_args) + if isinstance(ready_to_call, AsyncGenerator): + async for mer in ready_to_call: + # 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值) + if mer: + assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" + event.set_result(mer) + yield + else: + if event.get_result(): + yield + elif inspect.iscoroutine(ready_to_call): + # 如果只是一个 coroutine + ret = await ready_to_call + if ret: + # 如果有返回值 + assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。" + event.set_result(ret) + # 执行后续步骤来发送消息 + if event.is_stopped() and event.get_result(): + # 主动停止事件传播,并且有结果 + event.continue_event() + yield + event.clear_result() + event.stop_event() + yield + elif not event.is_stopped and not event.get_result(): + continue + else: + yield + event.clear_result() # 清除上一个 handler 的结果 except BaseException: logger.error(traceback.format_exc()) diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 0305bd041..b828f597f 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -196,7 +196,7 @@ class Context: except BaseException as e: raise ValueError("不合法的 session 字符串: " + str(e)) - for platform in self.registered_platforms: + for platform in self.platform_manager.platform_insts: if platform.meta().name == session.platform_name: await platform.send_by_session(session, message_chain) return True diff --git a/packages/reminder/main.py b/packages/reminder/main.py new file mode 100644 index 000000000..df896ec4f --- /dev/null +++ b/packages/reminder/main.py @@ -0,0 +1,116 @@ +import os +import json +import datetime +import astrbot.api.star as star +from astrbot.api.event import filter +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from astrbot.api.event import AstrMessageEvent, MessageEventResult +from astrbot.api import llm_tool, logger + +@star.register(name="astrbot-reminder", desc="使用 LLM 待办提醒", author="Soulter", version="0.0.1") +class Main(star.Star): + '''使用 LLM 待办提醒。只需对 LLM 说想要提醒的事情和时间即可。比如:`之后每天这个时候都提醒我做多邻国`''' + def __init__(self, context: star.Context) -> None: + self.context = context + self.scheduler = AsyncIOScheduler() + + # set and load config + if not os.path.exists("data/astrbot-reminder.json"): + with open("data/astrbot-reminder.json", "w") as f: + f.write("{}") + with open("data/astrbot-reminder.json", "r") as f: + self.reminder_data = json.load(f) + + self.scheduler.start() + + async def _init_scheduler(self): + '''Initialize the scheduler.''' + for group in self.reminder_data: + for reminder in self.reminder_data[group]: + self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"]], id=group, trigger=reminder["cron"]) + + async def _save_data(self): + '''Save the reminder data.''' + with open("data/astrbot-reminder.json", "w") as f: + json.dump(self.reminder_data, f, ensure_ascii=False) + + def _parse_cron_expr(self, cron_expr: str): + fields = cron_expr.split(" ") + return { + "minute": fields[0], + "hour": fields[1], + "day": fields[2], + "month": fields[3], + "day_of_week": fields[4], + } + + @llm_tool("reminder") + async def reminder_tool(self, event: AstrMessageEvent, text: str, datetime_str: str = None, cron_expression: str = None, human_readable_cron: str = None): + '''Call this function when user ask for setting a reminder. + + Args: + text(string): The content of the reminder. + datetime_str(string): Required when user's reminder is a single reminder. The datetime string of the reminder, Must format with %Y-%m-%d %H:%M + cron_expression(string): Required when user's reminder is a repeated reminder. The cron expression of the reminder. + human_readable_cron(string): Optional. The human readable cron expression of the reminder. + ''' + if event.unified_msg_origin not in self.reminder_data: + self.reminder_data[event.unified_msg_origin] = [] + + if not cron_expression and not datetime_str: + raise ValueError("The cron_expression and datetime_str cannot be both None.") + reminder_time = "" + if cron_expression: + d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron } + self.reminder_data[event.unified_msg_origin].append(d) + self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d], id=event.unified_msg_origin) + if human_readable_cron: + reminder_time = f"{human_readable_cron}(Cron: {cron_expression})" + else: + d = { "text": text, "datetime": datetime_str } + self.reminder_data[event.unified_msg_origin].append(d) + datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M") + self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], id=event.unified_msg_origin, run_date=datetime_scheduled) + reminder_time = datetime_str + await self._save_data() + yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。") + + @filter.command_group("reminder") + def reminder(self): + '''The command group of the reminder.''' + pass + + @reminder.command("ls") + async def reminder_ls(self, event: AstrMessageEvent): + '''List all reminders.''' + reminders = self.reminder_data.get(event.unified_msg_origin, []) + if not reminders: + yield event.plain_result("没有待办事项。") + else: + reminder_str = "待办事项:\n" + for i, reminder in enumerate(reminders): + time_ = reminder.get("datetime", "") + if not time_: + time_ = reminder.get("cron_h", "") + f"(Cron: {reminder.get('cron', "")})" + reminder_str += f"{i + 1}. {reminder['text']} - {time_}\n" + reminder_str += "\n使用 /reminder rm 删除待办事项。" + yield event.plain_result(reminder_str) + + @reminder.command("rm") + async def reminder_rm(self, event: AstrMessageEvent, index: int): + '''Remove a reminder by index.''' + reminders = self.reminder_data.get(event.unified_msg_origin, []) + if not reminders: + yield event.plain_result("没有待办事项。") + elif index < 1 or index > len(reminders): + yield event.plain_result("索引越界。") + else: + reminder = reminders.pop(index - 1) + self.scheduler.remove_job(event.unified_msg_origin) + await self._save_data() + yield event.plain_result("成功删除待办事项:\n" + reminder["text"]) + + async def _reminder_callback(self, unified_msg_origin: str, d: dict): + '''The callback function of the reminder.''' + logger.info(f"Reminder Activated: {d['text']}, created by {unified_msg_origin}") + await self.context.send_message(unified_msg_origin, MessageEventResult().message("待办提醒: \n\n" + d['text'] + "\n时间: " + d.get("datetime", "") + d.get("cron_h", ""))) \ No newline at end of file