Merge pull request #1128 from anka-afk/anka-dev

feature: 实现了 #1127 还有 #1133 还有 #1143
This commit is contained in:
Soulter
2025-04-06 22:22:01 +08:00
committed by GitHub
8 changed files with 270 additions and 29 deletions
+2 -6
View File
@@ -2,11 +2,7 @@ from astrbot.core.star.register import (
register_star as register, # 注册插件(Star
)
from astrbot.core.star import Context, Star
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star.config import *
__all__ = [
"register",
"Context",
"Star",
]
__all__ = ["register", "Context", "Star", "StarTools"]
+7
View File
@@ -98,6 +98,7 @@ DEFAULT_CONFIG = {
"plugin_repo_mirror": "",
"knowledge_db": {},
"persona": [],
"timezone": "",
}
@@ -1172,6 +1173,12 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
},
"timezone": {
"description": "时区",
"type": "string",
"obvious_hint": True,
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
},
"log_level": {
"description": "控制台日志级别",
"type": "string",
+68 -4
View File
@@ -2,6 +2,7 @@ import random
import asyncio
import math
import traceback
import astrbot.core.message.components as Comp
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
@@ -11,11 +12,42 @@ from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.components import Plain, Reply, At
@register_stage
class RespondStage(Stage):
# 组件类型到其非空判断函数的映射
_component_validators = {
Comp.Plain: lambda comp: bool(comp.text and comp.text.strip()), # 纯文本消息需要strip
Comp.Face: lambda comp: comp.id is not None, # QQ表情
Comp.Record: lambda comp: bool(comp.file), # 语音
Comp.Video: lambda comp: bool(comp.file), # 视频
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
Comp.AtAll: lambda comp: True, # @所有人
Comp.RPS: lambda comp: True, # 不知道是啥(未完成)
Comp.Dice: lambda comp: True, # 骰子(未完成)
Comp.Shake: lambda comp: True, # 摇一摇(未完成)
Comp.Anonymous: lambda comp: True, # 匿名(未完成)
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
Comp.Contact: lambda comp: True, # 联系人(未完成)
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
Comp.Music: lambda comp: bool(comp._type) and bool(comp.url) and bool(comp.audio), # 音乐
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.RedBag: lambda comp: bool(comp.title), # 红包
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
Comp.Node: lambda comp: bool(comp.name) and comp.uin != 0 and bool(comp.content), # 一个转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
Comp.Json: lambda comp: bool(comp.data), # JSON
Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片
Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成
Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息
Comp.File: lambda comp: bool(comp.file), # 文件
Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情
}
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
@@ -62,7 +94,7 @@ class RespondStage(Stage):
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
"""分段回复 计算间隔时间"""
if self.interval_method == "log":
if isinstance(comp, Plain):
if isinstance(comp, Comp.Plain):
wc = await self._word_cnt(comp.text)
i = math.log(wc + 1, self.log_base)
return random.uniform(i, i + 0.5)
@@ -72,6 +104,28 @@ class RespondStage(Stage):
# random
return random.uniform(self.interval[0], self.interval[1])
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
"""检查消息链是否为空
Args:
chain (list[BaseMessageComponent]): 包含消息对象的列表
"""
if not chain:
return True
for comp in chain:
comp_type = type(comp)
# 检查组件类型是否在字典中
if comp_type in self._component_validators:
if self._component_validators[comp_type](comp):
return False
else:
logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}")
# 如果所有组件都为空
return True
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
@@ -82,6 +136,16 @@ class RespondStage(Stage):
if len(result.chain) > 0:
await event._pre_send()
# 检查消息链是否为空
try:
if await self._is_empty_message_chain(result.chain):
logger.info("消息为空,跳过发送阶段")
event.clear_result()
event.stop_event()
return
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
if self.enable_seg and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
@@ -89,13 +153,13 @@ class RespondStage(Stage):
decorated_comps = []
if self.reply_with_mention:
for comp in result.chain:
if isinstance(comp, At):
if isinstance(comp, Comp.At):
decorated_comps.append(comp)
result.chain.remove(comp)
break
if self.reply_with_quote:
for comp in result.chain:
if isinstance(comp, Reply):
if isinstance(comp, Comp.Reply):
decorated_comps.append(comp)
result.chain.remove(comp)
break
@@ -147,7 +147,7 @@ class ProviderGoogleGenAI(Provider):
if message["role"] == "user":
if isinstance(message["content"], str):
if not message["content"]:
message["content"] = "<empty_content>"
message["content"] = ""
google_genai_conversation.append(
{"role": "user", "parts": [{"text": message["content"]}]}
@@ -158,7 +158,7 @@ class ProviderGoogleGenAI(Provider):
for part in message["content"]:
if part["type"] == "text":
if not part["text"]:
part["text"] = "<empty_content>"
part["text"] = ""
parts.append({"text": part["text"]})
elif part["type"] == "image_url":
parts.append(
@@ -176,7 +176,7 @@ class ProviderGoogleGenAI(Provider):
elif message["role"] == "assistant":
if "content" in message:
if not message["content"]:
message["content"] = "<empty_content>"
message["content"] = ""
google_genai_conversation.append(
{"role": "model", "parts": [{"text": message["content"]}]}
)
+3 -1
View File
@@ -4,12 +4,14 @@ from .context import Context
from astrbot.core.provider import Provider
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core import html_renderer
from astrbot.core.star.star_tools import StarTools
class Star(CommandParserMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
def __init__(self, context: Context):
StarTools.initialize(context)
self.context = context
async def text_to_image(self, text: str, return_url=True) -> str:
@@ -27,4 +29,4 @@ class Star(CommandParserMixin):
pass
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider"]
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
+144
View File
@@ -0,0 +1,144 @@
from typing import Union, Awaitable, List, Optional, ClassVar
from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.api.platform import MessageMember, AstrBotMessage
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.star.context import Context
class StarTools:
"""
提供给插件使用的便捷工具函数集合
这些方法封装了一些常用操作,使插件开发更加简单便捷!
"""
_context: ClassVar[Optional[Context]] = None
@classmethod
def initialize(cls, context: Context) -> None:
"""
初始化StarTools,设置context引用
Args:
context: 暴露给插件的上下文
"""
cls._context = context
@classmethod
async def send_message(
cls, session: Union[str, MessageSesion], message_chain: MessageChain
) -> bool:
"""
根据session(unified_msg_origin)主动发送消息
Args:
session: 消息会话。通过event.session或者event.unified_msg_origin获取
message_chain: 消息链
Returns:
bool: 是否找到匹配的平台
Raises:
ValueError: 当session为字符串且解析失败时抛出
Note:
qq_official(QQ官方API平台)不支持此方法
"""
return await cls._context.send_message(session, message_chain)
@classmethod
async def create_message(
cls,
type: str,
self_id: str,
session_id: str,
message_id: str,
sender: MessageMember,
message: List[BaseMessageComponent],
message_str: str,
raw_message: object,
group_id: str = "",
):
"""
创建一个AstrBot消息对象
Args:
type (str): 消息类型
self_id (str): 机器人自身ID
session_id (str): 会话ID(通常为用户ID)(QQ号, 群号等)
message_id (str): 消息ID
sender (MessageMember): 发送者信息
message (List[BaseMessageComponent]): 消息组件列表
message_str (str): 消息字符串
raw_message (object): 原始消息对象
group_id (str, optional): 群组ID, 如果为私聊则为空. Defaults to "".
Returns:
AstrBotMessage: 创建的消息对象
"""
abm = AstrBotMessage()
abm.type = type
abm.self_id = self_id
abm.session_id = session_id
abm.message_id = message_id
abm.sender = sender
abm.message = message
abm.message_str = message_str
abm.raw_message = raw_message
abm.group_id = group_id
return abm
# todo: 添加构造事件的方法
# async def create_event(
# self, platform: str, umo: str, sender_id: str, session_id: str
# ):
# platform = self._context.get_platform(platform)
# todo: 添加找到对应平台并提交对应事件的方法
@classmethod
def activate_llm_tool(cls, name: str) -> bool:
"""
激活一个已经注册的函数调用工具
注册的工具默认是激活状态
Args:
name (str): 工具名称
"""
return cls._context.activate_llm_tool(name)
@classmethod
def deactivate_llm_tool(cls, name: str) -> bool:
"""
停用一个已经注册的函数调用工具
Args:
name (str): 工具名称
"""
return cls._context.deactivate_llm_tool(name)
@classmethod
def register_llm_tool(
cls, name: str, func_args: list, desc: str, func_obj: Awaitable
) -> None:
"""
为函数调用(function-calling/tools-use)添加工具
Args:
name (str): 工具名称
func_args (list): 函数参数列表
desc (str): 工具描述
func_obj (Awaitable): 函数对象,必须是异步函数
"""
cls._context.register_llm_tool(name, func_args, desc, func_obj)
@classmethod
def unregister_llm_tool(cls, name: str) -> None:
"""
删除一个函数调用工具
如果再要启用,需要重新注册
Args:
name (str): 工具名称
"""
cls._context.unregister_llm_tool(name)
+24 -8
View File
@@ -3,6 +3,7 @@ import datetime
import builtins
import traceback
import re
import zoneinfo
import astrbot.api.star as star
import astrbot.api.event.filter as filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult
@@ -22,7 +23,6 @@ from astrbot.core.config.default import VERSION
from .long_term_memory import LongTermMemory
from astrbot.core import logger
from astrbot.api.message_components import Plain, Image, Reply
from typing import Union
@@ -39,7 +39,12 @@ class Main(star.Star):
self.prompt_prefix = cfg["provider_settings"]["prompt_prefix"]
self.identifier = cfg["provider_settings"]["identifier"]
self.enable_datetime = cfg["provider_settings"]["datetime_system_prompt"]
self.timezone = cfg.get("timezone")
if not self.timezone:
# 系统默认时区
self.timezone = None
else:
logger.info(f"Timezone set to: {self.timezone}")
self.ltm = None
if (
self.context.get_config()["provider_ltm_settings"]["group_icl_enable"]
@@ -969,7 +974,8 @@ UID: {user_id} 此 ID 可用于设置管理员。
if len(l) == 1:
message.set_result(
MessageEventResult()
.message(f"""[Persona]
.message(
f"""[Persona]
- 人格情景列表: `/persona list`
- 设置人格情景: `/persona 人格`
@@ -980,7 +986,8 @@ UID: {user_id} 此 ID 可用于设置管理员。
当前对话 {curr_cid_title} 的人格情景: {curr_persona_name}
配置人格情景请前往管理面板-配置页
""")
"""
)
.use_t2i(False)
)
elif l[1] == "list":
@@ -1190,11 +1197,20 @@ UID: {user_id} 此 ID 可用于设置管理员。
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
req.prompt = user_info + req.prompt
# 启用附加时间戳
if self.enable_datetime:
# Including timezone
current_time = (
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
)
current_time = None
if self.timezone:
# 启用时区
try:
now = datetime.datetime.now(zoneinfo.ZoneInfo(self.timezone))
current_time = now.strftime("%Y-%m-%d %H:%M (%Z)")
except Exception as e:
logger.error(f"时区设置错误: {e}, 使用本地时区")
if not current_time:
current_time = (
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
)
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
if req.conversation:
+19 -7
View File
@@ -2,6 +2,7 @@ import os
import json
import datetime
import uuid
import zoneinfo
import astrbot.api.star as star
from astrbot.api.event import filter
from apscheduler.schedulers.asyncio import AsyncIOScheduler
@@ -17,7 +18,15 @@ class Main(star.Star):
def __init__(self, context: star.Context) -> None:
self.context = context
self.scheduler = AsyncIOScheduler(timezone="Asia/Shanghai")
self.timezone = self.context.get_config().get("timezone")
if not self.timezone:
self.timezone = None
try:
self.timezone = zoneinfo.ZoneInfo(self.timezone) if self.timezone else None
except Exception as e:
logger.error(f"时区设置错误: {e}, 使用本地时区")
self.timezone = None
self.scheduler = AsyncIOScheduler(timezone=self.timezone)
# set and load config
if not os.path.exists("data/astrbot-reminder.json"):
@@ -65,10 +74,10 @@ class Main(star.Star):
def check_is_outdated(self, reminder: dict):
"""Check if the reminder is outdated."""
if "datetime" in reminder:
return (
datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M")
< datetime.datetime.now()
)
reminder_time = datetime.datetime.strptime(
reminder["datetime"], "%Y-%m-%d %H:%M"
).replace(tzinfo=self.timezone)
return reminder_time < datetime.datetime.now(self.timezone)
return False
async def _save_data(self):
@@ -171,12 +180,15 @@ class Main(star.Star):
reminders = self.reminder_data.get(unified_msg_origin, [])
if not reminders:
return []
now = datetime.datetime.now()
now = datetime.datetime.now(self.timezone)
upcoming_reminders = [
reminder
for reminder in reminders
if "datetime" not in reminder
or datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M") >= now
or datetime.datetime.strptime(
reminder["datetime"], "%Y-%m-%d %H:%M"
).replace(tzinfo=self.timezone)
>= now
]
return upcoming_reminders