fix: 修复面板保存配置时报错的问题;修复频道私聊报错的问题

perf: 改善日志
This commit is contained in:
Soulter
2024-02-05 13:18:34 +08:00
parent d522d2a6a9
commit 2cf18972f3
9 changed files with 72 additions and 35 deletions
+7 -3
View File
@@ -9,6 +9,7 @@ import sys
import os
import threading
import time
import asyncio
def shutdown_bot(delay_s: int):
@@ -28,6 +29,9 @@ class DashBoardConfig():
class DashBoardHelper():
def __init__(self, global_object, config: dict):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.logger = global_object.logger
dashboard_data = global_object.dashboard_data
dashboard_data.configs = {
"data": []
@@ -41,13 +45,13 @@ class DashBoardHelper():
@self.dashboard.register("post_configs")
def on_post_configs(post_configs: dict):
try:
gu.log(f"收到配置更新请求", gu.LEVEL_INFO, tag="可视化面板")
self.logger.log(f"收到配置更新请求", gu.LEVEL_INFO, tag="可视化面板")
self.save_config(post_configs)
self.parse_default_config(self.dashboard_data, self.cc.get_all())
# 重启
threading.Thread(target=shutdown_bot, args=(2,), daemon=True).start()
except Exception as e:
gu.log(f"在保存配置时发生错误:{e}", gu.LEVEL_ERROR, tag="可视化面板")
self.logger.log(f"在保存配置时发生错误:{e}", gu.LEVEL_ERROR, tag="可视化面板")
raise e
@@ -524,7 +528,7 @@ class DashBoardHelper():
]
except Exception as e:
gu.log(f"配置文件解析错误:{e}", gu.LEVEL_ERROR)
self.logger.log(f"配置文件解析错误:{e}", gu.LEVEL_ERROR)
raise e
+1 -1
View File
@@ -29,7 +29,7 @@ class Response():
class AstrBotDashBoard():
def __init__(self, global_object):
self.loop = asyncio.new_event_loop()
self.loop = asyncio.get_event_loop()
asyncio.set_event_loop(self.loop)
self.dashboard_data = global_object.dashboard_data
self.dashboard_be = Flask(__name__, static_folder="dist", static_url_path="/")
+4 -8
View File
@@ -294,9 +294,6 @@ def initBot(cfg):
platform_str = "(未启动任何平台,请前往面板添加)"
logger.log(f"🎉 项目启动完成\n - 启动的LLM: {len(llm_instance)}\n - 启动的平台: {platform_str}\n - 启动的插件: {len(_global_object.cached_plugins)}")
if chosen_provider is None:
logger.log("没有启动任何语言模型。", gu.LEVEL_WARNING)
dashboard_thread.join()
async def cli():
@@ -333,7 +330,7 @@ async def cli_pack_message(prompt: str) -> NakuruGuildMessage:
def run_qqchan_bot(cfg: dict, global_object: GlobalObject):
try:
from model.platform.qq_official import QQOfficial
qqchannel_bot = QQOfficial(cfg=cfg, message_handler=oper_msg)
qqchannel_bot = QQOfficial(cfg=cfg, message_handler=oper_msg, global_object=global_object)
global_object.platform_qqchan = qqchannel_bot
qqchannel_bot.run()
except BaseException as e:
@@ -358,7 +355,7 @@ def run_gocq_bot(cfg: dict, _global_object: GlobalObject):
logger.log("检查完毕,未发现问题。", tag="QQ")
break
try:
qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg)
qq_gocq = QQGOCQ(cfg=cfg, message_handler=oper_msg, global_object=_global_object)
_global_object.platform_qq = qq_gocq
qq_gocq.run()
except BaseException as e:
@@ -421,7 +418,6 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
for i in message.message:
if isinstance(i, Plain):
message_str += i.text.strip()
logger.log(message_str, gu.LEVEL_INFO, tag=platform)
if message_str == "":
return MessageResult("Hi~")
@@ -488,8 +484,8 @@ async def oper_msg(message: Union[GroupMessage, FriendMessage, GuildMessage, Nak
check, msg = baidu_judge.judge(message_str)
if not check:
return MessageResult(f"你的提问得到的回复未通过【百度AI内容审核】服务, 不予回复。\n\n{msg}")
if chosen_provider == None:
return MessageResult(f"管理员未启动任何语言模型或者语言模型初始化时失败")
if chosen_provider == NONE_LLM:
return MessageResult("没有启动任何 LLM 并且未触发任何指令")
try:
if llm_wake_prefix != "" and not message_str.startswith(llm_wake_prefix):
return
-1
View File
@@ -5,7 +5,6 @@ try:
import git.exc
from git.repo import Repo
except BaseException as e:
gu.log("你正运行在无Git环境下,暂时将无法使用插件、热更新功能。")
has_git = False
import os
import sys
+33 -1
View File
@@ -1,7 +1,17 @@
import abc
import threading
import asyncio
from typing import Callable
from typing import Callable, Union
from nakuru import (
GuildMessage,
GroupMessage,
FriendMessage,
)
from ._nakuru_translation_layer import (
NakuruGuildMessage,
)
from nakuru.entities.components import Plain, At, Image, Node
class Platform():
def __init__(self, message_handler: callable) -> None:
@@ -38,6 +48,28 @@ class Platform():
发送消息(主动发送)同 send_msg()
'''
pass
def parse_message_outline(self, message: Union[GuildMessage, GroupMessage, FriendMessage, str]) -> NakuruGuildMessage:
'''
将消息解析成大纲消息形式。
如: xxxxx[图片]xxxxx
'''
if isinstance(message, str):
return message
ret = ''
try:
for node in message.message:
if isinstance(node, Plain):
ret += node.text
elif isinstance(node, At):
ret += f'[At: {node.name}/{node.qq}]'
elif isinstance(node, Image):
ret += f'[图片]'
except Exception as e:
pass
ret.replace('\n', '')
return ret
def new_sub_thread(self, func, args=()):
thread = threading.Thread(target=self._runner, args=(func, args), daemon=True)
+4 -3
View File
@@ -24,7 +24,7 @@ class FakeSource:
class QQGOCQ(Platform):
def __init__(self, cfg: dict, message_handler: callable) -> None:
def __init__(self, cfg: dict, message_handler: callable, global_object) -> None:
super().__init__(message_handler)
self.loop = asyncio.new_event_loop()
@@ -34,7 +34,7 @@ class QQGOCQ(Platform):
self.gocq_cnt = 0
self.cc = CmdConfig()
self.cfg = cfg
self.logger = gu.Logger()
self.logger: gu.Logger = global_object.logger
try:
self.nick_qq = cfg['nick_qq']
@@ -107,6 +107,7 @@ class QQGOCQ(Platform):
self.client.run()
async def handle_msg(self, message: Union[GroupMessage, FriendMessage, GuildMessage, Notify], is_group: bool):
self.logger.log(f"{message.user_id} -> {self.parse_message_outline(message)}", tag="QQ_GOCQ")
# 判断是否响应消息
resp = False
if not is_group:
@@ -178,7 +179,7 @@ class QQGOCQ(Platform):
self.gocq_cnt += 1
self.logger.log(f"{source.user_id} <- {res}", tag="GOCQ")
self.logger.log(f"{source.user_id} <- {self.parse_message_outline(res)}", tag="QQ_GOCQ")
if isinstance(source, int):
source = FakeSource("GroupMessage", source)
+20 -17
View File
@@ -17,6 +17,7 @@ from ._nakuru_translation_layer import(
gocq_compatible_receive,
gocq_compatible_send
)
from typing import Union
# QQ 机器人官方框架
class botClient(Client):
@@ -25,24 +26,19 @@ class botClient(Client):
# 收到频道消息
async def on_at_message_create(self, message: Message):
gu.log(str(message), gu.LEVEL_DEBUG, max_len=9999)
# 转换层
nakuru_guild_message = gocq_compatible_receive(message)
gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999)
# await self.platform.handle_msg(nakuru_guild_message, is_group=True)
self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, True))
# 收到私聊消息
async def on_direct_message_create(self, message: DirectMessage):
# 转换层
nakuru_guild_message = gocq_compatible_receive(message)
gu.log(f"转换后: {str(nakuru_guild_message)}", gu.LEVEL_DEBUG, max_len=9999)
# await self.platform.handle_msg(nakuru_guild_message, is_group=False)
self.platform.new_sub_thread(self.platform.handle_msg, (nakuru_guild_message, False))
class QQOfficial(Platform):
def __init__(self, cfg: dict, message_handler: callable) -> None:
def __init__(self, cfg: dict, message_handler: callable, global_object) -> None:
super().__init__(message_handler)
self.loop = asyncio.new_event_loop()
@@ -56,7 +52,7 @@ class QQOfficial(Platform):
self.token = cfg['qqbot']['token']
self.secret = cfg['qqbot_secret']
self.unique_session = cfg['uniqueSessionMode']
self.logger = gu.Logger()
self.logger: gu.Logger = global_object.logger
self.intents = botpy.Intents(
public_guild_messages=True,
@@ -87,10 +83,11 @@ class QQOfficial(Platform):
)
async def handle_msg(self, message: NakuruGuildMessage, is_group: bool):
_t = "/私聊" if not is_group else ""
self.logger.log(f"{message.sender.nickname}({message.sender.tiny_id}{_t}) -> {self.parse_message_outline(message)}", tag="QQ_OFFICIAL")
# 解析出 session_id
if self.unique_session or not is_group:
session_id = message.sender.user_id
session_id = message.sesnder.user_id
else:
session_id = message.channel_id
@@ -112,7 +109,7 @@ class QQOfficial(Platform):
if message_result is None:
return
self.reply_msg(message, message_result.result_message)
self.reply_msg(is_group, message, message_result.result_message)
if message_result.callback is not None:
message_result.callback()
@@ -121,12 +118,13 @@ class QQOfficial(Platform):
self.waiting[session_id] = message
def reply_msg(self,
message: NakuruGuildMessage,
res: list):
is_group: bool,
message: NakuruGuildMessage,
res: Union[str, list]):
'''
回复频道消息
'''
self.logger.log(f"{message.sender.nickname}({message.sender.tiny_id}) <- {res}", tag="QQ频道")
self.logger.log(f"{message.sender.nickname}({message.sender.tiny_id}) <- {self.parse_message_outline(res)}", tag="QQ_OFFICIAL")
self.qqchan_cnt += 1
plain_text = ''
@@ -162,13 +160,15 @@ class QQOfficial(Platform):
msg_ref = Reference(message_id=message.raw_message.id, ignore_get_message_error=False)
# 到这里,我们得到了 plain_textimage_pathmsg_ref
data = {
'channel_id': str(message.channel_id),
'content': plain_text,
'msg_id': message.message_id,
'message_reference': msg_ref
}
if is_group:
data['channel_id'] = message.channel_id
else:
data['guild_id'] = message.guild_id
if image_path != '':
data['file_image'] = image_path
@@ -207,8 +207,11 @@ class QQOfficial(Platform):
self._send_wrapper(**data)
def _send_wrapper(self, **kwargs):
# await self.client.api.post_message(**kwargs)
asyncio.run_coroutine_threadsafe(self.client.api.post_message(**kwargs), self.loop).result()
if 'channel_id' in kwargs:
asyncio.run_coroutine_threadsafe(self.client.api.post_message(**kwargs), self.loop).result()
else:
asyncio.run_coroutine_threadsafe(self.client.api.post_dms(**kwargs), self.loop).result()
def send_msg(self, channel_id: int, message_chain: list, message_id: int = None):
'''
+1
View File
@@ -156,6 +156,7 @@ def web_keyword_search_via_sougou(keyword) -> str:
if len(res) >= 5: # 限制5条
break
except Exception as e:
pass
gu.log(f"sougou parse err: {str(e)}", tag="web_keyword_search_via_sougou", level=gu.LEVEL_ERROR)
# 爬取网页内容
_detail_store = []
+2 -1
View File
@@ -124,7 +124,7 @@ class Logger:
for line in pres:
ret += f"\033[{fg};{bg}m{line}\033[0m\n"
try:
requests.post("http://localhost:6185/api/log", data=ret[:-1].encode())
requests.post("http://localhost:6185/api/log", data=ret[:-1].encode(), timeout=1)
except BaseException as e:
pass
self.history.append(ret)
@@ -132,6 +132,7 @@ class Logger:
self.history = self.history[-100:]
print(ret[:-1])
log = Logger()
def port_checker(port: int, host: str = "localhost"):
sk = socket.socket(socket.AF_INET,socket.SOCK_STREAM)