diff --git a/astrbot/api/star/__init__.py b/astrbot/api/star/__init__.py index 630786de3..1b33923fe 100644 --- a/astrbot/api/star/__init__.py +++ b/astrbot/api/star/__init__.py @@ -2,11 +2,7 @@ from astrbot.core.star.register import ( register_star as register, # 注册插件(Star) ) -from astrbot.core.star import Context, Star +from astrbot.core.star import Context, Star, StarTools from astrbot.core.star.config import * -__all__ = [ - "register", - "Context", - "Star", -] +__all__ = ["register", "Context", "Star", "StarTools"] diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 7e2344816..2c4b60d2b 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -98,6 +98,7 @@ DEFAULT_CONFIG = { "plugin_repo_mirror": "", "knowledge_db": {}, "persona": [], + "timezone": "", } @@ -1172,6 +1173,12 @@ CONFIG_METADATA_2 = { "type": "string", "hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`", }, + "timezone": { + "description": "时区", + "type": "string", + "obvious_hint": True, + "hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab", + }, "log_level": { "description": "控制台日志级别", "type": "string", diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index a43f0b32d..86c165945 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -2,6 +2,7 @@ import random import asyncio import math import traceback +import astrbot.core.message.components as Comp from typing import Union, AsyncGenerator from ..stage import register_stage, Stage from ..context import PipelineContext @@ -11,11 +12,42 @@ from astrbot.core import logger from astrbot.core.message.message_event_result import BaseMessageComponent from astrbot.core.star.star_handler import star_handlers_registry, EventType from astrbot.core.star.star import star_map -from astrbot.core.message.components import Plain, Reply, At @register_stage class RespondStage(Stage): + # 组件类型到其非空判断函数的映射 + _component_validators = { + Comp.Plain: lambda comp: bool(comp.text and comp.text.strip()), # 纯文本消息需要strip + Comp.Face: lambda comp: comp.id is not None, # QQ表情 + Comp.Record: lambda comp: bool(comp.file), # 语音 + Comp.Video: lambda comp: bool(comp.file), # 视频 + Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @ + Comp.AtAll: lambda comp: True, # @所有人 + Comp.RPS: lambda comp: True, # 不知道是啥(未完成) + Comp.Dice: lambda comp: True, # 骰子(未完成) + Comp.Shake: lambda comp: True, # 摇一摇(未完成) + Comp.Anonymous: lambda comp: True, # 匿名(未完成) + Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享 + Comp.Contact: lambda comp: True, # 联系人(未完成) + Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置 + Comp.Music: lambda comp: bool(comp._type) and bool(comp.url) and bool(comp.audio), # 音乐 + Comp.Image: lambda comp: bool(comp.file), # 图片 + Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复 + Comp.RedBag: lambda comp: bool(comp.title), # 红包 + Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳 + Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发 + Comp.Node: lambda comp: bool(comp.name) and comp.uin != 0 and bool(comp.content), # 一个转发节点 + Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点 + Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML + Comp.Json: lambda comp: bool(comp.data), # JSON + Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片 + Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成 + Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息 + Comp.File: lambda comp: bool(comp.file), # 文件 + Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情 + } + async def initialize(self, ctx: PipelineContext): self.ctx = ctx @@ -62,7 +94,7 @@ class RespondStage(Stage): async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float: """分段回复 计算间隔时间""" if self.interval_method == "log": - if isinstance(comp, Plain): + if isinstance(comp, Comp.Plain): wc = await self._word_cnt(comp.text) i = math.log(wc + 1, self.log_base) return random.uniform(i, i + 0.5) @@ -72,6 +104,28 @@ class RespondStage(Stage): # random return random.uniform(self.interval[0], self.interval[1]) + async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]): + """检查消息链是否为空 + + Args: + chain (list[BaseMessageComponent]): 包含消息对象的列表 + """ + if not chain: + return True + + for comp in chain: + comp_type = type(comp) + + # 检查组件类型是否在字典中 + if comp_type in self._component_validators: + if self._component_validators[comp_type](comp): + return False + else: + logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}") + + # 如果所有组件都为空 + return True + async def process( self, event: AstrMessageEvent ) -> Union[None, AsyncGenerator[None, None]]: @@ -82,6 +136,16 @@ class RespondStage(Stage): if len(result.chain) > 0: await event._pre_send() + # 检查消息链是否为空 + try: + if await self._is_empty_message_chain(result.chain): + logger.info("消息为空,跳过发送阶段") + event.clear_result() + event.stop_event() + return + except Exception as e: + logger.warning(f"空内容检查异常: {e}") + if self.enable_seg and ( (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result @@ -89,13 +153,13 @@ class RespondStage(Stage): decorated_comps = [] if self.reply_with_mention: for comp in result.chain: - if isinstance(comp, At): + if isinstance(comp, Comp.At): decorated_comps.append(comp) result.chain.remove(comp) break if self.reply_with_quote: for comp in result.chain: - if isinstance(comp, Reply): + if isinstance(comp, Comp.Reply): decorated_comps.append(comp) result.chain.remove(comp) break diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index c316544ff..9f5f7c3c1 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -147,7 +147,7 @@ class ProviderGoogleGenAI(Provider): if message["role"] == "user": if isinstance(message["content"], str): if not message["content"]: - message["content"] = "" + message["content"] = "" google_genai_conversation.append( {"role": "user", "parts": [{"text": message["content"]}]} @@ -158,7 +158,7 @@ class ProviderGoogleGenAI(Provider): for part in message["content"]: if part["type"] == "text": if not part["text"]: - part["text"] = "" + part["text"] = "" parts.append({"text": part["text"]}) elif part["type"] == "image_url": parts.append( @@ -176,7 +176,7 @@ class ProviderGoogleGenAI(Provider): elif message["role"] == "assistant": if "content" in message: if not message["content"]: - message["content"] = "" + message["content"] = "" google_genai_conversation.append( {"role": "model", "parts": [{"text": message["content"]}]} ) diff --git a/astrbot/core/star/__init__.py b/astrbot/core/star/__init__.py index b1bd5de81..ec1ee655b 100644 --- a/astrbot/core/star/__init__.py +++ b/astrbot/core/star/__init__.py @@ -4,12 +4,14 @@ from .context import Context from astrbot.core.provider import Provider from astrbot.core.utils.command_parser import CommandParserMixin from astrbot.core import html_renderer +from astrbot.core.star.star_tools import StarTools class Star(CommandParserMixin): """所有插件(Star)的父类,所有插件都应该继承于这个类""" def __init__(self, context: Context): + StarTools.initialize(context) self.context = context async def text_to_image(self, text: str, return_url=True) -> str: @@ -27,4 +29,4 @@ class Star(CommandParserMixin): pass -__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider"] +__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"] diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py new file mode 100644 index 000000000..68468e353 --- /dev/null +++ b/astrbot/core/star/star_tools.py @@ -0,0 +1,144 @@ +from typing import Union, Awaitable, List, Optional, ClassVar +from astrbot.core.message.components import BaseMessageComponent +from astrbot.core.message.message_event_result import MessageChain +from astrbot.api.platform import MessageMember, AstrBotMessage +from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.star.context import Context + + +class StarTools: + """ + 提供给插件使用的便捷工具函数集合 + 这些方法封装了一些常用操作,使插件开发更加简单便捷! + """ + + _context: ClassVar[Optional[Context]] = None + + @classmethod + def initialize(cls, context: Context) -> None: + """ + 初始化StarTools,设置context引用 + + Args: + context: 暴露给插件的上下文 + """ + cls._context = context + + @classmethod + async def send_message( + cls, session: Union[str, MessageSesion], message_chain: MessageChain + ) -> bool: + """ + 根据session(unified_msg_origin)主动发送消息 + + Args: + session: 消息会话。通过event.session或者event.unified_msg_origin获取 + message_chain: 消息链 + + Returns: + bool: 是否找到匹配的平台 + + Raises: + ValueError: 当session为字符串且解析失败时抛出 + + Note: + qq_official(QQ官方API平台)不支持此方法 + """ + return await cls._context.send_message(session, message_chain) + + @classmethod + async def create_message( + cls, + type: str, + self_id: str, + session_id: str, + message_id: str, + sender: MessageMember, + message: List[BaseMessageComponent], + message_str: str, + raw_message: object, + group_id: str = "", + ): + """ + 创建一个AstrBot消息对象 + + Args: + type (str): 消息类型 + self_id (str): 机器人自身ID + session_id (str): 会话ID(通常为用户ID)(QQ号, 群号等) + message_id (str): 消息ID + sender (MessageMember): 发送者信息 + message (List[BaseMessageComponent]): 消息组件列表 + message_str (str): 消息字符串 + raw_message (object): 原始消息对象 + group_id (str, optional): 群组ID, 如果为私聊则为空. Defaults to "". + + Returns: + AstrBotMessage: 创建的消息对象 + """ + abm = AstrBotMessage() + abm.type = type + abm.self_id = self_id + abm.session_id = session_id + abm.message_id = message_id + abm.sender = sender + abm.message = message + abm.message_str = message_str + abm.raw_message = raw_message + abm.group_id = group_id + return abm + + # todo: 添加构造事件的方法 + # async def create_event( + # self, platform: str, umo: str, sender_id: str, session_id: str + # ): + # platform = self._context.get_platform(platform) + + # todo: 添加找到对应平台并提交对应事件的方法 + + @classmethod + def activate_llm_tool(cls, name: str) -> bool: + """ + 激活一个已经注册的函数调用工具 + 注册的工具默认是激活状态 + + Args: + name (str): 工具名称 + """ + return cls._context.activate_llm_tool(name) + + @classmethod + def deactivate_llm_tool(cls, name: str) -> bool: + """ + 停用一个已经注册的函数调用工具 + + Args: + name (str): 工具名称 + """ + return cls._context.deactivate_llm_tool(name) + + @classmethod + def register_llm_tool( + cls, name: str, func_args: list, desc: str, func_obj: Awaitable + ) -> None: + """ + 为函数调用(function-calling/tools-use)添加工具 + + Args: + name (str): 工具名称 + func_args (list): 函数参数列表 + desc (str): 工具描述 + func_obj (Awaitable): 函数对象,必须是异步函数 + """ + cls._context.register_llm_tool(name, func_args, desc, func_obj) + + @classmethod + def unregister_llm_tool(cls, name: str) -> None: + """ + 删除一个函数调用工具 + 如果再要启用,需要重新注册 + + Args: + name (str): 工具名称 + """ + cls._context.unregister_llm_tool(name) diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 3887fc929..013dcb0f7 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -3,6 +3,7 @@ import datetime import builtins import traceback import re +import zoneinfo import astrbot.api.star as star import astrbot.api.event.filter as filter from astrbot.api.event import AstrMessageEvent, MessageEventResult @@ -22,7 +23,6 @@ from astrbot.core.config.default import VERSION from .long_term_memory import LongTermMemory from astrbot.core import logger from astrbot.api.message_components import Plain, Image, Reply - from typing import Union @@ -39,7 +39,12 @@ class Main(star.Star): self.prompt_prefix = cfg["provider_settings"]["prompt_prefix"] self.identifier = cfg["provider_settings"]["identifier"] self.enable_datetime = cfg["provider_settings"]["datetime_system_prompt"] - + self.timezone = cfg.get("timezone") + if not self.timezone: + # 系统默认时区 + self.timezone = None + else: + logger.info(f"Timezone set to: {self.timezone}") self.ltm = None if ( self.context.get_config()["provider_ltm_settings"]["group_icl_enable"] @@ -969,7 +974,8 @@ UID: {user_id} 此 ID 可用于设置管理员。 if len(l) == 1: message.set_result( MessageEventResult() - .message(f"""[Persona] + .message( + f"""[Persona] - 人格情景列表: `/persona list` - 设置人格情景: `/persona 人格` @@ -980,7 +986,8 @@ UID: {user_id} 此 ID 可用于设置管理员。 当前对话 {curr_cid_title} 的人格情景: {curr_persona_name} 配置人格情景请前往管理面板-配置页 -""") +""" + ) .use_t2i(False) ) elif l[1] == "list": @@ -1190,11 +1197,20 @@ UID: {user_id} 此 ID 可用于设置管理员。 user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n" req.prompt = user_info + req.prompt + # 启用附加时间戳 if self.enable_datetime: - # Including timezone - current_time = ( - datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)") - ) + current_time = None + if self.timezone: + # 启用时区 + try: + now = datetime.datetime.now(zoneinfo.ZoneInfo(self.timezone)) + current_time = now.strftime("%Y-%m-%d %H:%M (%Z)") + except Exception as e: + logger.error(f"时区设置错误: {e}, 使用本地时区") + if not current_time: + current_time = ( + datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)") + ) req.system_prompt += f"\nCurrent datetime: {current_time}\n" if req.conversation: diff --git a/packages/reminder/main.py b/packages/reminder/main.py index e3f6c0a97..6b1b8f3e8 100644 --- a/packages/reminder/main.py +++ b/packages/reminder/main.py @@ -2,6 +2,7 @@ import os import json import datetime import uuid +import zoneinfo import astrbot.api.star as star from astrbot.api.event import filter from apscheduler.schedulers.asyncio import AsyncIOScheduler @@ -17,7 +18,15 @@ class Main(star.Star): def __init__(self, context: star.Context) -> None: self.context = context - self.scheduler = AsyncIOScheduler(timezone="Asia/Shanghai") + self.timezone = self.context.get_config().get("timezone") + if not self.timezone: + self.timezone = None + try: + self.timezone = zoneinfo.ZoneInfo(self.timezone) if self.timezone else None + except Exception as e: + logger.error(f"时区设置错误: {e}, 使用本地时区") + self.timezone = None + self.scheduler = AsyncIOScheduler(timezone=self.timezone) # set and load config if not os.path.exists("data/astrbot-reminder.json"): @@ -65,10 +74,10 @@ class Main(star.Star): def check_is_outdated(self, reminder: dict): """Check if the reminder is outdated.""" if "datetime" in reminder: - return ( - datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M") - < datetime.datetime.now() - ) + reminder_time = datetime.datetime.strptime( + reminder["datetime"], "%Y-%m-%d %H:%M" + ).replace(tzinfo=self.timezone) + return reminder_time < datetime.datetime.now(self.timezone) return False async def _save_data(self): @@ -171,12 +180,15 @@ class Main(star.Star): reminders = self.reminder_data.get(unified_msg_origin, []) if not reminders: return [] - now = datetime.datetime.now() + now = datetime.datetime.now(self.timezone) upcoming_reminders = [ reminder for reminder in reminders if "datetime" not in reminder - or datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M") >= now + or datetime.datetime.strptime( + reminder["datetime"], "%Y-%m-%d %H:%M" + ).replace(tzinfo=self.timezone) + >= now ] return upcoming_reminders