feat: 接入绿泡泡消息平台

This commit is contained in:
Soulter
2024-11-28 21:39:35 +08:00
parent f2efa022b4
commit 4d8d9ecfc2
10 changed files with 208 additions and 7 deletions
+20 -4
View File
@@ -44,6 +44,10 @@ class AiocqhttpPlatformConfig(PlatformConfig):
qq_id_whitelist: List[str] = field(default_factory=list)
qq_group_id_whitelist: List[str] = field(default_factory=list)
@dataclass
class WechatPlatformConfig(PlatformConfig):
wechat_id_whitelist: List[str] = field(default_factory=list)
@dataclass
class ModelConfig:
model: str = "gpt-4o"
@@ -147,14 +151,30 @@ class AstrBotConfig():
'''
self.config_version=data.get("version", 2)
self.platform=[]
left_platforms = ["qq_official", "aiocqhttp", "wechat"]
for p in data.get("platform", []):
if 'name' not in p:
logger.warning("A platform config missing name, skipping.")
continue
if p["name"] == "qq_official":
self.platform.append(QQOfficialPlatformConfig(**p))
left_platforms.remove(p["name"])
elif p["name"] == "aiocqhttp":
self.platform.append(AiocqhttpPlatformConfig(**p))
left_platforms.remove(p["name"])
elif p["name"] == "wechat":
self.platform.append(WechatPlatformConfig(**p))
left_platforms.remove(p["name"])
# 注入默认配置
for p in left_platforms:
if p == "qq_official":
self.platform.append(QQOfficialPlatformConfig(id="default", name=p))
elif p == "aiocqhttp":
self.platform.append(AiocqhttpPlatformConfig(id="default", name=p))
elif p == "wechat":
self.platform.append(WechatPlatformConfig(id="default", name=p))
self.platform_settings=PlatformSettings(**data.get("platform_settings", {}))
self.llm=[LLMConfig(**l) for l in data.get("llm", [])]
self.llm_settings=LLMSettings(**data.get("llm_settings", {}))
@@ -190,10 +210,6 @@ class AstrBotConfig():
config = DEFAULT_CONFIG_VERSION_2
else:
config = self.get_all()
# check if the config is outdated
if 'config_version' not in config: # version 1
config = self.migrate_config_1_2(config)
self.flush_config(config)
# 加载配置到对象
self.load_from_dict(config)
+7
View File
@@ -22,6 +22,12 @@ DEFAULT_CONFIG_VERSION_2 = {
"ws_reverse_port": 6199,
"qq_id_whitelist": [],
"qq_group_id_whitelist": []
},
{
"id": "default",
"name": "wechat",
"enable": False,
"wechat_id_whitelist": []
}
],
"platform_settings": {
@@ -105,6 +111,7 @@ CONFIG_METADATA_2 = {
"ws_reverse_port": {"description": "反向 Websocket 端口", "type": "int", "hint": "aiocqhttp 适配器的反向 Websocket 端口。"},
"qq_id_whitelist": {"description": "QQ 号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 号发来的消息事件。为空时表示不启用白名单过滤。"},
"qq_group_id_whitelist": {"description": "QQ 群号白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的 QQ 群发来的消息事件。为空时表示不启用白名单过滤。"},
"wechat_id_whitelist": {"description": "微信私聊/群聊白名单", "type": "list", "items": {"type": "string"}, "hint": "填写后,将只处理所填写的微信私聊/群聊发来的消息事件。为空时表示不启用白名单过滤。使用 /wechatid 指令获取微信 ID(不是微信号)。"},
}
},
"platform_settings": {
+4 -1
View File
@@ -1,4 +1,4 @@
import asyncio, re
import asyncio, re, time
import inspect
import traceback
from typing import List, Union
@@ -137,7 +137,10 @@ class MessageEventHandler():
else:
break
if plain_str and len(plain_str) > 150:
render_start = time.time()
url = await html_renderer.render_t2i(plain_str, return_url=True)
if time.time() - render_start > 3:
logger.warning(f"图片转文本耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
if url:
result.chain = [Image.fromURL(url)]
+1 -1
View File
@@ -3,4 +3,4 @@ from enum import Enum
class MessageType(Enum):
GROUP_MESSAGE = 'GroupMessage' # 群组形式的消息
FRIEND_MESSAGE = 'FriendMessage' # 私聊、好友等单聊消息
OTHER_MESSAGE = 'OtherMessage' # 其他类型的消息,如系统消息等
+18
View File
@@ -0,0 +1,18 @@
from astrbot.api import Context, AstrMessageEvent, MessageEventResult
from .wechat_platform_adapter import WechatPlatformAdapter
from astrbot.api import logger
class Main:
def __init__(self, context: Context) -> None:
self.context = context
platforms_config = context.get_config().platform
settings = context.get_config().platform_settings
for platform in platforms_config:
if platform.name == "wechat" and platform.enable:
self.context.register_platform(WechatPlatformAdapter(platform, settings, context.get_event_queue()))
logger.info(f"已注册 wechat({platform.id}) 消息适配器。")
self.context.register_commands("astrbot_adapter_wechat", "wechatid", "查看微信ID", 1, self.get_wechat_id)
async def get_wechat_id(self, event: AstrMessageEvent):
event.set_result(MessageEventResult().message("这个会话的微信ID是" + event.message_obj.raw_message.from_.username))
@@ -0,0 +1,6 @@
name: astrbot_adapter_wechat # 插件名称
desc: 支持 Wechat(UOS) 的消息平台适配器
help:
version: v1.0.0 # 插件版本号。格式:v1.1.1 或者 v1.1
author: Soulter # 作者
repo: https://github.com/Soulter/AstrBot
@@ -0,0 +1,38 @@
import random, asyncio
from astrbot.core.utils.io import download_image_by_url
from astrbot.api import AstrMessageEvent, MessageChain, logger, AstrBotMessage, PlatformMetadata
from astrbot.api import Plain, Image
from vchat import Core
class WechatPlatformEvent(AstrMessageEvent):
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, client: Core):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(client: Core, message: MessageChain, user_name: str):
plain = ""
for comp in message.chain:
if isinstance(comp, Plain):
plain += comp.text
elif isinstance(comp, Image):
if comp.file and comp.file.startswith("file:///"):
file_path = comp.file.replace("file:///", "")
with open(file_path, "rb") as f:
await client.send_image(user_name, fd=f)
elif comp.file and comp.file.startswith("http"):
image_path = await download_image_by_url(comp.file)
with open(image_path, "rb") as f:
await client.send_image(user_name, fd=f)
else:
logger.error(f"不支持的 vchat(微信适配器) 消息类型: {comp}")
await asyncio.sleep(random.uniform(0.5, 1.5)) # 🤓
if plain:
await client.send_msg(plain, user_name)
async def send(self, message: MessageChain):
await WechatPlatformEvent.send_with_client(self.client, message, self.message_obj.raw_message.from_.username)
await super().send(message)
@@ -0,0 +1,112 @@
import sys, time, datetime, uuid
import asyncio
from astrbot.api import Platform
from astrbot.api import MessageChain, MessageEventResult, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from typing import Union, List, Dict
from nakuru.entities.components import *
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from .wechat_message_event import WechatPlatformEvent
from astrbot.core.config.astrbot_config import PlatformConfig, WechatPlatformConfig, PlatformSettings
from astrbot.core.utils.io import save_temp_img, download_image_by_url
from vchat import Core
from vchat import model
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class WechatPlatformAdapter(Platform):
def __init__(self, platform_config: WechatPlatformConfig, platform_settings: PlatformSettings, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
self.client_self_id = uuid.uuid4().hex[:8]
@override
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
from_username = session.session_id.split('$$')[0]
await WechatPlatformEvent.send_with_client(self.client, message_chain, from_username)
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"wechat",
"基于 VChat 的 Wechat 适配器",
)
@override
def run(self):
self.client = Core()
@self.client.msg_register(msg_types=model.ContentTypes.TEXT,
contact_type=model.ContactTypes.CHATROOM | model.ContactTypes.USER)
async def _(msg: model.Message):
if isinstance(msg.content, model.UselessContent):
return
if msg.create_time < self.start_time:
logger.debug(f"忽略旧消息: {msg}")
return
if self.config.wechat_id_whitelist and msg.from_.username not in self.config.wechat_id_whitelist:
logger.debug(f"忽略不在白名单的微信消息。username: {msg.from_.username}")
return
logger.info(f"收到消息: {msg.todict()}")
abmsg = self.convert_message(msg)
# await self.handle_msg(abmsg) # 不能直接调用,否则会阻塞
asyncio.create_task(self.handle_msg(abmsg))
# TODO: 对齐微信服务器时间
self.start_time = int(time.time())
return self._run()
async def _run(self):
await self.client.init()
await self.client.auto_login(hot_reload=True)
await self.client.run()
def convert_message(self, msg: model.Message) -> AstrBotMessage:
# credits: https://github.com/z2z63/astrbot_plugin_vchat/blob/master/main.py#L49
assert isinstance(msg.content, model.TextContent)
amsg = AstrBotMessage()
amsg.message = [Plain(msg.content.content)]
amsg.self_id = self.client_self_id
if msg.content.is_at_me:
amsg.message.insert(0, At(qq=amsg.self_id))
sender = msg.chatroom_sender or msg.from_
amsg.sender = MessageMember(sender.username, sender.nickname)
amsg.message_str = msg.content.content
amsg.message_id = msg.message_id
if isinstance(msg.from_, model.User):
amsg.type = MessageType.FRIEND_MESSAGE
elif isinstance(msg.from_, model.Chatroom):
amsg.type = MessageType.GROUP_MESSAGE
else:
logger.error(f"不支持的 Wechat 消息类型: {msg.from_}")
amsg.raw_message = msg
session_id = msg.from_.username + "$$" + msg.to.username
if msg.chatroom_sender is not None:
session_id += '$$' + msg.chatroom_sender.username
amsg.session_id = session_id
return amsg
async def handle_msg(self, message: AstrBotMessage):
message_event = WechatPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client
)
logger.info(f"处理消息: {message_event}")
self.commit_event(message_event)
+1 -1
View File
@@ -61,7 +61,7 @@ class Main:
fetch_website_content
)
async def remove_web_search_tools(self):
def remove_web_search_tools(self):
self.context.unregister_llm_tool("web_search")
self.context.unregister_llm_tool("fetch_website_content")
+1
View File
@@ -1,4 +1,5 @@
pydantic~=1.10.4
vchat
aiohttp
openai
qq-botpy