From 4d8d9ecfc2a8a50289f31597458066139423bd77 Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Thu, 28 Nov 2024 21:39:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=20=E6=8E=A5=E5=85=A5=E7=BB=BF=E6=B3=A1?= =?UTF-8?q?=E6=B3=A1=E6=B6=88=E6=81=AF=E5=B9=B3=E5=8F=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/config/astrbot_config.py | 24 +++- astrbot/core/config/default.py | 7 ++ astrbot/core/message_event_handler.py | 5 +- astrbot/core/platform/message_type.py | 2 +- packages/astrbot_adapter_wechat/main.py | 18 +++ packages/astrbot_adapter_wechat/metadata.yaml | 6 + .../wechat_message_event.py | 38 ++++++ .../wechat_platform_adapter.py | 112 ++++++++++++++++++ packages/astrbot_plugin_openai/main.py | 2 +- requirements.txt | 1 + 10 files changed, 208 insertions(+), 7 deletions(-) create mode 100644 packages/astrbot_adapter_wechat/main.py create mode 100644 packages/astrbot_adapter_wechat/metadata.yaml create mode 100644 packages/astrbot_adapter_wechat/wechat_message_event.py create mode 100644 packages/astrbot_adapter_wechat/wechat_platform_adapter.py diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index e6111bd76..23699c9b7 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -44,6 +44,10 @@ class AiocqhttpPlatformConfig(PlatformConfig): qq_id_whitelist: List[str] = field(default_factory=list) qq_group_id_whitelist: List[str] = field(default_factory=list) +@dataclass +class WechatPlatformConfig(PlatformConfig): + wechat_id_whitelist: List[str] = field(default_factory=list) + @dataclass class ModelConfig: model: str = "gpt-4o" @@ -147,14 +151,30 @@ class AstrBotConfig(): ''' self.config_version=data.get("version", 2) self.platform=[] + + left_platforms = ["qq_official", "aiocqhttp", "wechat"] for p in data.get("platform", []): if 'name' not in p: logger.warning("A platform config missing name, skipping.") continue if p["name"] == "qq_official": self.platform.append(QQOfficialPlatformConfig(**p)) + left_platforms.remove(p["name"]) elif p["name"] == "aiocqhttp": self.platform.append(AiocqhttpPlatformConfig(**p)) + left_platforms.remove(p["name"]) + elif p["name"] == "wechat": + self.platform.append(WechatPlatformConfig(**p)) + left_platforms.remove(p["name"]) + # 注入默认配置 + for p in left_platforms: + if p == "qq_official": + self.platform.append(QQOfficialPlatformConfig(id="default", name=p)) + elif p == "aiocqhttp": + self.platform.append(AiocqhttpPlatformConfig(id="default", name=p)) + elif p == "wechat": + self.platform.append(WechatPlatformConfig(id="default", name=p)) + self.platform_settings=PlatformSettings(**data.get("platform_settings", {})) self.llm=[LLMConfig(**l) for l in data.get("llm", [])] self.llm_settings=LLMSettings(**data.get("llm_settings", {})) @@ -190,10 +210,6 @@ class AstrBotConfig(): config = DEFAULT_CONFIG_VERSION_2 else: config = self.get_all() - # check if the config is outdated - if 'config_version' not in config: # version 1 - config = self.migrate_config_1_2(config) - self.flush_config(config) # 加载配置到对象 self.load_from_dict(config) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 9f0f747a3..6e8ac6052 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -22,6 +22,12 @@ DEFAULT_CONFIG_VERSION_2 = { "ws_reverse_port": 6199, "qq_id_whitelist": [], "qq_group_id_whitelist": [] + }, + { + "id": "default", + "name": "wechat", + "enable": False, + "wechat_id_whitelist": [] } ], "platform_settings": { @@ -105,6 +111,7 @@ CONFIG_METADATA_2 = { "ws_reverse_port": {"description": "反向 Websocket 端口", "type": "int", "hint": "aiocqhttp 适配器的反向 Websocket 端口。"}, "qq_id_whitelist": {"description": "QQ 号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 号发来的消息事件。为空时表示不启用白名单过滤。"}, "qq_group_id_whitelist": {"description": "QQ 群号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 群发来的消息事件。为空时表示不启用白名单过滤。"}, + "wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。"}, } }, "platform_settings": { diff --git a/astrbot/core/message_event_handler.py b/astrbot/core/message_event_handler.py index 75e7e2cb1..f1323b4df 100644 --- a/astrbot/core/message_event_handler.py +++ b/astrbot/core/message_event_handler.py @@ -1,4 +1,4 @@ -import asyncio, re +import asyncio, re, time import inspect import traceback from typing import List, Union @@ -137,7 +137,10 @@ class MessageEventHandler(): else: break if plain_str and len(plain_str) > 150: + render_start = time.time() url = await html_renderer.render_t2i(plain_str, return_url=True) + if time.time() - render_start > 3: + logger.warning(f"图片转文本耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。") if url: result.chain = [Image.fromURL(url)] diff --git a/astrbot/core/platform/message_type.py b/astrbot/core/platform/message_type.py index b1c7721e5..6149277d9 100644 --- a/astrbot/core/platform/message_type.py +++ b/astrbot/core/platform/message_type.py @@ -3,4 +3,4 @@ from enum import Enum class MessageType(Enum): GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息 FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息 - \ No newline at end of file + OTHER_MESSAGE = 'OtherMessage' # 其他类型的消息,如系统消息等 \ No newline at end of file diff --git a/packages/astrbot_adapter_wechat/main.py b/packages/astrbot_adapter_wechat/main.py new file mode 100644 index 000000000..b6eaf25e2 --- /dev/null +++ b/packages/astrbot_adapter_wechat/main.py @@ -0,0 +1,18 @@ +from astrbot.api import Context, AstrMessageEvent, MessageEventResult +from .wechat_platform_adapter import WechatPlatformAdapter +from astrbot.api import logger + +class Main: + def __init__(self, context: Context) -> None: + self.context = context + platforms_config = context.get_config().platform + settings = context.get_config().platform_settings + for platform in platforms_config: + if platform.name == "wechat" and platform.enable: + self.context.register_platform(WechatPlatformAdapter(platform, settings, context.get_event_queue())) + logger.info(f"已注册 wechat({platform.id}) 消息适配器。") + + self.context.register_commands("astrbot_adapter_wechat", "wechatid", "查看微信ID", 1, self.get_wechat_id) + + async def get_wechat_id(self, event: AstrMessageEvent): + event.set_result(MessageEventResult().message("这个会话的微信ID是" + event.message_obj.raw_message.from_.username)) \ No newline at end of file diff --git a/packages/astrbot_adapter_wechat/metadata.yaml b/packages/astrbot_adapter_wechat/metadata.yaml new file mode 100644 index 000000000..16c8db775 --- /dev/null +++ b/packages/astrbot_adapter_wechat/metadata.yaml @@ -0,0 +1,6 @@ +name: astrbot_adapter_wechat # 插件名称 +desc: 支持 Wechat(UOS) 的消息平台适配器 +help: +version: v1.0.0 # 插件版本号。格式:v1.1.1 或者 v1.1 +author: Soulter # 作者 +repo: https://github.com/Soulter/AstrBot \ No newline at end of file diff --git a/packages/astrbot_adapter_wechat/wechat_message_event.py b/packages/astrbot_adapter_wechat/wechat_message_event.py new file mode 100644 index 000000000..91cbf0180 --- /dev/null +++ b/packages/astrbot_adapter_wechat/wechat_message_event.py @@ -0,0 +1,38 @@ +import random, asyncio +from astrbot.core.utils.io import download_image_by_url +from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata +from astrbot.api import Plain, Image +from vchat import Core + +class WechatPlatformEvent(AstrMessageEvent): + def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, client: Core): + super().__init__(message_str, message_obj, platform_meta, session_id) + self.client = client + + @staticmethod + async def send_with_client(client: Core, message: MessageChain, user_name: str): + plain = "" + for comp in message.chain: + if isinstance(comp, Plain): + plain += comp.text + elif isinstance(comp, Image): + if comp.file and comp.file.startswith("file:///"): + file_path = comp.file.replace("file:///", "") + with open(file_path, "rb") as f: + await client.send_image(user_name, fd=f) + elif comp.file and comp.file.startswith("http"): + image_path = await download_image_by_url(comp.file) + with open(image_path, "rb") as f: + await client.send_image(user_name, fd=f) + else: + logger.error(f"不支持的 vchat(微信适配器) 消息类型: {comp}") + await asyncio.sleep(random.uniform(0.5, 1.5)) # 🤓 + + if plain: + await client.send_msg(plain, user_name) + + + async def send(self, message: MessageChain): + await WechatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username) + await super().send(message) + \ No newline at end of file diff --git a/packages/astrbot_adapter_wechat/wechat_platform_adapter.py b/packages/astrbot_adapter_wechat/wechat_platform_adapter.py new file mode 100644 index 000000000..8d08b99b7 --- /dev/null +++ b/packages/astrbot_adapter_wechat/wechat_platform_adapter.py @@ -0,0 +1,112 @@ +import sys, time, datetime, uuid +import asyncio + +from astrbot.api import Platform +from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata +from typing import Union, List, Dict +from nakuru.entities.components import * +from astrbot.api import logger +from astrbot.core.platform.astr_message_event import MessageSesion +from .wechat_message_event import WechatPlatformEvent +from astrbot.core.config.astrbot_config import PlatformConfig, WechatPlatformConfig, PlatformSettings +from astrbot.core.utils.io import save_temp_img, download_image_by_url + +from vchat import Core +from vchat import model + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + +class WechatPlatformAdapter(Platform): + + def __init__(self, platform_config: WechatPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None: + super().__init__(event_queue) + self.config = platform_config + self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on' + self.client_self_id = uuid.uuid4().hex[:8] + + @override + async def send_by_session(self, session: MessageSesion, message_chain: MessageChain): + from_username = session.session_id.split('$$')[0] + await WechatPlatformEvent.send_with_client(self.client, message_chain, from_username) + await super().send_by_session(session, message_chain) + + @override + def meta(self) -> PlatformMetadata: + return PlatformMetadata( + "wechat", + "基于 VChat 的 Wechat 适配器", + ) + + @override + def run(self): + self.client = Core() + @self.client.msg_register(msg_types=model.ContentTypes.TEXT, + contact_type=model.ContactTypes.CHATROOM | model.ContactTypes.USER) + async def _(msg: model.Message): + if isinstance(msg.content, model.UselessContent): + return + if msg.create_time < self.start_time: + logger.debug(f"忽略旧消息: {msg}") + return + if self.config.wechat_id_whitelist and msg.from_.username not in self.config.wechat_id_whitelist: + logger.debug(f"忽略不在白名单的微信消息。username: {msg.from_.username}") + return + logger.info(f"收到消息: {msg.todict()}") + abmsg = self.convert_message(msg) + # await self.handle_msg(abmsg) # 不能直接调用,否则会阻塞 + asyncio.create_task(self.handle_msg(abmsg)) + + # TODO: 对齐微信服务器时间 + self.start_time = int(time.time()) + return self._run() + + + async def _run(self): + await self.client.init() + await self.client.auto_login(hot_reload=True) + await self.client.run() + + def convert_message(self, msg: model.Message) -> AstrBotMessage: + # credits: https://github.com/z2z63/astrbot_plugin_vchat/blob/master/main.py#L49 + assert isinstance(msg.content, model.TextContent) + amsg = AstrBotMessage() + amsg.message = [Plain(msg.content.content)] + amsg.self_id = self.client_self_id + if msg.content.is_at_me: + amsg.message.insert(0, At(qq=amsg.self_id)) + + sender = msg.chatroom_sender or msg.from_ + amsg.sender = MessageMember(sender.username, sender.nickname) + amsg.message_str = msg.content.content + amsg.message_id = msg.message_id + if isinstance(msg.from_, model.User): + amsg.type = MessageType.FRIEND_MESSAGE + elif isinstance(msg.from_, model.Chatroom): + amsg.type = MessageType.GROUP_MESSAGE + else: + logger.error(f"不支持的 Wechat 消息类型: {msg.from_}") + + amsg.raw_message = msg + + session_id = msg.from_.username + "$$" + msg.to.username + if msg.chatroom_sender is not None: + session_id += '$$' + msg.chatroom_sender.username + + amsg.session_id = session_id + return amsg + + async def handle_msg(self, message: AstrBotMessage): + message_event = WechatPlatformEvent( + message_str=message.message_str, + message_obj=message, + platform_meta=self.meta(), + session_id=message.session_id, + client=self.client + ) + + logger.info(f"处理消息: {message_event}") + + self.commit_event(message_event) \ No newline at end of file diff --git a/packages/astrbot_plugin_openai/main.py b/packages/astrbot_plugin_openai/main.py index 5dd245992..5a36b9777 100644 --- a/packages/astrbot_plugin_openai/main.py +++ b/packages/astrbot_plugin_openai/main.py @@ -61,7 +61,7 @@ class Main: fetch_website_content ) - async def remove_web_search_tools(self): + def remove_web_search_tools(self): self.context.unregister_llm_tool("web_search") self.context.unregister_llm_tool("fetch_website_content") diff --git a/requirements.txt b/requirements.txt index 3fce570d0..2703d5140 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ pydantic~=1.10.4 +vchat aiohttp openai qq-botpy