feat: 适配多节点的转发消息(OneBot V11)

This commit is contained in:
Soulter
2025-02-22 21:07:57 +08:00
parent 2188ea82de
commit b199bddb0b
12 changed files with 54 additions and 30 deletions
+1 -1
View File
@@ -58,7 +58,7 @@ class LogManager:
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_formatter = colorlog.ColoredFormatter(
fmt='%(log_color)s [%(asctime)s| %(levelname)s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s',
fmt='%(log_color)s [%(asctime)s] [%(levelname)-5s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s',
datefmt='%H:%M:%S',
log_colors=log_color_config
)
+26 -11
View File
@@ -30,11 +30,19 @@ from enum import Enum
from pydantic.v1 import BaseModel
class ComponentType(Enum):
Plain = "Plain"
Face = "Face"
Record = "Record"
Video = "Video"
At = "At"
Plain = "Plain" # 纯文本消息
Face = "Face" # QQ表情
Record = "Record" # 语音
Video = "Video" # 视频
At = "At" # At
Node = "Node" # 转发消息的一个节点
Nodes = "Nodes" # 转发消息的多个节点
Poke = "Poke" # QQ 戳一戳
Image = "Image" # 图片
Reply = "Reply" # 回复
Forward = "Forward" # 转发消息
File = "File" # 文件
RPS = "RPS" # TODO
Dice = "Dice" # TODO
Shake = "Shake" # TODO
@@ -43,18 +51,12 @@ class ComponentType(Enum):
Contact = "Contact" # TODO
Location = "Location" # TODO
Music = "Music"
Image = "Image"
Reply = "Reply"
RedBag = "RedBag"
Poke = "Poke"
Forward = "Forward"
Node = "Node"
Xml = "Xml"
Json = "Json"
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
File = "File"
class BaseMessageComponent(BaseModel):
@@ -362,6 +364,18 @@ class Node(BaseMessageComponent):
def toString(self):
# logger.warn("Protocol: node doesn't support stringify")
return ""
class Nodes(BaseMessageComponent):
type: ComponentType = "Nodes"
nodes: T.List[Node]
def __init__(self, nodes: T.List[Node], **_):
super().__init__(nodes=nodes, **_)
def toDict(self):
return {
"messages": [node.toDict() for node in self.nodes]
}
class Xml(BaseMessageComponent):
@@ -451,6 +465,7 @@ ComponentTypes = {
"poke": Poke,
"forward": Forward,
"node": Node,
"nodes": Nodes,
"xml": Xml,
"json": Json,
"cardimage": CardImage,
@@ -28,4 +28,3 @@ class ContentSafetyCheckStage(Stage):
event.stop_event()
logger.info(f"内容安全检查不通过,原因:{info}")
return
event.continue_event()
@@ -13,6 +13,7 @@ from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
class LLMRequestSubStage(Stage):
@@ -69,6 +70,7 @@ class LLMRequestSubStage(Stage):
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
for handler in handlers:
try:
logger.debug(f"hook(on_llm_request) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}")
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
@@ -82,10 +84,11 @@ class LLMRequestSubStage(Stage):
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
# 执行 LLM 响应后的事件。
# 执行 LLM 响应后的事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMResponseEvent)
for handler in handlers:
try:
logger.debug(f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}")
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
@@ -28,10 +28,8 @@ class StarRequestSubStage(Stage):
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_path not in star_map:
# 孤立无援的 star handler
continue
logger.debug(f"执行插件 handler {handler.handler_full_name}")
logger.debug(f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}")
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
async for ret in wrapper:
yield ret
@@ -73,8 +73,6 @@ class RateLimitStage(Stage):
timestamps.append(now)
return event.continue_event()
def _remove_expired_timestamps(self, timestamps: Deque[datetime], now: datetime) -> None:
"""
移除时间窗口外的时间戳。
+2
View File
@@ -10,6 +10,7 @@ from astrbot.core.message.message_event_result import MessageChain
from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.components import Plain, Reply, At
@register_stage
class RespondStage(Stage):
@@ -90,6 +91,7 @@ class RespondStage(Stage):
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
for handler in handlers:
try:
logger.debug(f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}")
await handler.handler(event)
except BaseException:
logger.error(traceback.format_exc())
@@ -10,6 +10,7 @@ from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
@register_stage
class ResultDecorateStage(Stage):
@@ -47,7 +48,7 @@ class ResultDecorateStage(Stage):
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
if result is None:
if result is None or not result.chain:
return
# 回复时检查内容安全
@@ -63,7 +64,10 @@ class ResultDecorateStage(Stage):
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
for handler in handlers:
try:
logger.debug(f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}")
await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
logger.debug(f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。")
except BaseException:
logger.error(traceback.format_exc())
@@ -1,7 +1,7 @@
import asyncio
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image, Record, At, Node, Music, Video
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
@@ -45,15 +45,25 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
send_one_by_one = False
for seg in message.chain:
if isinstance(seg, (Node, Music)):
if isinstance(seg, (Node, Nodes)):
# 转发消息不能和普通消息混在一起发送
send_one_by_one = True
break
if send_one_by_one:
for seg in message.chain:
await self.bot.send(self.message_obj.raw_message, await AiocqhttpMessageEvent._parse_onebot_json(MessageChain([seg])))
await asyncio.sleep(0.5)
if isinstance(seg, Nodes):
# 带有多个节点的合并转发消息
payload = seg.toDict()
if self.get_group_id():
payload['group_id'] = self.get_group_id()
await self.bot.call_action('send_group_forward_msg', **payload)
else:
payload['user_id'] = self.get_sender_id()
await self.bot.call_action('send_private_forward_msg', **payload)
else:
await self.bot.send(self.message_obj.raw_message, await AiocqhttpMessageEvent._parse_onebot_json(MessageChain([seg])))
await asyncio.sleep(0.5)
else:
await self.bot.send(self.message_obj.raw_message, ret)
@@ -16,7 +16,7 @@ from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed
from astrbot.core.utils.io import download_file
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
@register_platform_adapter("aiocqhttp", "适用于 OneBot V11 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
@@ -59,7 +59,6 @@ def register_command(command_name: str = None, sub_command: str = None, alias: s
if isinstance(command_name, RegisteringCommandable):
# 子指令
parent_command_names = command_name.parent_group.get_complete_command_names()
logger.debug(f"parent_command_names: {parent_command_names}")
new_command = CommandFilter(sub_command, alias, None, parent_command_names=parent_command_names)
command_name.parent_group.add_sub_command_filter(new_command)
else:
-4
View File
@@ -968,10 +968,6 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
# def kdb(self):
# pass
# @kdb.command("off")
# async def off_kdb(self, event: AstrMessageEvent):
# self.kdb_enabled = False
# yield event.plain_result("知识库已关闭")
# @filter.on_llm_request()
# async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):